|
9 | 9 | import argparse
|
10 | 10 | import csv
|
11 | 11 | import json
|
12 |
| -import logging |
13 | 12 | import multiprocessing
|
14 | 13 | import os
|
15 | 14 | import sys
|
16 | 15 | import time
|
17 | 16 |
|
18 |
| -from typing import cast, Dict, Iterator, List, Optional |
| 17 | +from typing import Dict, List, Optional |
19 | 18 |
|
20 | 19 | import numpy as np
|
21 | 20 |
|
22 | 21 | import torch
|
23 | 22 | import torch.nn as nn
|
24 | 23 |
|
25 |
| -from line_profiler import LineProfiler |
26 |
| - |
27 | 24 | from torch import distributed as dist
|
| 25 | +from torch.profiler import profile, ProfilerActivity, record_function |
28 | 26 | from torch.utils.data import DataLoader
|
29 | 27 | from torch.utils.tensorboard import SummaryWriter # @manual //caffe2:torch_tensorboard
|
30 | 28 | from torchrec.metrics.metrics_namespace import MetricPrefix
|
|
38 | 36 | from .benchmark_zch_utils import BenchmarkMCProbe, get_logger, get_module_from_instance
|
39 | 37 |
|
40 | 38 | from .data.get_dataloader import get_dataloader
|
41 |
| -from .data.get_metric_modules import get_metric_modules |
42 | 39 | from .data.nonzch_remapper import NonZchModRemapperModule
|
43 | 40 | from .models.apply_optimizers import (
|
44 | 41 | apply_dense_optimizers,
|
@@ -86,7 +83,6 @@ def main(rank: int, args: argparse.Namespace, queue: multiprocessing.Queue) -> N
|
86 | 83 |
|
87 | 84 | # get metric modules
|
88 | 85 | logger.info(f"[rank {rank}] get metric modules")
|
89 |
| - metric_modules = get_metric_modules(rank, args, device) |
90 | 86 |
|
91 | 87 | # make the model
|
92 | 88 | logger.info(f"[rank {rank}] make model")
|
@@ -146,166 +142,133 @@ def main(rank: int, args: argparse.Namespace, queue: multiprocessing.Queue) -> N
|
146 | 142 | total_num_queries_in_training = 0
|
147 | 143 |
|
148 | 144 | # train the model
|
149 |
| - logger.info(f"[rank {rank}] train the model") |
150 |
| - batch_cnt = 0 |
151 |
| - for epoch_idx in range(args.epochs): |
152 |
| - model.train() |
153 |
| - starter_list = [] |
154 |
| - ender_list = [] |
155 |
| - num_queries_per_batch_list = [] |
156 |
| - loss_per_batch_list = [] |
157 |
| - pbar = tqdm(train_dataloader, desc=f"Epoch {epoch_idx}") |
158 |
| - for batch_idx, batch in enumerate(pbar): |
159 |
| - # batch = batch.to(device) |
160 |
| - batch = batch.to(device) |
161 |
| - # remap the batch if needed |
162 |
| - if len(args.zch_method) == 0: |
163 |
| - # pyre-ignore [16] # NOTE: pyre reports nonzch_remapper can be None, but when reach to this branch of condition, we know it is not None |
164 |
| - batch = nonzch_remapper.remap(batch) |
165 |
| - starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( |
166 |
| - enable_timing=True |
167 |
| - ) |
168 |
| - if True or len(args.zch_method) > 0: |
169 |
| - benchmark_probe.record_mcec_state(stage="before_fwd") |
170 |
| - # train model |
171 |
| - starter.record() |
172 |
| - ## zero the gradients |
173 |
| - optimizer.zero_grad() |
174 |
| - ## forward pass |
175 |
| - loss, (loss_values, pred_logits, labels, weights) = model(batch) |
176 |
| - ## backward pass |
177 |
| - loss.backward() |
178 |
| - ## update weights |
179 |
| - optimizer.step() |
180 |
| - ender.record() |
181 |
| - # update the batch counter |
182 |
| - batch_cnt += 1 |
183 |
| - # append the start and end events to the lists |
184 |
| - starter_list.append(starter) |
185 |
| - ender_list.append(ender) |
186 |
| - # do training metrics and QPS statistics |
187 |
| - num_queries_per_batch = len(labels) |
188 |
| - num_queries_per_batch_list.append(num_queries_per_batch) |
189 |
| - loss_per_batch_list.append(loss.cpu().item()) |
190 |
| - # do zch statistics |
191 |
| - benchmark_probe.record_mcec_state(stage="after_fwd") |
192 |
| - # update zch statistics |
193 |
| - benchmark_probe.update() |
194 |
| - # push the zch stats to the queue |
| 145 | + ## code for profiling |
| 146 | + activities = [ProfilerActivity.CPU] |
| 147 | + if torch.cuda.is_available(): |
| 148 | + device = "cuda" |
| 149 | + activities += [ProfilerActivity.CUDA] |
| 150 | + ## end code for profiling |
| 151 | + with profile(activities=activities, record_shapes=True) as prof: |
| 152 | + for epoch_idx in range(args.epochs): |
| 153 | + model.train() |
| 154 | + starter_list = [] |
| 155 | + ender_list = [] |
| 156 | + num_queries_per_batch_list = [] |
| 157 | + pbar = tqdm(train_dataloader, desc=f"Epoch {epoch_idx}") |
| 158 | + for batch_idx, batch in enumerate(pbar): |
| 159 | + # batch = batch.to(device) |
| 160 | + batch = batch.to(device) |
| 161 | + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( |
| 162 | + enable_timing=True |
| 163 | + ) |
| 164 | + if len(args.zch_method) > 0: |
| 165 | + benchmark_probe.record_mcec_state(stage="before_fwd") |
| 166 | + # forward pass |
| 167 | + with record_function(f"model training on batch {batch_idx}"): |
| 168 | + starter.record() |
| 169 | + # zero the gradients |
| 170 | + optimizer.zero_grad() |
| 171 | + loss, (loss_values, pred_logits, labels) = model(batch) |
| 172 | + loss.backward() |
| 173 | + optimizer.step() |
| 174 | + ender.record() |
| 175 | + # do statistics |
| 176 | + num_queries_per_batch = len(labels) |
| 177 | + starter_list.append(starter) |
| 178 | + ender_list.append(ender) |
| 179 | + num_queries_per_batch_list.append(num_queries_per_batch) |
| 180 | + if len(args.zch_method) > 0: |
| 181 | + benchmark_probe.record_mcec_state(stage="after_fwd") |
| 182 | + # update zch statistics |
| 183 | + benchmark_probe.update() |
| 184 | + # push the zch stats to the queue |
| 185 | + msg_content = { |
| 186 | + "epoch_idx": epoch_idx, |
| 187 | + "batch_idx": batch_idx, |
| 188 | + "rank": rank, |
| 189 | + "mch_stats": benchmark_probe.get_mch_stats(), |
| 190 | + } |
| 191 | + queue.put( |
| 192 | + ("mch_stats", msg_content), |
| 193 | + ) |
| 194 | + if ( |
| 195 | + batch_idx % interval_num_batches_show_qps == 0 |
| 196 | + or batch_idx == len(train_dataloader) - 1 |
| 197 | + ): |
| 198 | + if batch_idx == 0: |
| 199 | + # skip the first batch since it is not a full batch |
| 200 | + continue |
| 201 | + # synchronize all the threads to get the exact number of batches |
| 202 | + torch.cuda.synchronize() |
| 203 | + # calculate the qps |
| 204 | + # NOTE: why do this qps calculation every interval_num_batches_show_qps batches? |
| 205 | + # because performing this calculation needs to synchronize all the ranks by calling torch.cuda.synchronize() |
| 206 | + # and this is a heavy operation (takes several milliseconds). So we only do this calculation every |
| 207 | + # interval_num_batches_show_qps batches to reduce the overhead. |
| 208 | + ## get per batch time list by calculating the time difference between the start and end events of each batch |
| 209 | + per_batch_time_list = [] |
| 210 | + for i in range(len(starter_list)): |
| 211 | + per_batch_time_list.append( |
| 212 | + starter_list[i].elapsed_time(ender_list[i]) / 1000 |
| 213 | + ) # convert to seconds by dividing by 1000 |
| 214 | + ## calculate the total time in the interval |
| 215 | + total_time_in_interval = sum(per_batch_time_list) |
| 216 | + ## calculate the total number of queries in the interval |
| 217 | + total_num_queries_in_interval = sum(num_queries_per_batch_list) |
| 218 | + ## fabricate the message and total_num_queries_in_interval to the queue |
| 219 | + interval_start_batch_idx = ( |
| 220 | + batch_idx - interval_num_batches_show_qps |
| 221 | + if batch_idx >= interval_num_batches_show_qps |
| 222 | + else 0 |
| 223 | + ) # the start batch index of the interval |
| 224 | + interval_end_batch_idx = ( |
| 225 | + batch_idx # the end batch index of the interval |
| 226 | + ) |
| 227 | + ## fabricate the message content |
| 228 | + msg_content = { |
| 229 | + "epoch_idx": epoch_idx, |
| 230 | + "rank": rank, |
| 231 | + "interval_start_batch_idx": interval_start_batch_idx, |
| 232 | + "interval_end_batch_idx": interval_end_batch_idx, |
| 233 | + "per_batch_time_list": per_batch_time_list, |
| 234 | + "per_batch_num_queries_list": num_queries_per_batch_list, |
| 235 | + } |
| 236 | + ## put the message into the queue |
| 237 | + queue.put(("duration_and_num_queries", msg_content)) |
| 238 | + qps_per_interval = ( |
| 239 | + total_num_queries_in_interval / total_time_in_interval |
| 240 | + ) |
| 241 | + total_time_in_training += total_time_in_interval |
| 242 | + total_num_queries_in_training += total_num_queries_in_interval |
| 243 | + pbar.set_postfix( |
| 244 | + { |
| 245 | + "QPS": qps_per_interval, |
| 246 | + } |
| 247 | + ) |
| 248 | + pbar.update(interval_num_batches_show_qps) |
| 249 | + # reset the lists |
| 250 | + starter_list = [] |
| 251 | + ender_list = [] |
| 252 | + num_queries_per_batch_list = [] |
| 253 | + if batch_idx > 50: |
| 254 | + # skip the first batch since it is not a full batch |
| 255 | + break |
| 256 | + # after each epoch, do validation |
| 257 | + eval_result_dict = evaluation(model, test_dataloader, device) |
| 258 | + # print the evaluation result |
| 259 | + print(f"Evaluation result: {eval_result_dict}") |
| 260 | + # send the evaluation result to the queue |
195 | 261 | msg_content = {
|
196 |
| - "epoch_idx": epoch_idx, |
197 |
| - "batch_idx": batch_idx, |
198 |
| - "batch_cnt": batch_cnt, |
| 262 | + "epoch_idx": args.epochs, |
199 | 263 | "rank": rank,
|
200 |
| - "mch_stats": benchmark_probe.get_mch_stats(), |
| 264 | + "eval_result_dict": eval_result_dict, |
201 | 265 | }
|
202 |
| - queue.put( |
203 |
| - ("mch_stats", msg_content), |
204 |
| - ) |
205 |
| - if ( |
206 |
| - batch_idx % interval_num_batches_show_qps == 0 |
207 |
| - or batch_idx == len(train_dataloader) - 1 |
208 |
| - ): |
209 |
| - if batch_idx == 0: |
210 |
| - # skip the first batch since it is not a full batch |
211 |
| - continue |
212 |
| - logger.info(f"[rank {rank}] batch_idx: {batch_idx} get the stats") |
213 |
| - # synchronize all the threads to get the exact number of batches |
214 |
| - torch.cuda.synchronize() |
215 |
| - # calculate the qps |
216 |
| - # NOTE: why do this qps calculation every interval_num_batches_show_qps batches? |
217 |
| - # because performing this calculation needs to synchronize all the ranks by calling torch.cuda.synchronize() |
218 |
| - # and this is a heavy operation (takes several milliseconds). So we only do this calculation every |
219 |
| - # interval_num_batches_show_qps batches to reduce the overhead. |
220 |
| - ## get per batch time list by calculating the time difference between the start and end events of each batch |
221 |
| - per_batch_time_list = [] |
222 |
| - for i in range(len(starter_list)): |
223 |
| - per_batch_time_list.append( |
224 |
| - starter_list[i].elapsed_time(ender_list[i]) / 1000 |
225 |
| - ) # convert to seconds by dividing by 1000 |
226 |
| - ## calculate the total time in the interval |
227 |
| - total_time_in_interval = sum(per_batch_time_list) |
228 |
| - ## calculate the total number of queries in the interval |
229 |
| - total_num_queries_in_interval = sum(num_queries_per_batch_list) |
230 |
| - ## fabricate the message and total_num_queries_in_interval to the queue |
231 |
| - interval_start_batch_idx = ( |
232 |
| - batch_idx - interval_num_batches_show_qps |
233 |
| - if batch_idx >= interval_num_batches_show_qps |
234 |
| - else 0 |
235 |
| - ) # the start batch index of the interval |
236 |
| - interval_start_batch_cnt = ( |
237 |
| - batch_cnt - interval_num_batches_show_qps |
238 |
| - if batch_cnt >= interval_num_batches_show_qps |
239 |
| - else 0 |
240 |
| - ) # the start batch counter of the interval |
241 |
| - interval_end_batch_idx = ( |
242 |
| - batch_idx # the end batch index of the interval |
243 |
| - ) |
244 |
| - ## fabricate the message content |
245 |
| - msg_content = { |
246 |
| - "epoch_idx": epoch_idx, |
247 |
| - "rank": rank, |
248 |
| - "interval_start_batch_idx": interval_start_batch_idx, |
249 |
| - "interval_end_batch_idx": interval_end_batch_idx, |
250 |
| - "interval_start_batch_cnt": interval_start_batch_cnt, |
251 |
| - "interval_end_batch_cnt": batch_cnt, |
252 |
| - "per_batch_time_list": per_batch_time_list, |
253 |
| - "per_batch_num_queries_list": num_queries_per_batch_list, |
254 |
| - } |
255 |
| - ## put the message into the queue |
256 |
| - queue.put(("duration_and_num_queries", msg_content)) |
257 |
| - ## also fabricate the message for loss |
258 |
| - msg_content = { |
259 |
| - "epoch_idx": epoch_idx, |
260 |
| - "rank": rank, |
261 |
| - "interval_start_batch_idx": interval_start_batch_idx, |
262 |
| - "interval_end_batch_idx": interval_end_batch_idx, |
263 |
| - "interval_start_batch_cnt": interval_start_batch_cnt, |
264 |
| - "interval_end_batch_cnt": batch_cnt, |
265 |
| - "per_batch_loss_list": loss_per_batch_list, |
266 |
| - } |
267 |
| - ## put the message into the queue |
268 |
| - queue.put(("training_metrics", msg_content)) |
269 |
| - # calculate QPS per statistic interval |
270 |
| - qps_per_interval = ( |
271 |
| - total_num_queries_in_interval / total_time_in_interval |
272 |
| - ) |
273 |
| - total_time_in_training += total_time_in_interval |
274 |
| - total_num_queries_in_training += total_num_queries_in_interval |
275 |
| - pbar.set_postfix( |
276 |
| - { |
277 |
| - "QPS": qps_per_interval, |
278 |
| - } |
279 |
| - ) |
280 |
| - pbar.update(interval_num_batches_show_qps) |
281 |
| - # reset the lists |
282 |
| - starter_list = [] |
283 |
| - ender_list = [] |
284 |
| - num_queries_per_batch_list = [] |
285 |
| - loss_per_batch_list = [] |
286 |
| - # after training of each epoch, do validation |
287 |
| - logger.info(f"[rank {rank}] do validation after training of epoch {epoch_idx}") |
288 |
| - metric_values = evaluation( |
289 |
| - metric_modules, |
290 |
| - model, |
291 |
| - test_dataloader, |
292 |
| - device, |
293 |
| - nonzch_remapper if len(args.zch_method) == 0 else None, |
294 |
| - ) |
295 |
| - # print the evaluation result |
296 |
| - print(f"Evaluation result: {metric_values}") |
297 |
| - # send the evaluation result to the queue |
298 |
| - msg_content = { |
299 |
| - "epoch_idx": epoch_idx, |
300 |
| - "rank": rank, |
301 |
| - "eval_result_dict": metric_values, |
302 |
| - } |
303 |
| - queue.put(("eval_result", msg_content)) |
| 266 | + queue.put(("eval_result", msg_content)) |
304 | 267 |
|
305 |
| - logger.info( |
306 |
| - f"[rank {rank}] finished, sleep for 15 seconds before sending finish signal and exit" |
| 268 | + prof.export_chrome_trace( |
| 269 | + f"/home/lizhouyu/tmp/trace_noremap_{args.zch_method if len(args.zch_method) > 0 else 'nonzch'}_{args.model_name}_fullloop_tbsize_{args.num_embeddings}_rank{rank}.json" |
307 | 270 | )
|
308 |
| - time.sleep(15) |
| 271 | + time.sleep(10) |
309 | 272 | queue.put(("finished", {"rank": rank}))
|
310 | 273 | print("finished")
|
311 | 274 | return
|
@@ -809,8 +772,6 @@ def statistic(args: argparse.Namespace, queue: multiprocessing.Queue) -> None:
|
809 | 772 | if __name__ == "__main__":
|
810 | 773 | args: argparse.Namespace = parse_args(sys.argv[1:])
|
811 | 774 |
|
812 |
| - __builtins__.__dict__["profile"] = LineProfiler() |
813 |
| - |
814 | 775 | # set environment variables
|
815 | 776 | os.environ["MASTER_ADDR"] = str("localhost")
|
816 | 777 | os.environ["MASTER_PORT"] = str(get_free_port())
|
|
0 commit comments