|
6 | 6 |
|
7 | 7 | import pytest
|
8 | 8 | import torch
|
| 9 | +from torch.nn import functional as F |
9 | 10 |
|
10 | 11 | pytest.importorskip("triton", reason="Triton required to run this test")
|
11 | 12 |
|
@@ -303,19 +304,47 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
|
303 | 304 |
|
304 | 305 | @pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
|
305 | 306 | @pytest.mark.parametrize("num_experts", (1, 8, 16))
|
306 |
| -def test_mxfp8_grouped_gemm_with_dq_fwd(M, K, N, num_experts): |
| 307 | +def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts): |
307 | 308 | from torchao.prototype.moe_training.scaled_grouped_mm import (
|
308 | 309 | _MXFP8GroupedMM,
|
309 | 310 | )
|
310 | 311 |
|
311 |
| - x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") |
312 |
| - w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda") |
313 |
| - offs = generate_jagged_offs(num_experts, M) |
314 |
| - x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone() |
315 | 312 | block_size = 32
|
| 313 | + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True) |
| 314 | + w_t = torch.randn( |
| 315 | + num_experts, K, N, dtype=torch.bfloat16, device="cuda", requires_grad=True |
| 316 | + ) |
| 317 | + offs = generate_jagged_offs(num_experts, M, multiple_of=block_size) |
| 318 | + x_ref, w_t_ref, offs_ref = ( |
| 319 | + x.clone().detach().requires_grad_(True), |
| 320 | + w_t.clone().detach().requires_grad_(True), |
| 321 | + offs.clone(), |
| 322 | + ) |
316 | 323 |
|
| 324 | + # Forward |
317 | 325 | out = _MXFP8GroupedMM.apply(x, w_t, offs, block_size, torch.bfloat16)
|
318 | 326 | ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
|
319 | 327 | sqnr = compute_error(ref_out, out)
|
320 | 328 | min_sqnr = 27.0
|
321 |
| - assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}" |
| 329 | + assert sqnr >= min_sqnr, f"Output sqnr {sqnr} is too low, must be >= {min_sqnr}" |
| 330 | + |
| 331 | + # Backward |
| 332 | + labels = torch.ones_like(ref_out) |
| 333 | + ref_loss = F.mse_loss(ref_out, labels) |
| 334 | + out_loss = F.mse_loss(out, labels) |
| 335 | + ref_loss.backward() |
| 336 | + out_loss.backward() |
| 337 | + |
| 338 | + # Check input grads |
| 339 | + min_input_grad_sqnr = 26.0 |
| 340 | + sqnr = compute_error(x_ref.grad, x.grad) |
| 341 | + assert sqnr >= min_input_grad_sqnr, ( |
| 342 | + f"Input grad sqnr {sqnr} is too low, must be >= {min_input_grad_sqnr}" |
| 343 | + ) |
| 344 | + |
| 345 | + # Check weight grads |
| 346 | + min_weight_grad_sqnr = 24.0 |
| 347 | + sqnr = compute_error(w_t_ref.grad, w_t.grad) |
| 348 | + assert sqnr >= min_weight_grad_sqnr, ( |
| 349 | + f"Weight grad sqnr {sqnr} is too low, must be >= {min_weight_grad_sqnr}" |
| 350 | + ) |
0 commit comments