Skip to content

Commit ad791c4

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Move scaling logic to input generation (#338)
Summary: Move scaling logic for FP8 benchmarks to `get_input_iter()`. This diff aligns our fp8_gemm benchmarking suite with real-world practices: input tensors are of high precision types (`bfloat16`, `float16`), scales are computed on the high-precision input tensors, and input tensors are then casted to a lower precision (`float8_e4m3fn`). This diff also circumvents performing unsupported operations, like `torch.max` and `torch.abs`, on low-precision data types. Test Plan: Imported from GitHub, without a `Test Plan:` line. Rollback Plan: Differential Revision: D80571223 Pulled By: jananisriram
1 parent 34e755a commit ad791c4

File tree

1 file changed

+30
-30
lines changed

1 file changed

+30
-30
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,27 @@ def __init__(
5656

5757
def get_input_iter(self):
5858
def args(m, n, k):
59-
a = torch.randn(m, k, device=self.device).to(torch.float8_e4m3fn)
59+
a = torch.randn(m, k, device=self.device).to(torch.float16)
6060
b = (
6161
torch.randn(k, n, device=self.device)
62-
.to(torch.float8_e4m3fn)
62+
.to(torch.float16)
6363
.T.contiguous()
6464
.T
6565
)
66-
return (a, b)
66+
67+
if self.extra_args.scaling_rowwise:
68+
M, N = a.shape[0], b.shape[1]
69+
scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device)
70+
scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
71+
else:
72+
scale_a = torch.tensor(1.0, device=a.device)
73+
scale_b = torch.tensor(1.0, device=a.device)
74+
75+
# Kernels expect dtype=float8_e4m3fn
76+
a = a.to(torch.float8_e4m3fn)
77+
b = b.to(torch.float8_e4m3fn)
78+
79+
return (a, b, scale_a, scale_b)
6780

6881
if (
6982
hasattr(self, "external_shapes") and self.external_shapes
@@ -90,62 +103,49 @@ def args(m, n, k):
90103
yield args(m, n, k)
91104

92105
def get_x_val(self, example_inputs) -> float:
93-
a, b = example_inputs
106+
a, b, _, _ = example_inputs
94107
m, k = a.size()
95108
_, n = b.size()
96109
return (m, n, k)
97110

98-
@register_benchmark(baseline=True)
99-
def torch_fp8_gemm(self, a, b):
111+
def _get_out_dtype(self):
100112
if self.extra_args.scaling_rowwise:
101-
M, N = a.shape[0], b.shape[1]
102-
scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device)
103-
scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
104-
out_dtype = torch.bfloat16
113+
return torch.bfloat16
105114
else:
106-
scale_a = torch.tensor(1.0, device=a.device)
107-
scale_b = torch.tensor(1.0, device=a.device)
108-
out_dtype = torch.float16
115+
return torch.float16
109116

117+
@register_benchmark(baseline=True)
118+
def torch_fp8_gemm(self, a, b, scale_a, scale_b):
110119
return lambda: torch._scaled_mm(
111-
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype
120+
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_out_dtype()
112121
)
113122

114123
@register_benchmark()
115-
def pt2_fp8_gemm(self, a, b) -> Callable:
124+
def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
116125
torch._dynamo.reset()
117126
with inductor_config.patch(
118127
max_autotune=True,
119128
max_autotune_gemm_backends="TRITON",
120129
autotune_fallback_to_aten=False,
121130
):
122-
if self.extra_args.scaling_rowwise:
123-
M, N = a.shape[0], b.shape[1]
124-
scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device)
125-
scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
126-
out_dtype = torch.bfloat16
127-
else:
128-
scale_a = torch.tensor(1.0, device=a.device)
129-
scale_b = torch.tensor(1.0, device=b.device)
130-
out_dtype = torch.float16
131131
f = lambda a, b: torch._scaled_mm(
132-
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype
132+
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_out_dtype()
133133
)
134134
compiled = torch.compile(f, dynamic=False)
135135
compiled(a, b)
136136

137137
return lambda: compiled(a, b)
138138

139139
@register_benchmark()
140-
def triton_fp8_gemm(self, a, b):
140+
def triton_fp8_gemm(self, a, b, scale_a, scale_b):
141141
return lambda: tutorial_matmul(a, b)
142142

143143
@register_benchmark(enabled=HAS_TMA)
144-
def triton_persistent_fp8_gemm(self, a, b):
144+
def triton_persistent_fp8_gemm(self, a, b, scale_a, scale_b):
145145
return lambda: matmul_persistent(a, b)
146146

147147
@register_benchmark(enabled=HAS_TMA)
148-
def triton_tma_persistent_fp8_gemm(self, a, b):
148+
def triton_tma_persistent_fp8_gemm(self, a, b, scale_a, scale_b):
149149
b = b.T.contiguous()
150150
c, desc_a, desc_b, desc_c = allocate_matmul_tma(a, b)
151151
return lambda: matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c)
@@ -155,7 +155,7 @@ def gbps(self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> fl
155155
def nbytes(t):
156156
return t.numel() * t.element_size()
157157

158-
a, b = example_inputs
158+
a, b, _, _ = example_inputs
159159
c = fn()
160160
c = c[0] if isinstance(c, tuple) else c
161161

@@ -168,7 +168,7 @@ def nbytes(t):
168168
def flops(
169169
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
170170
) -> float:
171-
a, b = example_inputs
171+
a, b, _, _ = example_inputs
172172
m, k = a.size()
173173
_, n = b.size()
174174
flops = 2 * m * n * k

0 commit comments

Comments
 (0)