Skip to content

Commit 2bb90ba

Browse files
authored
[TRTLLM-6960][fix] enable scaled_mm tests (#6936)
Signed-off-by: Zhenhuan Chen <[email protected]>
1 parent 06911c0 commit 2bb90ba

File tree

1 file changed

+1
-6
lines changed

1 file changed

+1
-6
lines changed

tests/unittest/_torch/thop/test_scaled_mm.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,6 @@
3838
[torch.float16, torch.float32, torch.bfloat16],
3939
)
4040
def test_fp8_scaled_mm(output_dtype, m, k_n):
41-
if getSMVersion() == 90:
42-
pytest.skip(
43-
"Skip test for sm90 because it's too flaky. https://nvbugspro.nvidia.com/bug/5441734"
44-
)
45-
4641
k, n = k_n
4742
torch.random.manual_seed(0)
4843
shape_x = (m, k)
@@ -76,7 +71,7 @@ def test_fp8_scaled_mm(output_dtype, m, k_n):
7671
os.environ["CUBLASLT_WORKSPACE_SIZE"] = old_env
7772
np.testing.assert_allclose(ref.float().cpu(),
7873
output.float().cpu(),
79-
atol=1,
74+
atol=0.01,
8075
rtol=0.01)
8176

8277
if getSMVersion() == 90:

0 commit comments

Comments
 (0)