Skip to content

Commit fa77af5

Browse files
mxfp8 emulated grouped gemm
add emulated mxfp8 grouped gemm stack-info: PR: #2626, branch: danielvegamyhre/stack/22
1 parent 0e00df3 commit fa77af5

File tree

4 files changed

+129
-34
lines changed

4 files changed

+129
-34
lines changed

benchmarks/float8/bench_grouped_mm.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from utils import do_benchmarks, get_name_to_moe_shapes_iter
1313

1414
from torchao.testing.training.roofline_utils import get_specs
15+
from torchao.prototype.moe_training.utils import generate_jagged_offs
1516

1617

1718
@torch.inference_mode()
@@ -146,38 +147,6 @@ def do_scaled_grouped_mm(A, B):
146147
data_df.to_csv(out_filename)
147148

148149

149-
def generate_jagged_offs(E, M, dtype=torch.int32, device="cuda"):
150-
"""
151-
Generates a tensor of length E, containing random values divisible by 16,
152-
from 0 to M, in sorted order, and where the final value in the tensor is always M.
153-
Args:
154-
E (int): The length of the tensor.
155-
M (int): The maximum value in the tensor.
156-
Returns:
157-
torch.Tensor: A tensor of length E with the specified properties.
158-
"""
159-
# Ensure M is divisible by 16
160-
if M % 16 != 0:
161-
raise ValueError("M must be divisible by 16")
162-
163-
# Generate a list of possible values
164-
possible_values = [i for i in range(0, M + 1, 16)]
165-
166-
# If E is larger than the number of possible values, raise an error
167-
if E > len(possible_values):
168-
raise ValueError("E cannot be larger than the number of possible values")
169-
170-
# Randomly select E - 1 values from the possible values (excluding M)
171-
selected_values = torch.tensor(random.sample(possible_values[:-1], E - 1))
172-
173-
# Append M to the selected values
174-
selected_values = torch.cat((selected_values, torch.tensor([M])))
175-
176-
# Sort the selected values
177-
selected_values, _ = torch.sort(selected_values)
178-
179-
return selected_values.to(dtype).to(device)
180-
181150

182151
def main() -> None:
183152
fire.Fire(run)

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88
import torch
99

10+
from torchao.prototype.moe_training.utils import generate_jagged_offs
1011
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1112

1213
# We need to skip before doing any imports which would use triton, since
@@ -25,10 +26,12 @@
2526
)
2627
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
2728
from torchao.float8.float8_training_tensor import LinearMMConfig
28-
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
29+
from torchao.float8.float8_utils import compute_error, tensor_to_scale, to_fp8_saturated
2930
from torchao.prototype.moe_training.scaled_grouped_mm import (
3031
_scaled_grouped_mm,
32+
emulated_mxfp8_scaled_grouped_mm,
3133
)
34+
from torchao.prototype.mx_formats.mx_tensor import to_mx
3235
from torchao.testing.utils import skip_if_rocm
3336

3437

@@ -212,3 +215,35 @@ def compute_reference_forward(
212215
# Concatenate the outputs and verify the full result is correct.
213216
output_ref = torch.cat(outputs, dim=0)
214217
return output_ref
218+
219+
220+
@pytest.mark.parametrize("M", (1024, 4096))
221+
@pytest.mark.parametrize("K", (1024, 4096))
222+
@pytest.mark.parametrize("N", (1024, 4096))
223+
@pytest.mark.parametrize("num_experts", (1, 8, 16))
224+
def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts):
225+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
226+
w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda")
227+
offs = generate_jagged_offs(num_experts, M)
228+
x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone()
229+
230+
# Quantize inputs to mxpf8 for emulated mxfp8 scaled grouped mm
231+
block_size = 32
232+
x_scale, x_mx = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
233+
234+
# To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose.
235+
w_scale, w_mx = to_mx(
236+
w_t.transpose(-2, -1).contiguous(),
237+
elem_dtype=torch.float8_e4m3fn,
238+
block_size=block_size,
239+
)
240+
w_t_scale, w_t_mx = w_scale.transpose(-2, -1), w_mx.transpose(-2, -1)
241+
242+
ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
243+
out = emulated_mxfp8_scaled_grouped_mm(
244+
x_mx, x_scale, w_t_mx, w_t_scale, offs=offs, out_dtype=torch.bfloat16
245+
)
246+
247+
sqnr = compute_error(ref_out, out)
248+
min_sqnr = 27.0
249+
assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}"

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def backward(ctx, grad_output: torch.Tensor):
217217
use_fast_accum=True,
218218
)
219219

220-
# Convert tranpose of grad_output to float8, row-major for left operand of grouped GEMM
220+
# Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM
221221
# needed for grad_B: grad_output_t @ A
222222
grad_output_t_row_major = grad_output.transpose(-2, -1).contiguous()
223223

@@ -266,3 +266,58 @@ def backward(ctx, grad_output: torch.Tensor):
266266
use_fast_accum=True,
267267
)
268268
return grad_A, grad_B.transpose(-2, -1), None, None, None, None
269+
270+
271+
def emulated_mxfp8_scaled_grouped_mm(
272+
A_mx: torch.Tensor,
273+
A_scale: torch.Tensor,
274+
B_t_mx: torch.Tensor,
275+
B_t_scale: torch.Tensor,
276+
offs: Optional[torch.Tensor] = None,
277+
out_dtype: Optional[torch.dtype] = torch.bfloat16,
278+
block_size: int = 32,
279+
) -> torch.Tensor:
280+
# Dequantize input
281+
# A_mx shape: (M, K)
282+
# A_scale shape: (M, K//block_size)
283+
A_orig_shape = A_mx.shape
284+
285+
# Reshape to be able to do per-scaling group multiplication
286+
# A_mx shape: (M, K//block_size, block_size)
287+
# A_scale shape: (M, K//block_size, 1)
288+
A_mx = A_mx.reshape(*A_mx.shape[:-1], A_mx.shape[-1] // block_size, block_size)
289+
A_scale = A_scale.unsqueeze(-1)
290+
291+
# Rescale and cast to bfloat16
292+
A = A_mx.to(torch.bfloat16) * A_scale.to(torch.bfloat16)
293+
294+
# Reshape back to original shape
295+
# A shape: (M, K)
296+
A = A.reshape(A_orig_shape)
297+
298+
# Dequantize weights
299+
# B_t_mx shape: (E, K, N)
300+
# B_t_scale shape: (E, K//block_size, N)
301+
E, K, N = B_t_mx.shape
302+
303+
# Tranpose to get block_size on rightmost dim
304+
# B_mx shape: (E, N, K)
305+
# B_scale shape: (E, N, K//block_size)
306+
B_mx, B_scale = B_t_mx.transpose(-2, -1), B_t_scale.transpose(-2, -1)
307+
308+
# Reshape to be able to do per-scaling group multiplication
309+
# B_mx shape: (E, N, K//block_size, block_size)
310+
# B_scale shape: (E, N, K//block_size, 1)
311+
B_mx = B_mx.reshape(*B_mx.shape[:-1], B_mx.shape[-1] // block_size, block_size)
312+
B_scale = B_scale.unsqueeze(-1)
313+
314+
# Rescale and cast to bfloat16
315+
B = B_mx.to(torch.bfloat16) * B_scale.to(torch.bfloat16)
316+
317+
# Reshape back to original shape
318+
# B shape: (E, K, N)
319+
B_t = B.reshape(E, N, K).transpose(-2, -1)
320+
321+
# Perform bf16 grouped GEMM.
322+
out = torch._grouped_mm(A, B_t, offs=offs, out_dtype=out_dtype)
323+
return out

torchao/prototype/moe_training/utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import random
12
from typing import Tuple
23

34
import torch
@@ -154,3 +155,38 @@ def _is_column_major(x: torch.Tensor) -> bool:
154155
"""
155156
assert x.ndim == 2 or x.ndim == 3, "input tensor must be 2D or 3D"
156157
return x.stride(-2) == 1 and x.stride(-1) > 1
158+
159+
160+
def generate_jagged_offs(E, M, dtype=torch.int32, device="cuda"):
161+
"""
162+
Utility function for tests and benchmarks.
163+
164+
Generates a tensor of length E, containing random values divisible by 16,
165+
from 0 to M, in sorted order, and where the final value in the tensor is always M.
166+
Args:
167+
E (int): The length of the tensor.
168+
M (int): The maximum value in the tensor.
169+
Returns:
170+
torch.Tensor: A tensor of length E with the specified properties.
171+
"""
172+
# Ensure M is divisible by 16
173+
if M % 16 != 0:
174+
raise ValueError("M must be divisible by 16")
175+
176+
# Generate a list of possible values
177+
possible_values = [i for i in range(0, M + 1, 16)]
178+
179+
# If E is larger than the number of possible values, raise an error
180+
if E > len(possible_values):
181+
raise ValueError("E cannot be larger than the number of possible values")
182+
183+
# Randomly select E - 1 values from the possible values (excluding M)
184+
selected_values = torch.tensor(random.sample(possible_values[:-1], E - 1))
185+
186+
# Append M to the selected values
187+
selected_values = torch.cat((selected_values, torch.tensor([M])))
188+
189+
# Sort the selected values
190+
selected_values, _ = torch.sort(selected_values)
191+
192+
return selected_values.to(dtype).to(device)

0 commit comments

Comments
 (0)