|
6 | 6 |
|
7 | 7 | import pytest
|
8 | 8 | import torch
|
| 9 | +from torch.nn import functional as F |
9 | 10 |
|
10 | 11 | pytest.importorskip("triton", reason="Triton required to run this test")
|
11 | 12 |
|
@@ -306,19 +307,47 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
|
306 | 307 |
|
307 | 308 | @pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
|
308 | 309 | @pytest.mark.parametrize("num_experts", (1, 8, 16))
|
309 |
| -def test_mxfp8_grouped_gemm_with_dq_fwd(M, K, N, num_experts): |
| 310 | +def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts): |
310 | 311 | from torchao.prototype.moe_training.scaled_grouped_mm import (
|
311 | 312 | _MXFP8GroupedMM,
|
312 | 313 | )
|
313 | 314 |
|
314 |
| - x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") |
315 |
| - w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda") |
316 |
| - offs = generate_jagged_offs(num_experts, M) |
317 |
| - x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone() |
318 | 315 | block_size = 32
|
| 316 | + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True) |
| 317 | + w_t = torch.randn( |
| 318 | + num_experts, K, N, dtype=torch.bfloat16, device="cuda", requires_grad=True |
| 319 | + ) |
| 320 | + offs = generate_jagged_offs(num_experts, M, multiple_of=block_size) |
| 321 | + x_ref, w_t_ref, offs_ref = ( |
| 322 | + x.clone().detach().requires_grad_(True), |
| 323 | + w_t.clone().detach().requires_grad_(True), |
| 324 | + offs.clone(), |
| 325 | + ) |
319 | 326 |
|
| 327 | + # Forward |
320 | 328 | out = _MXFP8GroupedMM.apply(x, w_t, offs, block_size, torch.bfloat16)
|
321 | 329 | ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
|
322 | 330 | sqnr = compute_error(ref_out, out)
|
323 | 331 | min_sqnr = 27.0
|
324 |
| - assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}" |
| 332 | + assert sqnr >= min_sqnr, f"Output sqnr {sqnr} is too low, must be >= {min_sqnr}" |
| 333 | + |
| 334 | + # Backward |
| 335 | + labels = torch.ones_like(ref_out) |
| 336 | + ref_loss = F.mse_loss(ref_out, labels) |
| 337 | + out_loss = F.mse_loss(out, labels) |
| 338 | + ref_loss.backward() |
| 339 | + out_loss.backward() |
| 340 | + |
| 341 | + # Check input grads |
| 342 | + min_input_grad_sqnr = 26.0 |
| 343 | + sqnr = compute_error(x_ref.grad, x.grad) |
| 344 | + assert sqnr >= min_input_grad_sqnr, ( |
| 345 | + f"Input grad sqnr {sqnr} is too low, must be >= {min_input_grad_sqnr}" |
| 346 | + ) |
| 347 | + |
| 348 | + # Check weight grads |
| 349 | + min_weight_grad_sqnr = 24.0 |
| 350 | + sqnr = compute_error(w_t_ref.grad, w_t.grad) |
| 351 | + assert sqnr >= min_weight_grad_sqnr, ( |
| 352 | + f"Weight grad sqnr {sqnr} is too low, must be >= {min_weight_grad_sqnr}" |
| 353 | + ) |
0 commit comments