Skip to content

Commit cf53713

Browse files
committed
Migrate to unittest for files in test/dtypes
1 parent 2f8fd69 commit cf53713

File tree

2 files changed

+30
-33
lines changed

2 files changed

+30
-33
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import pytest
6+
import unittest
77

88
from torchao.utils import (
99
TORCH_VERSION_AT_LEAST_2_5,
1010
)
1111

1212
if not TORCH_VERSION_AT_LEAST_2_5:
13-
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
13+
raise unittest.SkipTest("Unsupported PyTorch version")
1414

1515
import copy
1616
import io
@@ -20,10 +20,10 @@
2020
from functools import partial
2121
from typing import Tuple
2222

23-
import pytest
2423
import torch
2524
from torch._inductor.test_case import TestCase as InductorTestCase
2625
from torch.testing._internal import common_utils
26+
from torch.testing._internal.common_utils import parametrize, run_tests
2727

2828
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale
2929
from torchao.float8.float8_utils import compute_error
@@ -74,12 +74,12 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
7474
@unittest.skipIf(
7575
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
7676
)
77-
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
78-
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
79-
@common_utils.parametrize("compile", [True, False])
80-
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
77+
@parametrize("dtype", [torch.bfloat16, torch.float32])
78+
@parametrize("mode", ["dynamic", "weight-only", "static"])
79+
@parametrize("compile", [True, False])
80+
@parametrize("granularity", [PerTensor(), PerRow()])
8181
# Inputs are (M,..), K, N
82-
@common_utils.parametrize(
82+
@parametrize(
8383
"sizes",
8484
[
8585
((128,), 256, 128),
@@ -99,7 +99,7 @@ def test_fp8_linear_variants(
9999
)
100100

101101
error_context = (
102-
pytest.raises(AssertionError, match=error_message)
102+
self.assertRaisesRegex(AssertionError, error_message)
103103
if error_message
104104
else nullcontext()
105105
)
@@ -150,16 +150,16 @@ def test_fp8_linear_variants(
150150
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
151151
)
152152
def test_invalid_granularity(self):
153-
with pytest.raises(ValueError, match="Invalid granularity specification"):
153+
with self.assertRaisesRegex(ValueError, "Invalid granularity specification"):
154154
float8_dynamic_activation_float8_weight(granularity="invalid")
155155

156156
@unittest.skipIf(
157157
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
158158
)
159159
def test_mismatched_granularity(self):
160-
with pytest.raises(
160+
with self.assertRaisesRegex(
161161
ValueError,
162-
match="Different granularities for activation and weight are not supported",
162+
"Different granularities for activation and weight are not supported",
163163
):
164164
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))
165165

@@ -170,7 +170,7 @@ def test_unsupported_granularity(self):
170170
class UnsupportedGranularity:
171171
pass
172172

173-
with pytest.raises(ValueError, match="Invalid granularity types"):
173+
with self.assertRaisesRegex(ValueError, "Invalid granularity types"):
174174
float8_dynamic_activation_float8_weight(
175175
granularity=(UnsupportedGranularity(), UnsupportedGranularity())
176176
)
@@ -180,9 +180,9 @@ class UnsupportedGranularity:
180180
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
181181
)
182182
def test_per_row_with_float32(self):
183-
with pytest.raises(
183+
with self.assertRaisesRegex(
184184
AssertionError,
185-
match="PerRow quantization only works for bfloat16 precision",
185+
"PerRow quantization only works for bfloat16 precision",
186186
):
187187
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
188188
quantize_(
@@ -193,7 +193,7 @@ def test_per_row_with_float32(self):
193193
@unittest.skipIf(
194194
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
195195
)
196-
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
196+
@parametrize("mode", ["dynamic", "weight-only", "static"])
197197
def test_serialization(self, mode: str):
198198
# Create and quantize the model
199199
model = ToyLinearModel(16, 32).to(device="cuda")
@@ -300,13 +300,11 @@ def test_fp8_weight_dimension_warning(self):
300300
@unittest.skipIf(
301301
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
302302
)
303-
@common_utils.parametrize(
304-
"in_features,out_features", [(512, 1024), (256, 768), (1024, 512)]
305-
)
306-
@common_utils.parametrize(
303+
@parametrize("in_features,out_features", [(512, 1024), (256, 768), (1024, 512)])
304+
@parametrize(
307305
"leading_shape", [(1,), (8,), (16,), (2, 8,), (2, 2, 16,)]
308306
) # fmt: skip
309-
@common_utils.parametrize("bias", [True, False])
307+
@parametrize("bias", [True, False])
310308
def test_mm_float8dq_per_row(
311309
self, in_features, out_features, leading_shape, bias: bool
312310
):
@@ -354,8 +352,8 @@ def test_mm_float8dq_per_row(
354352
@unittest.skipIf(
355353
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
356354
)
357-
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
358-
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
355+
@parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
356+
@parametrize("output_dtype", [torch.float32, torch.bfloat16])
359357
def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype):
360358
block_size = ()
361359
device = "cuda"
@@ -397,9 +395,9 @@ def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype):
397395
@unittest.skipIf(
398396
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
399397
)
400-
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
401-
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
402-
@common_utils.parametrize("block_size", [(), (1, 32), (2, 16), (4, 8)])
398+
@parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
399+
@parametrize("output_dtype", [torch.float32, torch.bfloat16])
400+
@parametrize("block_size", [(), (1, 32), (2, 16), (4, 8)])
403401
def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
404402
"""Test _dequantize_affine_float8 with various configurations"""
405403

@@ -462,7 +460,7 @@ def test_dequantize_affine_float8_scale_broadcasting(self):
462460
@unittest.skipIf(
463461
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
464462
)
465-
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
463+
@parametrize("granularity", [PerTensor(), PerRow()])
466464
def test_float8_tensor_slicing_basic(self, granularity):
467465
"""Test basic slicing operations on Float8 tensors"""
468466
device = "cuda"
@@ -595,7 +593,7 @@ def test_float8_tensor_slicing_edge_cases(self):
595593
@unittest.skipIf(
596594
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
597595
)
598-
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
596+
@parametrize("granularity", [PerTensor(), PerRow()])
599597
@unittest.skipIf(
600598
is_sm_version(8, 9),
601599
"TODO: AssertionError: tensor(-2.1562, device='cuda:0', dtype=torch.bfloat16) not greater than 15",
@@ -718,8 +716,8 @@ def test_preprocess_scale_3d_reshape(self):
718716
expected_shape = (8, 1) # Flattened (2*2*2, 1)
719717
self.assertEqual(result.shape, expected_shape)
720718

721-
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
722-
@common_utils.parametrize("hp_dtype", [torch.float32, torch.bfloat16])
719+
@parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
720+
@parametrize("hp_dtype", [torch.float32, torch.bfloat16])
723721
def test_quantize_dequantize_fp8_inductor(self, float8_dtype, hp_dtype):
724722
quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8
725723
dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8
@@ -762,4 +760,4 @@ def test_quantize_dequantize_fp8_inductor(self, float8_dtype, hp_dtype):
762760
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
763761

764762
if __name__ == "__main__":
765-
pytest.main([__file__])
763+
run_tests()

test/dtypes/test_affine_quantized_tensor_parallel.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66
import unittest
77

8-
import pytest
98
import torch
109
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
1110
from torch.testing._internal import common_utils
@@ -34,7 +33,7 @@
3433
has_gemlite = False
3534

3635
if torch.version.hip is not None:
37-
pytest.skip("Skipping the test in ROCm", allow_module_level=True)
36+
raise unittest.SkipTest("Skipping the test in ROCm", allow_module_level=True)
3837

3938

4039
class TestAffineQuantizedTensorParallel(DTensorTestBase):

0 commit comments

Comments
 (0)