Skip to content

Commit 3484c61

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Move scaling logic to input generation
Summary: Move scaling logic for FP8 benchmarks to `get_input_iter()`. This diff aligns our fp8_gemm benchmarking suite with real-world practices: input tensors are of high precision types (`bfloat16`, `float16`), scales are computed on the high-precision input tensors, and input tensors are then casted to a lower precision (`float8_e4m3fn`). This diff also circumvents performing unsupported operations, like `torch.max` and `torch.abs`, on low-precision data types. Reviewed By: NikhilAPatel Differential Revision: D80571223
1 parent 9902c28 commit 3484c61

File tree

1 file changed

+30
-30
lines changed

1 file changed

+30
-30
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,27 @@ def __init__(
5656

5757
def get_input_iter(self):
5858
def args(m, n, k):
59-
a = torch.randn(m, k, device=self.device).to(torch.float8_e4m3fn)
59+
a = torch.randn(m, k, device=self.device).to(torch.float16)
6060
b = (
6161
torch.randn(k, n, device=self.device)
62-
.to(torch.float8_e4m3fn)
62+
.to(torch.float16)
6363
.T.contiguous()
6464
.T
6565
)
66-
return (a, b)
66+
67+
if self.extra_args.scaling_rowwise:
68+
M, N = a.shape[0], b.shape[1]
69+
scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device)
70+
scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
71+
else:
72+
scale_a = torch.tensor(1.0, device=a.device)
73+
scale_b = torch.tensor(1.0, device=a.device)
74+
75+
# Kernels expect dtype=float8_e4m3fn
76+
a = a.to(torch.float8_e4m3fn)
77+
b = b.to(torch.float8_e4m3fn)
78+
79+
return (a, b, scale_a, scale_b)
6780

6881
if hasattr(self, 'external_shapes') and self.external_shapes: # Check for external shapes loaded from input-loader
6982
for shape in self.external_shapes:
@@ -86,62 +99,49 @@ def args(m, n, k):
8699
yield args(m, n, k)
87100

88101
def get_x_val(self, example_inputs) -> float:
89-
a, b = example_inputs
102+
a, b, _, _ = example_inputs
90103
m, k = a.size()
91104
_, n = b.size()
92105
return (m, n, k)
93106

94-
@register_benchmark(baseline=True)
95-
def torch_fp8_gemm(self, a, b):
107+
def _get_out_dtype(self):
96108
if self.extra_args.scaling_rowwise:
97-
M, N = a.shape[0], b.shape[1]
98-
scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device)
99-
scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
100-
out_dtype = torch.bfloat16
109+
return torch.bfloat16
101110
else:
102-
scale_a = torch.tensor(1.0, device=a.device)
103-
scale_b = torch.tensor(1.0, device=a.device)
104-
out_dtype = torch.float16
111+
return torch.float16
105112

113+
@register_benchmark(baseline=True)
114+
def torch_fp8_gemm(self, a, b, scale_a, scale_b):
106115
return lambda: torch._scaled_mm(
107-
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype
116+
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_out_dtype()
108117
)
109118

110119
@register_benchmark()
111-
def pt2_fp8_gemm(self, a, b) -> Callable:
120+
def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
112121
torch._dynamo.reset()
113122
with inductor_config.patch(
114123
max_autotune=True,
115124
max_autotune_gemm_backends="TRITON",
116125
autotune_fallback_to_aten=False,
117126
):
118-
if self.extra_args.scaling_rowwise:
119-
M, N = a.shape[0], b.shape[1]
120-
scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device)
121-
scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
122-
out_dtype = torch.bfloat16
123-
else:
124-
scale_a = torch.tensor(1.0, device=a.device)
125-
scale_b = torch.tensor(1.0, device=b.device)
126-
out_dtype = torch.float16
127127
f = lambda a, b: torch._scaled_mm(
128-
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype
128+
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_out_dtype()
129129
)
130130
compiled = torch.compile(f, dynamic=False)
131131
compiled(a, b)
132132

133133
return lambda: compiled(a, b)
134134

135135
@register_benchmark()
136-
def triton_fp8_gemm(self, a, b):
136+
def triton_fp8_gemm(self, a, b, scale_a, scale_b):
137137
return lambda: tutorial_matmul(a, b)
138138

139139
@register_benchmark(enabled=HAS_TMA)
140-
def triton_persistent_fp8_gemm(self, a, b):
140+
def triton_persistent_fp8_gemm(self, a, b, scale_a, scale_b):
141141
return lambda: matmul_persistent(a, b)
142142

143143
@register_benchmark(enabled=HAS_TMA)
144-
def triton_tma_persistent_fp8_gemm(self, a, b):
144+
def triton_tma_persistent_fp8_gemm(self, a, b, scale_a, scale_b):
145145
b = b.T.contiguous()
146146
c, desc_a, desc_b, desc_c = allocate_matmul_tma(a, b)
147147
return lambda: matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c)
@@ -151,7 +151,7 @@ def gbps(self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> fl
151151
def nbytes(t):
152152
return t.numel() * t.element_size()
153153

154-
a, b = example_inputs
154+
a, b, _, _ = example_inputs
155155
c = fn()
156156
c = c[0] if isinstance(c, tuple) else c
157157

@@ -164,7 +164,7 @@ def nbytes(t):
164164
def flops(
165165
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
166166
) -> float:
167-
a, b = example_inputs
167+
a, b, _, _ = example_inputs
168168
m, k = a.size()
169169
_, n = b.size()
170170
flops = 2 * m * n * k

0 commit comments

Comments
 (0)