Skip to content

Commit f7dc7d1

Browse files
authored
Fix cuda time (#337)
1 parent a404ea7 commit f7dc7d1

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

tritonbench/components/do_bench/cuda_time.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
CACHE_CLEAR_KERNEL = "void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, std::array<char*, 1ul> >(int, at::native::FillFunctor<int>, std::array<char*, 1ul>)"
1313

1414

15-
def _kineto_events_to_latency(prof):
15+
def _kineto_events_to_latency(prof, n_repeat):
1616
prof_averages = prof.key_averages(group_by_input_shape=False)
1717
cuda_event_names = [
1818
event.key
@@ -33,22 +33,16 @@ def _kineto_events_to_latency(prof):
3333
kernel_duration_name_map[event.name()] = []
3434
kernel_duration_name_map[event.name()].append(event.duration_ns() / 1e6)
3535

36-
kernel_hits = [len(kernel_duration_name_map[k]) for k in kernel_duration_name_map]
37-
assert all(
38-
x == kernel_hits[0] for x in kernel_hits
39-
), "Error: Not all kernels run the same time."
36+
op_time = 0.0
37+
for name in kernel_duration_name_map:
38+
op_time += sum(kernel_duration_name_map[name])
4039

41-
op_latencies = []
42-
for x in range(kernel_hits[0]):
43-
op_time = 0.0
44-
for name in kernel_duration_name_map:
45-
op_time += kernel_duration_name_map[name][x]
46-
op_latencies.append(op_time)
40+
op_time = op_time / n_repeat
4741

4842
print(
4943
prof.key_averages(group_by_input_shape=False).table(sort_by="cuda_time_total")
5044
)
51-
return Latency(times=op_latencies)
45+
return op_time
5246

5347

5448
def _do_bench_cuda_time_cudagraph(
@@ -59,7 +53,7 @@ def _do_bench_cuda_time_cudagraph(
5953
n_repeat: int,
6054
grad_to_none: bool,
6155
bypass_fail: bool = False,
62-
) -> Latency:
56+
) -> float:
6357
with torch.cuda.stream(torch.cuda.Stream()):
6458
g = torch.cuda.CUDAGraph()
6559
with torch.cuda.graph(g):
@@ -87,7 +81,7 @@ def _do_bench_cuda_time_cudagraph(
8781
prof.step()
8882
synchronize_with_timing()
8983

90-
return _kineto_events_to_latency(prof)
84+
return _kineto_events_to_latency(prof, n_repeat)
9185

9286

9387
def do_bench_cuda_time(
@@ -97,7 +91,7 @@ def do_bench_cuda_time(
9791
grad_to_none: bool,
9892
use_cuda_graphs: bool = False,
9993
bypass_fail: bool = False,
100-
) -> Latency:
94+
) -> float:
10195
"""
10296
Return the aggregated CUDA time of a benchmarked operator backend.
10397
"""
@@ -156,4 +150,4 @@ def synchronize_with_timing():
156150
prof.step()
157151
synchronize_with_timing()
158152

159-
return _kineto_events_to_latency(prof)
153+
return _kineto_events_to_latency(prof, n_repeat)

0 commit comments

Comments
 (0)