Skip to content

Check QPS Regresses of MPZCH #3205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
291 changes: 120 additions & 171 deletions torchrec/distributed/benchmark/benchmark_zch/benchmark_zch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,22 @@
import argparse
import csv
import json
import logging
import multiprocessing
import os
import sys
import time

from typing import cast, Dict, Iterator, List, Optional
from typing import Dict, List, Optional

import numpy as np

import torch
import torch.nn as nn

from line_profiler import LineProfiler

from torch import distributed as dist

# pyre-ignore [21] # NOTE: pyre reports ProfilerActivity is not in torchrec.distributed, but it is in torch.profiler according to https://docs.pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
from torch.profiler import profile, ProfilerActivity, record_function
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter # @manual //caffe2:torch_tensorboard
from torchrec.metrics.metrics_namespace import MetricPrefix
Expand All @@ -38,7 +38,6 @@
from .benchmark_zch_utils import BenchmarkMCProbe, get_logger, get_module_from_instance

from .data.get_dataloader import get_dataloader
from .data.get_metric_modules import get_metric_modules
from .data.nonzch_remapper import NonZchModRemapperModule
from .models.apply_optimizers import (
apply_dense_optimizers,
Expand Down Expand Up @@ -80,13 +79,6 @@ def main(rank: int, args: argparse.Namespace, queue: multiprocessing.Queue) -> N
# get training dataset
logger.info(f"[rank {rank}] get train dataloader")
train_dataloader = get_dataloader(args.dataset_name, args, "train")
# get test dataset
logger.info(f"[rank {rank}] get test dataloader")
test_dataloader = get_dataloader(args.dataset_name, args, "val")

# get metric modules
logger.info(f"[rank {rank}] get metric modules")
metric_modules = get_metric_modules(rank, args, device)

# make the model
logger.info(f"[rank {rank}] make model")
Expand Down Expand Up @@ -146,166 +138,125 @@ def main(rank: int, args: argparse.Namespace, queue: multiprocessing.Queue) -> N
total_num_queries_in_training = 0

# train the model
logger.info(f"[rank {rank}] train the model")
batch_cnt = 0
for epoch_idx in range(args.epochs):
model.train()
starter_list = []
ender_list = []
num_queries_per_batch_list = []
loss_per_batch_list = []
pbar = tqdm(train_dataloader, desc=f"Epoch {epoch_idx}")
for batch_idx, batch in enumerate(pbar):
# batch = batch.to(device)
batch = batch.to(device)
# remap the batch if needed
if len(args.zch_method) == 0:
# pyre-ignore [16] # NOTE: pyre reports nonzch_remapper can be None, but when reach to this branch of condition, we know it is not None
batch = nonzch_remapper.remap(batch)
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(
enable_timing=True
)
if True or len(args.zch_method) > 0:
benchmark_probe.record_mcec_state(stage="before_fwd")
# train model
starter.record()
## zero the gradients
optimizer.zero_grad()
## forward pass
loss, (loss_values, pred_logits, labels, weights) = model(batch)
## backward pass
loss.backward()
## update weights
optimizer.step()
ender.record()
# update the batch counter
batch_cnt += 1
# append the start and end events to the lists
starter_list.append(starter)
ender_list.append(ender)
# do training metrics and QPS statistics
num_queries_per_batch = len(labels)
num_queries_per_batch_list.append(num_queries_per_batch)
loss_per_batch_list.append(loss.cpu().item())
# do zch statistics
benchmark_probe.record_mcec_state(stage="after_fwd")
# update zch statistics
benchmark_probe.update()
# push the zch stats to the queue
msg_content = {
"epoch_idx": epoch_idx,
"batch_idx": batch_idx,
"batch_cnt": batch_cnt,
"rank": rank,
"mch_stats": benchmark_probe.get_mch_stats(),
}
queue.put(
("mch_stats", msg_content),
)
if (
batch_idx % interval_num_batches_show_qps == 0
or batch_idx == len(train_dataloader) - 1
):
if batch_idx == 0:
# skip the first batch since it is not a full batch
continue
logger.info(f"[rank {rank}] batch_idx: {batch_idx} get the stats")
# synchronize all the threads to get the exact number of batches
torch.cuda.synchronize()
# calculate the qps
# NOTE: why do this qps calculation every interval_num_batches_show_qps batches?
# because performing this calculation needs to synchronize all the ranks by calling torch.cuda.synchronize()
# and this is a heavy operation (takes several milliseconds). So we only do this calculation every
# interval_num_batches_show_qps batches to reduce the overhead.
## get per batch time list by calculating the time difference between the start and end events of each batch
per_batch_time_list = []
for i in range(len(starter_list)):
per_batch_time_list.append(
starter_list[i].elapsed_time(ender_list[i]) / 1000
) # convert to seconds by dividing by 1000
## calculate the total time in the interval
total_time_in_interval = sum(per_batch_time_list)
## calculate the total number of queries in the interval
total_num_queries_in_interval = sum(num_queries_per_batch_list)
## fabricate the message and total_num_queries_in_interval to the queue
interval_start_batch_idx = (
batch_idx - interval_num_batches_show_qps
if batch_idx >= interval_num_batches_show_qps
else 0
) # the start batch index of the interval
interval_start_batch_cnt = (
batch_cnt - interval_num_batches_show_qps
if batch_cnt >= interval_num_batches_show_qps
else 0
) # the start batch counter of the interval
interval_end_batch_idx = (
batch_idx # the end batch index of the interval
)
## fabricate the message content
msg_content = {
"epoch_idx": epoch_idx,
"rank": rank,
"interval_start_batch_idx": interval_start_batch_idx,
"interval_end_batch_idx": interval_end_batch_idx,
"interval_start_batch_cnt": interval_start_batch_cnt,
"interval_end_batch_cnt": batch_cnt,
"per_batch_time_list": per_batch_time_list,
"per_batch_num_queries_list": num_queries_per_batch_list,
}
## put the message into the queue
queue.put(("duration_and_num_queries", msg_content))
## also fabricate the message for loss
msg_content = {
"epoch_idx": epoch_idx,
"rank": rank,
"interval_start_batch_idx": interval_start_batch_idx,
"interval_end_batch_idx": interval_end_batch_idx,
"interval_start_batch_cnt": interval_start_batch_cnt,
"interval_end_batch_cnt": batch_cnt,
"per_batch_loss_list": loss_per_batch_list,
}
## put the message into the queue
queue.put(("training_metrics", msg_content))
# calculate QPS per statistic interval
qps_per_interval = (
total_num_queries_in_interval / total_time_in_interval
## code for profiling
# pyre-ignore [16] # NOTE: pyre reports ProfilerActivity is not in torchrec.distributed, but it is in torch.profiler according to https://docs.pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
activities = [ProfilerActivity.CPU]
if torch.cuda.is_available():
device = torch.device("cuda")
# pyre-ignore [16] # NOTE: pyre reports ProfilerActivity is not in torchrec.distributed, but it is in torch.profiler according to https://docs.pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
activities += [ProfilerActivity.CUDA]
## end code for profiling
with profile(activities=activities, record_shapes=True) as prof:
for epoch_idx in range(args.epochs):
model.train()
starter_list = []
ender_list = []
num_queries_per_batch_list = []
pbar = tqdm(train_dataloader, desc=f"Epoch {epoch_idx}")
for batch_idx, batch in enumerate(pbar):
# batch = batch.to(device)
batch = batch.to(device)
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(
enable_timing=True
)
total_time_in_training += total_time_in_interval
total_num_queries_in_training += total_num_queries_in_interval
pbar.set_postfix(
{
"QPS": qps_per_interval,
if len(args.zch_method) > 0:
benchmark_probe.record_mcec_state(stage="before_fwd")
# forward pass
with record_function(f"model training on batch {batch_idx}"):
starter.record()
# zero the gradients
optimizer.zero_grad()
loss, (loss_values, pred_logits, labels) = model(batch)
loss.backward()
optimizer.step()
ender.record()
# do statistics
num_queries_per_batch = len(labels)
starter_list.append(starter)
ender_list.append(ender)
num_queries_per_batch_list.append(num_queries_per_batch)
if len(args.zch_method) > 0:
benchmark_probe.record_mcec_state(stage="after_fwd")
# update zch statistics
benchmark_probe.update()
# push the zch stats to the queue
msg_content = {
"epoch_idx": epoch_idx,
"batch_idx": batch_idx,
"rank": rank,
"mch_stats": benchmark_probe.get_mch_stats(),
}
)
pbar.update(interval_num_batches_show_qps)
# reset the lists
starter_list = []
ender_list = []
num_queries_per_batch_list = []
loss_per_batch_list = []
# after training of each epoch, do validation
logger.info(f"[rank {rank}] do validation after training of epoch {epoch_idx}")
metric_values = evaluation(
metric_modules,
model,
test_dataloader,
device,
nonzch_remapper if len(args.zch_method) == 0 else None,
)
# print the evaluation result
print(f"Evaluation result: {metric_values}")
# send the evaluation result to the queue
msg_content = {
"epoch_idx": epoch_idx,
"rank": rank,
"eval_result_dict": metric_values,
}
queue.put(("eval_result", msg_content))

logger.info(
f"[rank {rank}] finished, sleep for 15 seconds before sending finish signal and exit"
queue.put(
("mch_stats", msg_content),
)
if (
batch_idx % interval_num_batches_show_qps == 0
or batch_idx == len(train_dataloader) - 1
):
if batch_idx == 0:
# skip the first batch since it is not a full batch
continue
# synchronize all the threads to get the exact number of batches
torch.cuda.synchronize()
# calculate the qps
# NOTE: why do this qps calculation every interval_num_batches_show_qps batches?
# because performing this calculation needs to synchronize all the ranks by calling torch.cuda.synchronize()
# and this is a heavy operation (takes several milliseconds). So we only do this calculation every
# interval_num_batches_show_qps batches to reduce the overhead.
## get per batch time list by calculating the time difference between the start and end events of each batch
per_batch_time_list = []
for i in range(len(starter_list)):
per_batch_time_list.append(
starter_list[i].elapsed_time(ender_list[i]) / 1000
) # convert to seconds by dividing by 1000
## calculate the total time in the interval
total_time_in_interval = sum(per_batch_time_list)
## calculate the total number of queries in the interval
total_num_queries_in_interval = sum(num_queries_per_batch_list)
## fabricate the message and total_num_queries_in_interval to the queue
interval_start_batch_idx = (
batch_idx - interval_num_batches_show_qps
if batch_idx >= interval_num_batches_show_qps
else 0
) # the start batch index of the interval
interval_end_batch_idx = (
batch_idx # the end batch index of the interval
)
## fabricate the message content
msg_content = {
"epoch_idx": epoch_idx,
"rank": rank,
"interval_start_batch_idx": interval_start_batch_idx,
"interval_end_batch_idx": interval_end_batch_idx,
"per_batch_time_list": per_batch_time_list,
"per_batch_num_queries_list": num_queries_per_batch_list,
}
## put the message into the queue
queue.put(("duration_and_num_queries", msg_content))
qps_per_interval = (
total_num_queries_in_interval / total_time_in_interval
)
total_time_in_training += total_time_in_interval
total_num_queries_in_training += total_num_queries_in_interval
pbar.set_postfix(
{
"QPS": qps_per_interval,
}
)
pbar.update(interval_num_batches_show_qps)
# reset the lists
starter_list = []
ender_list = []
num_queries_per_batch_list = []
if batch_idx > 50:
# skip rest after collecting data for 50 batches
break
# skip evaluation

prof.export_chrome_trace(
f"/home/lizhouyu/tmp/trace_noremap_{args.zch_method if len(args.zch_method) > 0 else 'nonzch'}_{args.model_name}_fullloop_tbsize_{args.num_embeddings}_rank{rank}.json"
)
time.sleep(15)
time.sleep(10)
queue.put(("finished", {"rank": rank}))
print("finished")
return
Expand Down Expand Up @@ -809,8 +760,6 @@ def statistic(args: argparse.Namespace, queue: multiprocessing.Queue) -> None:
if __name__ == "__main__":
args: argparse.Namespace = parse_args(sys.argv[1:])

__builtins__.__dict__["profile"] = LineProfiler()

# set environment variables
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
dataset_path: "/home/lizhouyu/oss_github/dlrm/torchrec_dlrm/criteo_1tb/criteo_kaggle_processed"
dataset_path: "/home/lizhouyu/oss_github/dlrm/torchrec_dlrm/criteo_1tb/criteo_kaggle_processed_small"
batch_size: 4096
seed: 0
multitask_configs:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
dataset_path: "/home/lizhouyu/oss_github/generative-recommenders/tmp/data/ml-1m"
batch_size: 16
train_split_percentage: 0.75
train_split_percentage: 0.8
num_workers: 4
prefetch_factor: 4
max_num_candidates: 10
Expand Down
Loading
Loading