Skip to content

Commit 15e6d34

Browse files
lizhouyufacebook-github-bot
authored andcommitted
Check QPS Regresses of MPZCH
Summary: Check the causes of QPS regress of MPZCH modules Differential Revision: D77189125
1 parent 45d5c4d commit 15e6d34

File tree

5 files changed

+406
-165
lines changed

5 files changed

+406
-165
lines changed

torchrec/distributed/benchmark/benchmark_zch/benchmark_zch.py

Lines changed: 124 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,20 @@
99
import argparse
1010
import csv
1111
import json
12-
import logging
1312
import multiprocessing
1413
import os
1514
import sys
1615
import time
1716

18-
from typing import cast, Dict, Iterator, List, Optional
17+
from typing import Dict, List, Optional
1918

2019
import numpy as np
2120

2221
import torch
2322
import torch.nn as nn
2423

25-
from line_profiler import LineProfiler
26-
2724
from torch import distributed as dist
25+
from torch.profiler import profile, ProfilerActivity, record_function
2826
from torch.utils.data import DataLoader
2927
from torch.utils.tensorboard import SummaryWriter # @manual //caffe2:torch_tensorboard
3028
from torchrec.metrics.metrics_namespace import MetricPrefix
@@ -38,7 +36,6 @@
3836
from .benchmark_zch_utils import BenchmarkMCProbe, get_logger, get_module_from_instance
3937

4038
from .data.get_dataloader import get_dataloader
41-
from .data.get_metric_modules import get_metric_modules
4239
from .data.nonzch_remapper import NonZchModRemapperModule
4340
from .models.apply_optimizers import (
4441
apply_dense_optimizers,
@@ -86,7 +83,6 @@ def main(rank: int, args: argparse.Namespace, queue: multiprocessing.Queue) -> N
8683

8784
# get metric modules
8885
logger.info(f"[rank {rank}] get metric modules")
89-
metric_modules = get_metric_modules(rank, args, device)
9086

9187
# make the model
9288
logger.info(f"[rank {rank}] make model")
@@ -146,166 +142,133 @@ def main(rank: int, args: argparse.Namespace, queue: multiprocessing.Queue) -> N
146142
total_num_queries_in_training = 0
147143

148144
# train the model
149-
logger.info(f"[rank {rank}] train the model")
150-
batch_cnt = 0
151-
for epoch_idx in range(args.epochs):
152-
model.train()
153-
starter_list = []
154-
ender_list = []
155-
num_queries_per_batch_list = []
156-
loss_per_batch_list = []
157-
pbar = tqdm(train_dataloader, desc=f"Epoch {epoch_idx}")
158-
for batch_idx, batch in enumerate(pbar):
159-
# batch = batch.to(device)
160-
batch = batch.to(device)
161-
# remap the batch if needed
162-
if len(args.zch_method) == 0:
163-
# pyre-ignore [16] # NOTE: pyre reports nonzch_remapper can be None, but when reach to this branch of condition, we know it is not None
164-
batch = nonzch_remapper.remap(batch)
165-
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(
166-
enable_timing=True
167-
)
168-
if True or len(args.zch_method) > 0:
169-
benchmark_probe.record_mcec_state(stage="before_fwd")
170-
# train model
171-
starter.record()
172-
## zero the gradients
173-
optimizer.zero_grad()
174-
## forward pass
175-
loss, (loss_values, pred_logits, labels, weights) = model(batch)
176-
## backward pass
177-
loss.backward()
178-
## update weights
179-
optimizer.step()
180-
ender.record()
181-
# update the batch counter
182-
batch_cnt += 1
183-
# append the start and end events to the lists
184-
starter_list.append(starter)
185-
ender_list.append(ender)
186-
# do training metrics and QPS statistics
187-
num_queries_per_batch = len(labels)
188-
num_queries_per_batch_list.append(num_queries_per_batch)
189-
loss_per_batch_list.append(loss.cpu().item())
190-
# do zch statistics
191-
benchmark_probe.record_mcec_state(stage="after_fwd")
192-
# update zch statistics
193-
benchmark_probe.update()
194-
# push the zch stats to the queue
145+
## code for profiling
146+
activities = [ProfilerActivity.CPU]
147+
if torch.cuda.is_available():
148+
device = "cuda"
149+
activities += [ProfilerActivity.CUDA]
150+
## end code for profiling
151+
with profile(activities=activities, record_shapes=True) as prof:
152+
for epoch_idx in range(args.epochs):
153+
model.train()
154+
starter_list = []
155+
ender_list = []
156+
num_queries_per_batch_list = []
157+
pbar = tqdm(train_dataloader, desc=f"Epoch {epoch_idx}")
158+
for batch_idx, batch in enumerate(pbar):
159+
# batch = batch.to(device)
160+
batch = batch.to(device)
161+
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(
162+
enable_timing=True
163+
)
164+
if len(args.zch_method) > 0:
165+
benchmark_probe.record_mcec_state(stage="before_fwd")
166+
# forward pass
167+
with record_function(f"model training on batch {batch_idx}"):
168+
starter.record()
169+
# zero the gradients
170+
optimizer.zero_grad()
171+
loss, (loss_values, pred_logits, labels) = model(batch)
172+
loss.backward()
173+
optimizer.step()
174+
ender.record()
175+
# do statistics
176+
num_queries_per_batch = len(labels)
177+
starter_list.append(starter)
178+
ender_list.append(ender)
179+
num_queries_per_batch_list.append(num_queries_per_batch)
180+
if len(args.zch_method) > 0:
181+
benchmark_probe.record_mcec_state(stage="after_fwd")
182+
# update zch statistics
183+
benchmark_probe.update()
184+
# push the zch stats to the queue
185+
msg_content = {
186+
"epoch_idx": epoch_idx,
187+
"batch_idx": batch_idx,
188+
"rank": rank,
189+
"mch_stats": benchmark_probe.get_mch_stats(),
190+
}
191+
queue.put(
192+
("mch_stats", msg_content),
193+
)
194+
if (
195+
batch_idx % interval_num_batches_show_qps == 0
196+
or batch_idx == len(train_dataloader) - 1
197+
):
198+
if batch_idx == 0:
199+
# skip the first batch since it is not a full batch
200+
continue
201+
# synchronize all the threads to get the exact number of batches
202+
torch.cuda.synchronize()
203+
# calculate the qps
204+
# NOTE: why do this qps calculation every interval_num_batches_show_qps batches?
205+
# because performing this calculation needs to synchronize all the ranks by calling torch.cuda.synchronize()
206+
# and this is a heavy operation (takes several milliseconds). So we only do this calculation every
207+
# interval_num_batches_show_qps batches to reduce the overhead.
208+
## get per batch time list by calculating the time difference between the start and end events of each batch
209+
per_batch_time_list = []
210+
for i in range(len(starter_list)):
211+
per_batch_time_list.append(
212+
starter_list[i].elapsed_time(ender_list[i]) / 1000
213+
) # convert to seconds by dividing by 1000
214+
## calculate the total time in the interval
215+
total_time_in_interval = sum(per_batch_time_list)
216+
## calculate the total number of queries in the interval
217+
total_num_queries_in_interval = sum(num_queries_per_batch_list)
218+
## fabricate the message and total_num_queries_in_interval to the queue
219+
interval_start_batch_idx = (
220+
batch_idx - interval_num_batches_show_qps
221+
if batch_idx >= interval_num_batches_show_qps
222+
else 0
223+
) # the start batch index of the interval
224+
interval_end_batch_idx = (
225+
batch_idx # the end batch index of the interval
226+
)
227+
## fabricate the message content
228+
msg_content = {
229+
"epoch_idx": epoch_idx,
230+
"rank": rank,
231+
"interval_start_batch_idx": interval_start_batch_idx,
232+
"interval_end_batch_idx": interval_end_batch_idx,
233+
"per_batch_time_list": per_batch_time_list,
234+
"per_batch_num_queries_list": num_queries_per_batch_list,
235+
}
236+
## put the message into the queue
237+
queue.put(("duration_and_num_queries", msg_content))
238+
qps_per_interval = (
239+
total_num_queries_in_interval / total_time_in_interval
240+
)
241+
total_time_in_training += total_time_in_interval
242+
total_num_queries_in_training += total_num_queries_in_interval
243+
pbar.set_postfix(
244+
{
245+
"QPS": qps_per_interval,
246+
}
247+
)
248+
pbar.update(interval_num_batches_show_qps)
249+
# reset the lists
250+
starter_list = []
251+
ender_list = []
252+
num_queries_per_batch_list = []
253+
if batch_idx > 50:
254+
# skip the first batch since it is not a full batch
255+
break
256+
# after each epoch, do validation
257+
eval_result_dict = evaluation(model, test_dataloader, device)
258+
# print the evaluation result
259+
print(f"Evaluation result: {eval_result_dict}")
260+
# send the evaluation result to the queue
195261
msg_content = {
196-
"epoch_idx": epoch_idx,
197-
"batch_idx": batch_idx,
198-
"batch_cnt": batch_cnt,
262+
"epoch_idx": args.epochs,
199263
"rank": rank,
200-
"mch_stats": benchmark_probe.get_mch_stats(),
264+
"eval_result_dict": eval_result_dict,
201265
}
202-
queue.put(
203-
("mch_stats", msg_content),
204-
)
205-
if (
206-
batch_idx % interval_num_batches_show_qps == 0
207-
or batch_idx == len(train_dataloader) - 1
208-
):
209-
if batch_idx == 0:
210-
# skip the first batch since it is not a full batch
211-
continue
212-
logger.info(f"[rank {rank}] batch_idx: {batch_idx} get the stats")
213-
# synchronize all the threads to get the exact number of batches
214-
torch.cuda.synchronize()
215-
# calculate the qps
216-
# NOTE: why do this qps calculation every interval_num_batches_show_qps batches?
217-
# because performing this calculation needs to synchronize all the ranks by calling torch.cuda.synchronize()
218-
# and this is a heavy operation (takes several milliseconds). So we only do this calculation every
219-
# interval_num_batches_show_qps batches to reduce the overhead.
220-
## get per batch time list by calculating the time difference between the start and end events of each batch
221-
per_batch_time_list = []
222-
for i in range(len(starter_list)):
223-
per_batch_time_list.append(
224-
starter_list[i].elapsed_time(ender_list[i]) / 1000
225-
) # convert to seconds by dividing by 1000
226-
## calculate the total time in the interval
227-
total_time_in_interval = sum(per_batch_time_list)
228-
## calculate the total number of queries in the interval
229-
total_num_queries_in_interval = sum(num_queries_per_batch_list)
230-
## fabricate the message and total_num_queries_in_interval to the queue
231-
interval_start_batch_idx = (
232-
batch_idx - interval_num_batches_show_qps
233-
if batch_idx >= interval_num_batches_show_qps
234-
else 0
235-
) # the start batch index of the interval
236-
interval_start_batch_cnt = (
237-
batch_cnt - interval_num_batches_show_qps
238-
if batch_cnt >= interval_num_batches_show_qps
239-
else 0
240-
) # the start batch counter of the interval
241-
interval_end_batch_idx = (
242-
batch_idx # the end batch index of the interval
243-
)
244-
## fabricate the message content
245-
msg_content = {
246-
"epoch_idx": epoch_idx,
247-
"rank": rank,
248-
"interval_start_batch_idx": interval_start_batch_idx,
249-
"interval_end_batch_idx": interval_end_batch_idx,
250-
"interval_start_batch_cnt": interval_start_batch_cnt,
251-
"interval_end_batch_cnt": batch_cnt,
252-
"per_batch_time_list": per_batch_time_list,
253-
"per_batch_num_queries_list": num_queries_per_batch_list,
254-
}
255-
## put the message into the queue
256-
queue.put(("duration_and_num_queries", msg_content))
257-
## also fabricate the message for loss
258-
msg_content = {
259-
"epoch_idx": epoch_idx,
260-
"rank": rank,
261-
"interval_start_batch_idx": interval_start_batch_idx,
262-
"interval_end_batch_idx": interval_end_batch_idx,
263-
"interval_start_batch_cnt": interval_start_batch_cnt,
264-
"interval_end_batch_cnt": batch_cnt,
265-
"per_batch_loss_list": loss_per_batch_list,
266-
}
267-
## put the message into the queue
268-
queue.put(("training_metrics", msg_content))
269-
# calculate QPS per statistic interval
270-
qps_per_interval = (
271-
total_num_queries_in_interval / total_time_in_interval
272-
)
273-
total_time_in_training += total_time_in_interval
274-
total_num_queries_in_training += total_num_queries_in_interval
275-
pbar.set_postfix(
276-
{
277-
"QPS": qps_per_interval,
278-
}
279-
)
280-
pbar.update(interval_num_batches_show_qps)
281-
# reset the lists
282-
starter_list = []
283-
ender_list = []
284-
num_queries_per_batch_list = []
285-
loss_per_batch_list = []
286-
# after training of each epoch, do validation
287-
logger.info(f"[rank {rank}] do validation after training of epoch {epoch_idx}")
288-
metric_values = evaluation(
289-
metric_modules,
290-
model,
291-
test_dataloader,
292-
device,
293-
nonzch_remapper if len(args.zch_method) == 0 else None,
294-
)
295-
# print the evaluation result
296-
print(f"Evaluation result: {metric_values}")
297-
# send the evaluation result to the queue
298-
msg_content = {
299-
"epoch_idx": epoch_idx,
300-
"rank": rank,
301-
"eval_result_dict": metric_values,
302-
}
303-
queue.put(("eval_result", msg_content))
266+
queue.put(("eval_result", msg_content))
304267

305-
logger.info(
306-
f"[rank {rank}] finished, sleep for 15 seconds before sending finish signal and exit"
268+
prof.export_chrome_trace(
269+
f"/home/lizhouyu/tmp/trace_noremap_{args.zch_method if len(args.zch_method) > 0 else 'nonzch'}_{args.model_name}_fullloop_tbsize_{args.num_embeddings}_rank{rank}.json"
307270
)
308-
time.sleep(15)
271+
time.sleep(10)
309272
queue.put(("finished", {"rank": rank}))
310273
print("finished")
311274
return
@@ -809,8 +772,6 @@ def statistic(args: argparse.Namespace, queue: multiprocessing.Queue) -> None:
809772
if __name__ == "__main__":
810773
args: argparse.Namespace = parse_args(sys.argv[1:])
811774

812-
__builtins__.__dict__["profile"] = LineProfiler()
813-
814775
# set environment variables
815776
os.environ["MASTER_ADDR"] = str("localhost")
816777
os.environ["MASTER_PORT"] = str(get_free_port())
Binary file not shown.

torchrec/distributed/benchmark/benchmark_zch/data/configs/criteo_kaggle.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
dataset_path: "/home/lizhouyu/oss_github/dlrm/torchrec_dlrm/criteo_1tb/criteo_kaggle_processed"
1+
dataset_path: "/home/lizhouyu/oss_github/dlrm/torchrec_dlrm/criteo_1tb/criteo_kaggle_processed_small"
22
batch_size: 4096
33
seed: 0
44
multitask_configs:

torchrec/distributed/benchmark/benchmark_zch/data/configs/movielens_1m.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
dataset_path: "/home/lizhouyu/oss_github/generative-recommenders/tmp/data/ml-1m"
22
batch_size: 16
3-
train_split_percentage: 0.75
3+
train_split_percentage: 0.8
44
num_workers: 4
55
prefetch_factor: 4
66
max_num_candidates: 10

0 commit comments

Comments
 (0)