From 999dbfcbe81cf43b870fdf6a73aaf5b68ab60c22 Mon Sep 17 00:00:00 2001 From: Yernar Sadybekov Date: Thu, 17 Jul 2025 11:46:45 -0700 Subject: [PATCH 1/2] Refactor benchmark utilities: centralize core logic into `_run_benchmark_core` (#3191) Summary: Removed redundant logic in `benchmark` and `benchmark_func` by moving shared timing, memory, and profiling code into a new `_run_benchmark_core` function. Ensured the backward-compatibility. Differential Revision: D78290979 --- .../distributed/benchmark/benchmark_utils.py | 372 ++++++++---------- 1 file changed, 166 insertions(+), 206 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index b954dd0c1..0cb2147b8 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -655,62 +655,76 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module: return sharded_module if not benchmark_unsharded_module else module -def benchmark( +def _run_benchmark_core( name: str, - model: torch.nn.Module, - warmup_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]], - bench_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]], - prof_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]], + run_iter_fn: Callable[[], None], + profile_iter_fn: Optional[Callable[[Any], None]], # pyre-ignore [2] world_size: int, - output_dir: str, - num_benchmarks: int, - # pyre-ignore[2] - func_to_benchmark: Any, - benchmark_func_kwargs: Optional[Dict[str, Any]], rank: int, - enable_logging: bool = True, - device_type: str = "cuda", - benchmark_unsharded_module: bool = False, + num_benchmarks: int, + device_type: str, + output_dir: str, + pre_gpu_load: int = 0, + export_stacks: bool = False, + reset_accumulated_memory_stats: bool = False, ) -> BenchmarkResult: - memory_stats: List[MemoryStats] = [] - if enable_logging: - logger.info(f" BENCHMARK_MODEL[{name}]:\n{model}") + """Internal helper that contains the core benchmarking logic shared by + ``benchmark`` and ``benchmark_func``. All heavy–lifting (timing, memory + accounting, optional profiling) happens here so the public helpers can stay + small and focused on preparing the callables to execute. - for _input in warmup_inputs: - model(_input) + Args: + name: Human-readable benchmark name. + run_iter_fn: Zero-arg callable that executes one measured iteration. + profile_iter_fn: Optional callable that receives a ``torch.profiler`` + instance and runs the iterations that should be captured. + world_size, rank: Distributed context to correctly reset / collect GPU + stats. ``rank == -1`` means single-process mode. + num_benchmarks: Number of measured iterations. + device_type: "cuda" or "cpu". + output_dir: Where to write chrome traces / stack files. + pre_gpu_load: Number of dummy matmul operations to run before the first + measured iteration (helps simulating a loaded allocator). + export_stacks: Whether to export flamegraph-compatible stack files. + reset_accumulated_memory_stats: Whether to reset accumulated memory + stats in addition to peak memory stats. + """ + # Preparation & memory reset if device_type == "cuda": if rank == -1: - # Reset memory for measurement, no process per rank so do all for di in range(world_size): torch.cuda.reset_peak_memory_stats(di) + if reset_accumulated_memory_stats: + torch.cuda.reset_accumulated_memory_stats(di) else: torch.cuda.reset_peak_memory_stats(rank) + if reset_accumulated_memory_stats: + torch.cuda.reset_accumulated_memory_stats(rank) - start = [] - end = [] - if device_type == "cuda": - # Measure time taken for batches in bench_inputs - start = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)] - end = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)] - - if benchmark_func_kwargs is None: - # Need this to unwrap - benchmark_func_kwargs = {} + # Optional allocator warm-up to create fragmentation similar to production + if pre_gpu_load and device_type == "cuda": + _tmp = torch.rand(16384, 16384, device="cuda") + for _ in range(pre_gpu_load): + _tmp = _tmp * torch.rand(16384, 16384, device="cuda") - times = [] + # Timing loop + start_events, end_events, times = [], [], [] if device_type == "cuda": + start_events = [ + torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks) + ] + end_events = [ + torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks) + ] for i in range(num_benchmarks): - start[i].record() - func_to_benchmark(model, bench_inputs, **benchmark_func_kwargs) - end[i].record() - elif device_type == "cpu": - times = timeit.repeat( - lambda: func_to_benchmark(model, bench_inputs, **benchmark_func_kwargs), - number=1, - repeat=num_benchmarks, - ) + start_events[i].record() + run_iter_fn() + end_events[i].record() + else: + times = timeit.repeat(run_iter_fn, number=1, repeat=num_benchmarks) + # Make sure all kernels are finished before reading timers / stats if device_type == "cuda": if rank == -1: for di in range(world_size): @@ -718,79 +732,118 @@ def benchmark( else: torch.cuda.synchronize(rank) - # TODO: First Benchmark Run for Eager Mode produces outlier + # First Benchmark Run for Eager Mode produces outlier # Start counting after first as workaround for standard deviation if device_type == "cuda": elapsed_time = torch.tensor( - [si.elapsed_time(ei) for si, ei in zip(start[1:], end[1:])] + [s.elapsed_time(e) for s, e in zip(start_events[1:], end_events[1:])] ) else: - elapsed_time = torch.tensor(times) * 1e3 + elapsed_time = torch.tensor(times) * 1e3 # convert seconds ➜ milliseconds + # Memory statistics collection + mem_stats: List[MemoryStats] = [] if device_type == "cuda": if rank == -1: - # Add up all memory allocated in inference mode for di in range(world_size): - memory_stats.append(MemoryStats.for_device(di)) + mem_stats.append(MemoryStats.for_device(di)) else: - # Only add up memory allocated for current rank in training mode - memory_stats.append(MemoryStats.for_device(rank)) + mem_stats.append(MemoryStats.for_device(rank)) - if output_dir != "": - # Only do profiling if output_dir is set + # Optional detailed profiling + if output_dir and profile_iter_fn and device_type == "cuda": - # pyre-ignore[2] - def trace_handler(prof) -> None: - total_average = prof.profiler.total_average() - logger.info(f" TOTAL_AVERAGE:\n{name}\n{total_average}") - dir_path: str = output_dir - - # only 1 rank should output in pg case, rank = 0 + def _trace_handler(prof: torch.profiler.profile) -> None: + total_avg = prof.profiler.total_average() + logger.info(f" TOTAL_AVERAGE:\n{name}\n{total_avg}") if rank > 0: return - - trace_file: str = f"{dir_path}/trace-{name}.json" - stacks_cpu_file = f"{dir_path}/stacks-cpu-{name}.stacks" - stacks_cuda_file = f"{dir_path}/stacks-cuda-{name}.stacks" + trace_file = f"{output_dir}/trace-{name}.json" logger.info(f" PROFILE[{name}].chrome_trace:{trace_file}") - prof.export_chrome_trace(trace_file) - prof.export_stacks(stacks_cpu_file, "self_cpu_time_total") - prof.export_stacks(stacks_cuda_file, "self_cuda_time_total") - - # - git clone https://github.com/brendangregg/FlameGraph - # - cd FlameGraph - # - ./flamegraph.pl --title "CPU time" --countname "us." profiler.stacks > perf_viz.svg - - if device_type == "cuda": - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - record_shapes=True, - profile_memory=True, - with_stack=True, - with_flops=True, - with_modules=True, - on_trace_ready=trace_handler, - ) as p: - for _input in prof_inputs: - with record_function("## forward ##"): - model(_input) - p.step() - - if rank == -1: - for di in range(torch.cuda.device_count()): - torch.cuda.synchronize(torch.device(f"cuda:{di}")) - else: - torch.cuda.synchronize() + if export_stacks: + prof.export_stacks( + f"{output_dir}/stacks-cpu-{name}.stacks", "self_cpu_time_total" + ) + prof.export_stacks( + f"{output_dir}/stacks-cuda-{name}.stacks", "self_cuda_time_total" + ) + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + profile_memory=True, + with_flops=True, + with_modules=True, + with_stack=export_stacks, + on_trace_ready=_trace_handler, + ) as prof: + profile_iter_fn(prof) + + # Synchronize again after profiling to guarantee deterministic ordering + if rank == -1: + for di in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{di}")) + else: + torch.cuda.synchronize(rank) return BenchmarkResult( - short_name=name, - elapsed_time=elapsed_time, - mem_stats=memory_stats, + short_name=name, elapsed_time=elapsed_time, mem_stats=mem_stats, rank=rank + ) + + +def benchmark( + name: str, + model: torch.nn.Module, + warmup_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]], + bench_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]], + prof_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]], + world_size: int, + output_dir: str, + num_benchmarks: int, + # pyre-ignore[2] + func_to_benchmark: Any, + benchmark_func_kwargs: Optional[Dict[str, Any]], + rank: int, + enable_logging: bool = True, + device_type: str = "cuda", + benchmark_unsharded_module: bool = False, +) -> BenchmarkResult: + if enable_logging: + logger.info(f" BENCHMARK_MODEL[{name}]:\n{model}") + + # Warm-up forwards to stabilize kernels / JIT compilation + for _input in warmup_inputs: + model(_input) + + if benchmark_func_kwargs is None: + benchmark_func_kwargs = {} + + run_iter_fn: Callable[[], None] = lambda: func_to_benchmark( + model, bench_inputs, **benchmark_func_kwargs + ) + + def _profile_iter_fn(prof: torch.profiler.profile) -> None: + for _input in prof_inputs: + with record_function("## forward ##"): + model(_input) + prof.step() + + return _run_benchmark_core( + name=name, + run_iter_fn=run_iter_fn, + profile_iter_fn=_profile_iter_fn if output_dir else None, + world_size=world_size, rank=rank, + num_benchmarks=num_benchmarks, + device_type=device_type, + output_dir=output_dir, + pre_gpu_load=0, + export_stacks=True, + reset_accumulated_memory_stats=False, ) @@ -809,124 +862,31 @@ def benchmark_func( device_type: str = "cuda", pre_gpu_load: int = 0, ) -> BenchmarkResult: - memory_stats: List[MemoryStats] = [] - if device_type == "cuda": - if rank == -1: - # Reset memory for measurement, no process per rank so do all - for di in range(world_size): - torch.cuda.reset_peak_memory_stats(di) - torch.cuda.reset_accumulated_memory_stats(di) - else: - torch.cuda.reset_peak_memory_stats(rank) - torch.cuda.reset_accumulated_memory_stats(rank) - - start = [] - end = [] - if device_type == "cuda": - # Measure time taken for batches in bench_inputs - start = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)] - end = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)] - if benchmark_func_kwargs is None: - # Need this to unwrap benchmark_func_kwargs = {} - times = [] - if device_type == "cuda": - a = torch.rand(16384, 16384, device="cuda") - for _ in range(pre_gpu_load): - a = a * torch.rand(16384, 16384, device="cuda") - for i in range(num_benchmarks): - start[i].record() - func_to_benchmark(bench_inputs, **benchmark_func_kwargs) - end[i].record() - elif device_type == "cpu": - if bench_inputs is None or len(bench_inputs) == 0: - times = timeit.repeat( - lambda: func_to_benchmark(**benchmark_func_kwargs), - number=1, - repeat=num_benchmarks, - ) - else: - times = timeit.repeat( - lambda: func_to_benchmark(bench_inputs, **benchmark_func_kwargs), - number=1, - repeat=num_benchmarks, - ) - - if device_type == "cuda": - if rank == -1: - for di in range(world_size): - torch.cuda.synchronize(di) - else: - torch.cuda.synchronize(rank) - - # TODO: First Benchmark Run for Eager Mode produces outlier - # Start counting after first as workaround for standard deviation - if device_type == "cuda": - elapsed_time = torch.tensor( - [si.elapsed_time(ei) for si, ei in zip(start[1:], end[1:])] - ) - else: - elapsed_time = torch.tensor(times) * 1e3 - - if device_type == "cuda": - if rank == -1: - # Add up all memory allocated in inference mode - for di in range(world_size): - memory_stats.append(MemoryStats.for_device(di)) - else: - # Only add up memory allocated for current rank in training mode - memory_stats.append(MemoryStats.for_device(rank)) - - if profile_dir != "": - # Only do profiling if output_dir is set - - # pyre-ignore[2] - def trace_handler(prof) -> None: - total_average = prof.profiler.total_average() - logger.info(f" TOTAL_AVERAGE:\n{name}\n{total_average}") - dir_path: str = profile_dir - if rank == 0: - trace_file: str = f"{dir_path}/trace-{name}.json" - else: - trace_file: str = f"{dir_path}/trace-{name}-{rank}.json" - return # only 1 rank should output in pg case, rank = 0 - logger.info(f" PROFILE[{name}].chrome_trace:{trace_file}") - prof.export_chrome_trace(trace_file) - - if device_type == "cuda": - a = torch.rand(16384, 16384, device="cuda") - for _ in range(pre_gpu_load): - a = a * torch.rand(16384, 16384, device="cuda") - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - record_shapes=True, - profile_memory=True, - with_flops=True, - with_modules=True, - with_stack=False, # usually we don't want to show the entire stack in the trace - on_trace_ready=trace_handler, - ) as p: - for i in range(num_profiles): - with record_function(f"## profile {i} ##"): - func_to_benchmark(prof_inputs, **benchmark_func_kwargs) - p.step() - - if rank == -1: - for di in range(torch.cuda.device_count()): - torch.cuda.synchronize(torch.device(f"cuda:{di}")) - else: - torch.cuda.synchronize() + run_iter_fn: Callable[[], None] = lambda: func_to_benchmark( + bench_inputs, **benchmark_func_kwargs + ) - return BenchmarkResult( - short_name=name, - elapsed_time=elapsed_time, - mem_stats=memory_stats, + def _profile_iter_fn(prof: torch.profiler.profile) -> None: + for i in range(num_profiles): + with record_function(f"## profile {i} ##"): + func_to_benchmark(prof_inputs, **benchmark_func_kwargs) + prof.step() + + return _run_benchmark_core( + name=name, + run_iter_fn=run_iter_fn, + profile_iter_fn=_profile_iter_fn if profile_dir else None, + world_size=world_size, rank=rank, + num_benchmarks=num_benchmarks, + device_type=device_type, + output_dir=profile_dir, + pre_gpu_load=pre_gpu_load, + export_stacks=False, + reset_accumulated_memory_stats=True, ) From ab2d68a0d857c007b8d82d71264ab6a88e9ab1ad Mon Sep 17 00:00:00 2001 From: Yernar Sadybekov Date: Thu, 17 Jul 2025 11:46:45 -0700 Subject: [PATCH 2/2] Added CPU runtime benchmarking and expanded BenchmarkResult to store CPU runtimes Summary: * Added feature to becnhmark CPU runtimes alongside with GPU measurements. * Expanded `BenchmarkResult` class to store both device measurements * Adapted files that are importing `BenchmarkResult` to ensure compatibility Now we can compare CPU and GPU runtimes without running and analyzing PyTorch Profiler. BenchmarkResults can help users to detect if the module/function/operator is CPU-bounded or GPU-bounded. Example metrics of FBGEMM operators: | Operator | CPU Runtime | GPU Runtime | GPU Memory | |---------------------------------------------|-------------|-------------|------------| | **[fallback] pytorch generic** | 5.41 ms | 2.66 ms | 1.01 GB | | **[Prod] KeyedTensor.regroup_dup** | 2.13 ms | 2.48 ms | 1.01 GB | | **[Module] KTRegroupAsDict_dup** | 0.14 ms | 0.75 ms | 1.01 GB | | **[2 Ops] permute_multi_embs_dup** | 0.88 ms | 1.44 ms | 1.01 GB | | **[1 Op] KT_regroup_dup** | 0.99 ms | 1.54 ms | 1.01 GB | We can see that `[fallback] pytorch generic` is CPU-bounded Differential Revision: D78503319 --- .../distributed/benchmark/benchmark_utils.py | 114 +++++++++++++----- .../sparse/tests/jagged_tensor_benchmark.py | 7 +- .../keyed_jagged_tensor_benchmark_lib.py | 3 +- 3 files changed, 87 insertions(+), 37 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index 0cb2147b8..607cece6a 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -134,26 +134,43 @@ def __str__(self) -> str: class BenchmarkResult: "Class for holding results of benchmark runs" short_name: str - elapsed_time: torch.Tensor # milliseconds + gpu_elapsed_time: torch.Tensor # milliseconds + cpu_elapsed_time: torch.Tensor # milliseconds mem_stats: List[MemoryStats] # memory stats per rank rank: int = -1 def __str__(self) -> str: - runtime = f"Runtime (P90): {self.runtime_percentile(90):.2f} ms" + gpu_runtime = ( + f"GPU Runtime (P90): {self.runtime_percentile(90, device='gpu'):.2f} ms" + ) + cpu_runtime = ( + f"CPU Runtime (P90): {self.runtime_percentile(90, device='cpu'):.2f} ms" + ) if len(self.mem_stats) == 0: - return f"{self.short_name: <{35}} | {runtime}" + return f"{self.short_name: <{35}} | {gpu_runtime} | {cpu_runtime}" mem_alloc = ( f"Peak Memory alloc (P90): {self.max_mem_alloc_percentile(90)/1000:.2f} GB" ) mem_reserved = f"Peak Memory reserved (P90): {self.max_mem_reserved_percentile(90)/1000:.2f} GB" - malloc_retries = f"Malloc retries (P50/P90/P100): {self.mem_retries(50) } / {self.mem_retries(90)} / {self.mem_retries(100)}" - return f"{self.short_name: <{35}} | {malloc_retries} | {runtime} | {mem_alloc} | {mem_reserved}" + malloc_retries = f"Malloc retries (P50/P90/P100): {self.mem_retries(50)} / {self.mem_retries(90)} / {self.mem_retries(100)}" + return f"{self.short_name: <{35}} | {malloc_retries} | {gpu_runtime} | {cpu_runtime} | {mem_alloc} | {mem_reserved}" def runtime_percentile( - self, percentile: int = 50, interpolation: str = "nearest" + self, + percentile: int = 50, + interpolation: str = "nearest", + device: str = "gpu", ) -> torch.Tensor: + """Return the runtime percentile for the requested timer. + + Args: + percentile: Percentile to compute. + interpolation: See ``torch.quantile``. + device: 'gpu' for CUDA event timings, 'cpu' for active CPU timings. + """ + timings = self.gpu_elapsed_time if device == "gpu" else self.cpu_elapsed_time return torch.quantile( - self.elapsed_time, + timings, percentile / 100.0, interpolation=interpolation, ) @@ -408,17 +425,26 @@ def write_report( num_requests: int, ) -> None: for benchmark_res in benchmark_results: - avg_dur_s = benchmark_res.elapsed_time.mean().item() * 1e-3 # time in seconds - std_dur_s = benchmark_res.elapsed_time.std().item() * 1e-3 # time in seconds + # GPU statistics + avg_dur_s_gpu = benchmark_res.gpu_elapsed_time.mean().item() * 1e-3 # sec + std_dur_s_gpu = benchmark_res.gpu_elapsed_time.std().item() * 1e-3 # sec + + # CPU statistics + avg_dur_s_cpu = benchmark_res.cpu_elapsed_time.mean().item() * 1e-3 # sec + std_dur_s_cpu = benchmark_res.cpu_elapsed_time.std().item() * 1e-3 # sec - qps = int(num_requests / avg_dur_s) + qps_gpu = int(num_requests / avg_dur_s_gpu) mem_str = "" for memory_stats in benchmark_res.mem_stats: mem_str += f"{memory_stats}\n" - report_str += f"{benchmark_res.short_name:40} Avg QPS:{qps:10} Avg Duration: {int(1000*avg_dur_s):5}" - report_str += f"ms Standard Dev Duration: {(1000*std_dur_s):.2f}ms\n" + report_str += ( + f"{benchmark_res.short_name:40} " + f"Avg QPS(GPU):{qps_gpu:10} " + f"GPU Avg: {int(1000*avg_dur_s_gpu):5}ms ±{(1000*std_dur_s_gpu):.2f}ms " + f"CPU Avg: {int(1000*avg_dur_s_cpu):5}ms ±{(1000*std_dur_s_cpu):.2f}ms\n" + ) report_str += f"\tMemory Allocated Per Rank:\n\t{mem_str}\n" with open(report_file, "w") as f: @@ -702,14 +728,15 @@ def _run_benchmark_core( if reset_accumulated_memory_stats: torch.cuda.reset_accumulated_memory_stats(rank) - # Optional allocator warm-up to create fragmentation similar to production - if pre_gpu_load and device_type == "cuda": - _tmp = torch.rand(16384, 16384, device="cuda") - for _ in range(pre_gpu_load): - _tmp = _tmp * torch.rand(16384, 16384, device="cuda") + # Optional allocator warm-up to create fragmentation similar to production + if pre_gpu_load: + _tmp = torch.rand(16384, 16384, device="cuda") + for _ in range(pre_gpu_load): + _tmp = _tmp * torch.rand(16384, 16384, device="cuda") - # Timing loop + # Timings start_events, end_events, times = [], [], [] + if device_type == "cuda": start_events = [ torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks) @@ -717,29 +744,47 @@ def _run_benchmark_core( end_events = [ torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks) ] + # Capture per-iteration active CPU cycles (excludes time the thread is truly idle/asleep) using `process_time_ns`. + cpu_times_active_ns: List[int] = [] + for i in range(num_benchmarks): + # Ensure that outstanding GPU work from the previous iteration has + # finished so that we do not attribute its wait time to the next + # CPU measurement. + if i > 0: + torch.cuda.synchronize(rank if rank >= 0 else 0) + start_events[i].record() + cpu_start_active_ns = time.process_time_ns() + run_iter_fn() + + cpu_end_active_ns = time.process_time_ns() end_events[i].record() - else: - times = timeit.repeat(run_iter_fn, number=1, repeat=num_benchmarks) + cpu_times_active_ns.append(cpu_end_active_ns - cpu_start_active_ns) - # Make sure all kernels are finished before reading timers / stats - if device_type == "cuda": + # Convert to milliseconds and drop the first iteration + cpu_elapsed_time = torch.tensor( + [t / 1e6 for t in cpu_times_active_ns[1:]], dtype=torch.float + ) + + # Make sure all kernels are finished before reading timers / stats if rank == -1: for di in range(world_size): torch.cuda.synchronize(di) else: torch.cuda.synchronize(rank) - # First Benchmark Run for Eager Mode produces outlier - # Start counting after first as workaround for standard deviation - if device_type == "cuda": - elapsed_time = torch.tensor( + gpu_elapsed_time = torch.tensor( [s.elapsed_time(e) for s, e in zip(start_events[1:], end_events[1:])] ) else: - elapsed_time = torch.tensor(times) * 1e3 # convert seconds ➜ milliseconds + # For CPU-only benchmarks we fall back to wall-clock timing via ``timeit``. + times = timeit.repeat(run_iter_fn, number=1, repeat=num_benchmarks) + cpu_elapsed_time = torch.tensor(times) * 1e3 # convert to ms + + # mirror CPU timings for overall consistency + gpu_elapsed_time = cpu_elapsed_time.clone() # Memory statistics collection mem_stats: List[MemoryStats] = [] @@ -791,7 +836,11 @@ def _trace_handler(prof: torch.profiler.profile) -> None: torch.cuda.synchronize(rank) return BenchmarkResult( - short_name=name, elapsed_time=elapsed_time, mem_stats=mem_stats, rank=rank + short_name=name, + gpu_elapsed_time=gpu_elapsed_time, + cpu_elapsed_time=cpu_elapsed_time, + mem_stats=mem_stats, + rank=rank, ) @@ -1066,10 +1115,11 @@ def setUp() -> None: assert 0 == p.exitcode total_benchmark_res = BenchmarkResult( - benchmark_res_per_rank[0].short_name, - benchmark_res_per_rank[0].elapsed_time, - [MemoryStats(rank, 0, 0, 0) for rank in range(world_size)], - 0, + short_name=benchmark_res_per_rank[0].short_name, + gpu_elapsed_time=benchmark_res_per_rank[0].gpu_elapsed_time, + cpu_elapsed_time=benchmark_res_per_rank[0].cpu_elapsed_time, + mem_stats=[MemoryStats(rank, 0, 0, 0) for rank in range(world_size)], + rank=0, ) for res in benchmark_res_per_rank: diff --git a/torchrec/sparse/tests/jagged_tensor_benchmark.py b/torchrec/sparse/tests/jagged_tensor_benchmark.py index 34862e380..4e9cd1921 100644 --- a/torchrec/sparse/tests/jagged_tensor_benchmark.py +++ b/torchrec/sparse/tests/jagged_tensor_benchmark.py @@ -107,14 +107,13 @@ def wrapped_func( ) result = BenchmarkResult( short_name=name, - elapsed_time=torch.tensor(times) * 1e3, + gpu_elapsed_time=torch.tensor(times) * 1e3, + cpu_elapsed_time=torch.tensor(times) * 1e3, mem_stats=[MemoryStats(0, 0, 0, 0)], ) - mem_alloc = f"Memory alloc (P90): {result.max_mem_alloc_percentile(90):5.1f}" - mem_reserved = f"Memory alloc (P90): {result.max_mem_reserved_percentile(90):5.1f}" print( - f" {name : <{30}} | B: {batch_size : <{8}} | F: {feature_count : <{8}} | device: {device_type : <{8}} | Runtime (P90): {result.runtime_percentile(90):5.2f} ms | {mem_alloc} | {mem_reserved}" + f"B: {batch_size : <{8}} | F: {feature_count : <{8}} | device: {device_type : <{8}} | {result}" ) diff --git a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py index 1c409fcf2..8d5126a13 100644 --- a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py +++ b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py @@ -227,7 +227,8 @@ def benchmark_kjt( result = BenchmarkResult( short_name=f"{test_name}-{transform_type.name}", - elapsed_time=torch.tensor(times), + gpu_elapsed_time=torch.tensor(times), + cpu_elapsed_time=torch.tensor(times), mem_stats=[MemoryStats(0, 0, 0, 0)], )