Skip to content

Commit 5c7574c

Browse files
lizhouyufacebook-github-bot
authored andcommitted
Check QPS Regresses of MPZCH (#3205)
Summary: Pull Request resolved: #3205 Check the causes of QPS regress of MPZCH modules Differential Revision: D77189125
1 parent 45d5c4d commit 5c7574c

File tree

5 files changed

+410
-165
lines changed

5 files changed

+410
-165
lines changed

torchrec/distributed/benchmark/benchmark_zch/benchmark_zch.py

Lines changed: 128 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,22 @@
99
import argparse
1010
import csv
1111
import json
12-
import logging
1312
import multiprocessing
1413
import os
1514
import sys
1615
import time
1716

18-
from typing import cast, Dict, Iterator, List, Optional
17+
from typing import Dict, List, Optional
1918

2019
import numpy as np
2120

2221
import torch
2322
import torch.nn as nn
2423

25-
from line_profiler import LineProfiler
26-
2724
from torch import distributed as dist
25+
26+
# pyre-ignore [21] # NOTE: pyre reports ProfilerActivity is not in torchrec.distributed, but it is in torch.profiler according to https://docs.pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
27+
from torch.profiler import profile, ProfilerActivity, record_function
2828
from torch.utils.data import DataLoader
2929
from torch.utils.tensorboard import SummaryWriter # @manual //caffe2:torch_tensorboard
3030
from torchrec.metrics.metrics_namespace import MetricPrefix
@@ -38,7 +38,6 @@
3838
from .benchmark_zch_utils import BenchmarkMCProbe, get_logger, get_module_from_instance
3939

4040
from .data.get_dataloader import get_dataloader
41-
from .data.get_metric_modules import get_metric_modules
4241
from .data.nonzch_remapper import NonZchModRemapperModule
4342
from .models.apply_optimizers import (
4443
apply_dense_optimizers,
@@ -86,7 +85,6 @@ def main(rank: int, args: argparse.Namespace, queue: multiprocessing.Queue) -> N
8685

8786
# get metric modules
8887
logger.info(f"[rank {rank}] get metric modules")
89-
metric_modules = get_metric_modules(rank, args, device)
9088

9189
# make the model
9290
logger.info(f"[rank {rank}] make model")
@@ -146,166 +144,135 @@ def main(rank: int, args: argparse.Namespace, queue: multiprocessing.Queue) -> N
146144
total_num_queries_in_training = 0
147145

148146
# train the model
149-
logger.info(f"[rank {rank}] train the model")
150-
batch_cnt = 0
151-
for epoch_idx in range(args.epochs):
152-
model.train()
153-
starter_list = []
154-
ender_list = []
155-
num_queries_per_batch_list = []
156-
loss_per_batch_list = []
157-
pbar = tqdm(train_dataloader, desc=f"Epoch {epoch_idx}")
158-
for batch_idx, batch in enumerate(pbar):
159-
# batch = batch.to(device)
160-
batch = batch.to(device)
161-
# remap the batch if needed
162-
if len(args.zch_method) == 0:
163-
# pyre-ignore [16] # NOTE: pyre reports nonzch_remapper can be None, but when reach to this branch of condition, we know it is not None
164-
batch = nonzch_remapper.remap(batch)
165-
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(
166-
enable_timing=True
167-
)
168-
if True or len(args.zch_method) > 0:
169-
benchmark_probe.record_mcec_state(stage="before_fwd")
170-
# train model
171-
starter.record()
172-
## zero the gradients
173-
optimizer.zero_grad()
174-
## forward pass
175-
loss, (loss_values, pred_logits, labels, weights) = model(batch)
176-
## backward pass
177-
loss.backward()
178-
## update weights
179-
optimizer.step()
180-
ender.record()
181-
# update the batch counter
182-
batch_cnt += 1
183-
# append the start and end events to the lists
184-
starter_list.append(starter)
185-
ender_list.append(ender)
186-
# do training metrics and QPS statistics
187-
num_queries_per_batch = len(labels)
188-
num_queries_per_batch_list.append(num_queries_per_batch)
189-
loss_per_batch_list.append(loss.cpu().item())
190-
# do zch statistics
191-
benchmark_probe.record_mcec_state(stage="after_fwd")
192-
# update zch statistics
193-
benchmark_probe.update()
194-
# push the zch stats to the queue
147+
## code for profiling
148+
# pyre-ignore [16] # NOTE: pyre reports ProfilerActivity is not in torchrec.distributed, but it is in torch.profiler according to https://docs.pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
149+
activities = [ProfilerActivity.CPU]
150+
if torch.cuda.is_available():
151+
device = torch.device("cuda")
152+
# pyre-ignore [16] # NOTE: pyre reports ProfilerActivity is not in torchrec.distributed, but it is in torch.profiler according to https://docs.pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
153+
activities += [ProfilerActivity.CUDA]
154+
## end code for profiling
155+
with profile(activities=activities, record_shapes=True) as prof:
156+
for epoch_idx in range(args.epochs):
157+
model.train()
158+
starter_list = []
159+
ender_list = []
160+
num_queries_per_batch_list = []
161+
pbar = tqdm(train_dataloader, desc=f"Epoch {epoch_idx}")
162+
for batch_idx, batch in enumerate(pbar):
163+
# batch = batch.to(device)
164+
batch = batch.to(device)
165+
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(
166+
enable_timing=True
167+
)
168+
if len(args.zch_method) > 0:
169+
benchmark_probe.record_mcec_state(stage="before_fwd")
170+
# forward pass
171+
with record_function(f"model training on batch {batch_idx}"):
172+
starter.record()
173+
# zero the gradients
174+
optimizer.zero_grad()
175+
loss, (loss_values, pred_logits, labels) = model(batch)
176+
loss.backward()
177+
optimizer.step()
178+
ender.record()
179+
# do statistics
180+
num_queries_per_batch = len(labels)
181+
starter_list.append(starter)
182+
ender_list.append(ender)
183+
num_queries_per_batch_list.append(num_queries_per_batch)
184+
if len(args.zch_method) > 0:
185+
benchmark_probe.record_mcec_state(stage="after_fwd")
186+
# update zch statistics
187+
benchmark_probe.update()
188+
# push the zch stats to the queue
189+
msg_content = {
190+
"epoch_idx": epoch_idx,
191+
"batch_idx": batch_idx,
192+
"rank": rank,
193+
"mch_stats": benchmark_probe.get_mch_stats(),
194+
}
195+
queue.put(
196+
("mch_stats", msg_content),
197+
)
198+
if (
199+
batch_idx % interval_num_batches_show_qps == 0
200+
or batch_idx == len(train_dataloader) - 1
201+
):
202+
if batch_idx == 0:
203+
# skip the first batch since it is not a full batch
204+
continue
205+
# synchronize all the threads to get the exact number of batches
206+
torch.cuda.synchronize()
207+
# calculate the qps
208+
# NOTE: why do this qps calculation every interval_num_batches_show_qps batches?
209+
# because performing this calculation needs to synchronize all the ranks by calling torch.cuda.synchronize()
210+
# and this is a heavy operation (takes several milliseconds). So we only do this calculation every
211+
# interval_num_batches_show_qps batches to reduce the overhead.
212+
## get per batch time list by calculating the time difference between the start and end events of each batch
213+
per_batch_time_list = []
214+
for i in range(len(starter_list)):
215+
per_batch_time_list.append(
216+
starter_list[i].elapsed_time(ender_list[i]) / 1000
217+
) # convert to seconds by dividing by 1000
218+
## calculate the total time in the interval
219+
total_time_in_interval = sum(per_batch_time_list)
220+
## calculate the total number of queries in the interval
221+
total_num_queries_in_interval = sum(num_queries_per_batch_list)
222+
## fabricate the message and total_num_queries_in_interval to the queue
223+
interval_start_batch_idx = (
224+
batch_idx - interval_num_batches_show_qps
225+
if batch_idx >= interval_num_batches_show_qps
226+
else 0
227+
) # the start batch index of the interval
228+
interval_end_batch_idx = (
229+
batch_idx # the end batch index of the interval
230+
)
231+
## fabricate the message content
232+
msg_content = {
233+
"epoch_idx": epoch_idx,
234+
"rank": rank,
235+
"interval_start_batch_idx": interval_start_batch_idx,
236+
"interval_end_batch_idx": interval_end_batch_idx,
237+
"per_batch_time_list": per_batch_time_list,
238+
"per_batch_num_queries_list": num_queries_per_batch_list,
239+
}
240+
## put the message into the queue
241+
queue.put(("duration_and_num_queries", msg_content))
242+
qps_per_interval = (
243+
total_num_queries_in_interval / total_time_in_interval
244+
)
245+
total_time_in_training += total_time_in_interval
246+
total_num_queries_in_training += total_num_queries_in_interval
247+
pbar.set_postfix(
248+
{
249+
"QPS": qps_per_interval,
250+
}
251+
)
252+
pbar.update(interval_num_batches_show_qps)
253+
# reset the lists
254+
starter_list = []
255+
ender_list = []
256+
num_queries_per_batch_list = []
257+
if batch_idx > 50:
258+
# skip the first batch since it is not a full batch
259+
break
260+
# after each epoch, do validation
261+
eval_result_dict = evaluation(model, test_dataloader, device)
262+
# print the evaluation result
263+
print(f"Evaluation result: {eval_result_dict}")
264+
# send the evaluation result to the queue
195265
msg_content = {
196-
"epoch_idx": epoch_idx,
197-
"batch_idx": batch_idx,
198-
"batch_cnt": batch_cnt,
266+
"epoch_idx": args.epochs,
199267
"rank": rank,
200-
"mch_stats": benchmark_probe.get_mch_stats(),
268+
"eval_result_dict": eval_result_dict,
201269
}
202-
queue.put(
203-
("mch_stats", msg_content),
204-
)
205-
if (
206-
batch_idx % interval_num_batches_show_qps == 0
207-
or batch_idx == len(train_dataloader) - 1
208-
):
209-
if batch_idx == 0:
210-
# skip the first batch since it is not a full batch
211-
continue
212-
logger.info(f"[rank {rank}] batch_idx: {batch_idx} get the stats")
213-
# synchronize all the threads to get the exact number of batches
214-
torch.cuda.synchronize()
215-
# calculate the qps
216-
# NOTE: why do this qps calculation every interval_num_batches_show_qps batches?
217-
# because performing this calculation needs to synchronize all the ranks by calling torch.cuda.synchronize()
218-
# and this is a heavy operation (takes several milliseconds). So we only do this calculation every
219-
# interval_num_batches_show_qps batches to reduce the overhead.
220-
## get per batch time list by calculating the time difference between the start and end events of each batch
221-
per_batch_time_list = []
222-
for i in range(len(starter_list)):
223-
per_batch_time_list.append(
224-
starter_list[i].elapsed_time(ender_list[i]) / 1000
225-
) # convert to seconds by dividing by 1000
226-
## calculate the total time in the interval
227-
total_time_in_interval = sum(per_batch_time_list)
228-
## calculate the total number of queries in the interval
229-
total_num_queries_in_interval = sum(num_queries_per_batch_list)
230-
## fabricate the message and total_num_queries_in_interval to the queue
231-
interval_start_batch_idx = (
232-
batch_idx - interval_num_batches_show_qps
233-
if batch_idx >= interval_num_batches_show_qps
234-
else 0
235-
) # the start batch index of the interval
236-
interval_start_batch_cnt = (
237-
batch_cnt - interval_num_batches_show_qps
238-
if batch_cnt >= interval_num_batches_show_qps
239-
else 0
240-
) # the start batch counter of the interval
241-
interval_end_batch_idx = (
242-
batch_idx # the end batch index of the interval
243-
)
244-
## fabricate the message content
245-
msg_content = {
246-
"epoch_idx": epoch_idx,
247-
"rank": rank,
248-
"interval_start_batch_idx": interval_start_batch_idx,
249-
"interval_end_batch_idx": interval_end_batch_idx,
250-
"interval_start_batch_cnt": interval_start_batch_cnt,
251-
"interval_end_batch_cnt": batch_cnt,
252-
"per_batch_time_list": per_batch_time_list,
253-
"per_batch_num_queries_list": num_queries_per_batch_list,
254-
}
255-
## put the message into the queue
256-
queue.put(("duration_and_num_queries", msg_content))
257-
## also fabricate the message for loss
258-
msg_content = {
259-
"epoch_idx": epoch_idx,
260-
"rank": rank,
261-
"interval_start_batch_idx": interval_start_batch_idx,
262-
"interval_end_batch_idx": interval_end_batch_idx,
263-
"interval_start_batch_cnt": interval_start_batch_cnt,
264-
"interval_end_batch_cnt": batch_cnt,
265-
"per_batch_loss_list": loss_per_batch_list,
266-
}
267-
## put the message into the queue
268-
queue.put(("training_metrics", msg_content))
269-
# calculate QPS per statistic interval
270-
qps_per_interval = (
271-
total_num_queries_in_interval / total_time_in_interval
272-
)
273-
total_time_in_training += total_time_in_interval
274-
total_num_queries_in_training += total_num_queries_in_interval
275-
pbar.set_postfix(
276-
{
277-
"QPS": qps_per_interval,
278-
}
279-
)
280-
pbar.update(interval_num_batches_show_qps)
281-
# reset the lists
282-
starter_list = []
283-
ender_list = []
284-
num_queries_per_batch_list = []
285-
loss_per_batch_list = []
286-
# after training of each epoch, do validation
287-
logger.info(f"[rank {rank}] do validation after training of epoch {epoch_idx}")
288-
metric_values = evaluation(
289-
metric_modules,
290-
model,
291-
test_dataloader,
292-
device,
293-
nonzch_remapper if len(args.zch_method) == 0 else None,
294-
)
295-
# print the evaluation result
296-
print(f"Evaluation result: {metric_values}")
297-
# send the evaluation result to the queue
298-
msg_content = {
299-
"epoch_idx": epoch_idx,
300-
"rank": rank,
301-
"eval_result_dict": metric_values,
302-
}
303-
queue.put(("eval_result", msg_content))
270+
queue.put(("eval_result", msg_content))
304271

305-
logger.info(
306-
f"[rank {rank}] finished, sleep for 15 seconds before sending finish signal and exit"
272+
prof.export_chrome_trace(
273+
f"/home/lizhouyu/tmp/trace_noremap_{args.zch_method if len(args.zch_method) > 0 else 'nonzch'}_{args.model_name}_fullloop_tbsize_{args.num_embeddings}_rank{rank}.json"
307274
)
308-
time.sleep(15)
275+
time.sleep(10)
309276
queue.put(("finished", {"rank": rank}))
310277
print("finished")
311278
return
@@ -809,8 +776,6 @@ def statistic(args: argparse.Namespace, queue: multiprocessing.Queue) -> None:
809776
if __name__ == "__main__":
810777
args: argparse.Namespace = parse_args(sys.argv[1:])
811778

812-
__builtins__.__dict__["profile"] = LineProfiler()
813-
814779
# set environment variables
815780
os.environ["MASTER_ADDR"] = str("localhost")
816781
os.environ["MASTER_PORT"] = str(get_free_port())
Binary file not shown.

torchrec/distributed/benchmark/benchmark_zch/data/configs/criteo_kaggle.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
dataset_path: "/home/lizhouyu/oss_github/dlrm/torchrec_dlrm/criteo_1tb/criteo_kaggle_processed"
1+
dataset_path: "/home/lizhouyu/oss_github/dlrm/torchrec_dlrm/criteo_1tb/criteo_kaggle_processed_small"
22
batch_size: 4096
33
seed: 0
44
multitask_configs:

torchrec/distributed/benchmark/benchmark_zch/data/configs/movielens_1m.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
dataset_path: "/home/lizhouyu/oss_github/generative-recommenders/tmp/data/ml-1m"
22
batch_size: 16
3-
train_split_percentage: 0.75
3+
train_split_percentage: 0.8
44
num_workers: 4
55
prefetch_factor: 4
66
max_num_candidates: 10

0 commit comments

Comments
 (0)