|
6 | 6 | # this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
|
7 | 7 |
|
8 | 8 | import itertools
|
9 |
| -import time |
10 | 9 | from dataclasses import dataclass
|
11 | 10 | from typing import List
|
12 | 11 |
|
13 | 12 | import torch
|
14 | 13 | from tabulate import tabulate
|
15 | 14 | from tqdm import tqdm
|
| 15 | +from triton.testing import do_bench |
16 | 16 |
|
17 | 17 | from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
|
18 | 18 | triton_fp8_col_major_jagged_colwise_scales,
|
@@ -129,18 +129,15 @@ def run_triton(
|
129 | 129 |
|
130 | 130 | # bench torch
|
131 | 131 | compiled_run_torch = torch.compile(run_torch)
|
132 |
| - warmup(compiled_run_torch, input_row_major, input_col_major, offs) |
133 |
| - start_time_ns = time.perf_counter_ns() |
134 |
| - compiled_run_torch(input_row_major, input_col_major, offs) |
135 |
| - torch_time_ns = time.perf_counter_ns() - start_time_ns |
136 |
| - torch_time_us = torch_time_ns / 1e3 |
| 132 | + torch_time_us = benchmark_cuda_function_in_microseconds( |
| 133 | + compiled_run_torch, input_row_major, input_col_major, offs |
| 134 | + ) |
137 | 135 |
|
138 | 136 | # bench triton
|
139 | 137 | warmup(run_triton, input_row_major, input_col_major, offs)
|
140 |
| - start_time_ns = time.perf_counter_ns() |
141 |
| - run_triton(input_row_major, input_col_major, offs) |
142 |
| - triton_time_ns = time.perf_counter_ns() - start_time_ns |
143 |
| - triton_time_us = triton_time_ns / 1e3 |
| 138 | + triton_time_us = benchmark_cuda_function_in_microseconds( |
| 139 | + run_triton, input_row_major, input_col_major, offs |
| 140 | + ) |
144 | 141 |
|
145 | 142 | return ExperimentResult(
|
146 | 143 | torch_time_us=torch_time_us,
|
@@ -173,6 +170,10 @@ def print_results(experiments: List[Experiment]):
|
173 | 170 | print(tabulate(rows, headers=headers))
|
174 | 171 |
|
175 | 172 |
|
| 173 | +def benchmark_cuda_function_in_microseconds(f, *args): |
| 174 | + return do_bench(lambda: f(*args), return_mode="median") * 1e3 |
| 175 | + |
| 176 | + |
176 | 177 | def main():
|
177 | 178 | torch.random.manual_seed(123)
|
178 | 179 | configs = get_configs()
|
|
0 commit comments