Skip to content

Commit 9902c28

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Add input loader module for fp8_gemm (#336)
Summary: Add input loader module for fp8_gemm kernels. This diff enables us to invoke custom shapes in fp8_gemm benchmarking by passing a `json` file into the TritonBench `--input-loader` arg. Reviewed By: NikhilAPatel, xuzhao9 Differential Revision: D80565971
1 parent f7dc7d1 commit 9902c28

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,14 @@ def args(m, n, k):
6565
)
6666
return (a, b)
6767

68-
if self.extra_args.llama:
68+
if hasattr(self, 'external_shapes') and self.external_shapes: # Check for external shapes loaded from input-loader
69+
for shape in self.external_shapes:
70+
if len(shape) == 3:
71+
m, n, k = shape
72+
yield args(m, n, k)
73+
else:
74+
logger.warning(f"Skipping invalid shape: {shape}, expected [M, N, K]")
75+
elif self.extra_args.llama:
6976
for m, n, k, _bias in llama_shapes():
7077
yield args(m, n, k)
7178
elif self.extra_args.m:
@@ -115,7 +122,7 @@ def pt2_fp8_gemm(self, a, b) -> Callable:
115122
out_dtype = torch.bfloat16
116123
else:
117124
scale_a = torch.tensor(1.0, device=a.device)
118-
scale_b = torch.tensor(1.0, device=a.device)
125+
scale_b = torch.tensor(1.0, device=b.device)
119126
out_dtype = torch.float16
120127
f = lambda a, b: torch._scaled_mm(
121128
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype

0 commit comments

Comments
 (0)