Skip to content

Commit 3641c98

Browse files
committed
Migrate to unittest for files in test/dtypes
1 parent 376d6d2 commit 3641c98

File tree

6 files changed

+235
-230
lines changed

6 files changed

+235
-230
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import pytest
6+
import unittest
77

88
from torchao.utils import (
99
TORCH_VERSION_AT_LEAST_2_5,
1010
)
1111

1212
if not TORCH_VERSION_AT_LEAST_2_5:
13-
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
13+
raise unittest.SkipTest("Unsupported PyTorch version")
1414

1515
import copy
1616
import io
@@ -20,10 +20,10 @@
2020
from functools import partial
2121
from typing import Tuple
2222

23-
import pytest
2423
import torch
2524
from torch._inductor.test_case import TestCase as InductorTestCase
2625
from torch.testing._internal import common_utils
26+
from torch.testing._internal.common_utils import parametrize, run_tests
2727

2828
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale
2929
from torchao.float8.float8_utils import compute_error
@@ -74,12 +74,12 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
7474
@unittest.skipIf(
7575
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
7676
)
77-
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
78-
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
79-
@common_utils.parametrize("compile", [True, False])
80-
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
77+
@parametrize("dtype", [torch.bfloat16, torch.float32])
78+
@parametrize("mode", ["dynamic", "weight-only", "static"])
79+
@parametrize("compile", [True, False])
80+
@parametrize("granularity", [PerTensor(), PerRow()])
8181
# Inputs are (M,..), K, N
82-
@common_utils.parametrize(
82+
@parametrize(
8383
"sizes",
8484
[
8585
((128,), 256, 128),
@@ -99,7 +99,7 @@ def test_fp8_linear_variants(
9999
)
100100

101101
error_context = (
102-
pytest.raises(AssertionError, match=error_message)
102+
self.assertRaisesRegex(AssertionError, error_message)
103103
if error_message
104104
else nullcontext()
105105
)
@@ -150,16 +150,16 @@ def test_fp8_linear_variants(
150150
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
151151
)
152152
def test_invalid_granularity(self):
153-
with pytest.raises(ValueError, match="Invalid granularity specification"):
153+
with self.assertRaisesRegex(ValueError, "Invalid granularity specification"):
154154
float8_dynamic_activation_float8_weight(granularity="invalid")
155155

156156
@unittest.skipIf(
157157
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
158158
)
159159
def test_mismatched_granularity(self):
160-
with pytest.raises(
160+
with self.assertRaisesRegex(
161161
ValueError,
162-
match="Different granularities for activation and weight are not supported",
162+
"Different granularities for activation and weight are not supported",
163163
):
164164
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))
165165

@@ -170,7 +170,7 @@ def test_unsupported_granularity(self):
170170
class UnsupportedGranularity:
171171
pass
172172

173-
with pytest.raises(ValueError, match="Invalid granularity types"):
173+
with self.assertRaisesRegex(ValueError, "Invalid granularity types"):
174174
float8_dynamic_activation_float8_weight(
175175
granularity=(UnsupportedGranularity(), UnsupportedGranularity())
176176
)
@@ -180,9 +180,9 @@ class UnsupportedGranularity:
180180
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
181181
)
182182
def test_per_row_with_float32(self):
183-
with pytest.raises(
183+
with self.assertRaisesRegex(
184184
AssertionError,
185-
match="PerRow quantization only works for bfloat16 precision",
185+
"PerRow quantization only works for bfloat16 precision",
186186
):
187187
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
188188
quantize_(
@@ -193,7 +193,7 @@ def test_per_row_with_float32(self):
193193
@unittest.skipIf(
194194
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
195195
)
196-
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
196+
@parametrize("mode", ["dynamic", "weight-only", "static"])
197197
def test_serialization(self, mode: str):
198198
# Create and quantize the model
199199
model = ToyLinearModel(16, 32).to(device="cuda")
@@ -300,13 +300,11 @@ def test_fp8_weight_dimension_warning(self):
300300
@unittest.skipIf(
301301
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
302302
)
303-
@common_utils.parametrize(
304-
"in_features,out_features", [(512, 1024), (256, 768), (1024, 512)]
305-
)
306-
@common_utils.parametrize(
303+
@parametrize("in_features,out_features", [(512, 1024), (256, 768), (1024, 512)])
304+
@parametrize(
307305
"leading_shape", [(1,), (8,), (16,), (2, 8,), (2, 2, 16,)]
308306
) # fmt: skip
309-
@common_utils.parametrize("bias", [True, False])
307+
@parametrize("bias", [True, False])
310308
def test_mm_float8dq_per_row(
311309
self, in_features, out_features, leading_shape, bias: bool
312310
):
@@ -354,9 +352,9 @@ def test_mm_float8dq_per_row(
354352
@unittest.skipIf(
355353
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
356354
)
357-
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
358-
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
359-
@common_utils.parametrize("block_size", [(), (1, 32), (2, 16), (4, 8)])
355+
@parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
356+
@parametrize("output_dtype", [torch.float32, torch.bfloat16])
357+
@parametrize("block_size", [(), (1, 32), (2, 16), (4, 8)])
360358
def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
361359
"""Test _dequantize_affine_float8 with various configurations"""
362360

@@ -419,7 +417,7 @@ def test_dequantize_affine_float8_scale_broadcasting(self):
419417
@unittest.skipIf(
420418
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
421419
)
422-
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
420+
@parametrize("granularity", [PerTensor(), PerRow()])
423421
def test_float8_tensor_slicing_basic(self, granularity):
424422
"""Test basic slicing operations on Float8 tensors"""
425423
device = "cuda"
@@ -552,7 +550,7 @@ def test_float8_tensor_slicing_edge_cases(self):
552550
@unittest.skipIf(
553551
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
554552
)
555-
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
553+
@parametrize("granularity", [PerTensor(), PerRow()])
556554
@unittest.skipIf(
557555
is_sm_version(8, 9),
558556
"TODO: AssertionError: tensor(-2.1562, device='cuda:0', dtype=torch.bfloat16) not greater than 15",
@@ -675,8 +673,8 @@ def test_preprocess_scale_3d_reshape(self):
675673
expected_shape = (8, 1) # Flattened (2*2*2, 1)
676674
self.assertEqual(result.shape, expected_shape)
677675

678-
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
679-
@common_utils.parametrize("hp_dtype", [torch.float32, torch.bfloat16])
676+
@parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
677+
@parametrize("hp_dtype", [torch.float32, torch.bfloat16])
680678
def test_quantize_dequantize_fp8_inductor(self, float8_dtype, hp_dtype):
681679
quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8
682680
dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8
@@ -719,4 +717,4 @@ def test_quantize_dequantize_fp8_inductor(self, float8_dtype, hp_dtype):
719717
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
720718

721719
if __name__ == "__main__":
722-
pytest.main([__file__])
720+
run_tests()

test/dtypes/test_affine_quantized_tensor_parallel.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66
import unittest
77

8-
import pytest
98
import torch
109
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
1110
from torch.testing._internal import common_utils
@@ -34,7 +33,7 @@
3433
has_gemlite = False
3534

3635
if torch.version.hip is not None:
37-
pytest.skip("Skipping the test in ROCm", allow_module_level=True)
36+
raise unittest.SkipTest("Skipping the test in ROCm", allow_module_level=True)
3837

3938

4039
class TestAffineQuantizedTensorParallel(DTensorTestBase):

test/dtypes/test_bitpacking.py

Lines changed: 71 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,15 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import pytest
6+
from unittest import skipIf
7+
78
import torch
9+
from torch.testing._internal.common_utils import (
10+
TestCase,
11+
instantiate_parametrized_tests,
12+
parametrize,
13+
run_tests,
14+
)
815
from torch.utils._triton import has_triton
916

1017
from torchao.dtypes.uintx.bitpacking import pack, pack_cpu, unpack, unpack_cpu
@@ -13,68 +20,74 @@
1320
dimensions = (0, -1, 1)
1421

1522

16-
@pytest.fixture(autouse=True)
17-
def run_before_and_after_tests():
18-
yield
19-
torch._dynamo.reset() # reset cache between tests
20-
21-
22-
@pytest.mark.parametrize("bit_width", bit_widths)
23-
@pytest.mark.parametrize("dim", dimensions)
24-
def test_CPU(bit_width, dim):
25-
test_tensor = torch.randint(
26-
0, 2**bit_width, (32, 32, 32), dtype=torch.uint8, device="cpu"
27-
)
28-
packed = pack_cpu(test_tensor, bit_width, dim=dim)
29-
unpacked = unpack_cpu(packed, bit_width, dim=dim)
30-
assert unpacked.allclose(test_tensor)
23+
class TestBitpacking(TestCase):
24+
def tearDown(self):
25+
torch._dynamo.reset() # reset cache between tests
3126

27+
@parametrize("bit_width", bit_widths)
28+
@parametrize("dim", dimensions)
29+
def test_CPU(self, bit_width, dim):
30+
test_tensor = torch.randint(
31+
0, 2**bit_width, (32, 32, 32), dtype=torch.uint8, device="cpu"
32+
)
33+
packed = pack_cpu(test_tensor, bit_width, dim=dim)
34+
unpacked = unpack_cpu(packed, bit_width, dim=dim)
35+
assert unpacked.allclose(test_tensor)
3236

33-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
34-
@pytest.mark.parametrize("bit_width", bit_widths)
35-
@pytest.mark.parametrize("dim", dimensions)
36-
def test_GPU(bit_width, dim):
37-
test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).cuda()
38-
packed = pack(test_tensor, bit_width, dim=dim)
39-
unpacked = unpack(packed, bit_width, dim=dim)
40-
assert unpacked.allclose(test_tensor)
37+
@skipIf(not torch.cuda.is_available(), "CUDA not available")
38+
@parametrize("bit_width", bit_widths)
39+
@parametrize("dim", dimensions)
40+
def test_GPU(self, bit_width, dim):
41+
test_tensor = torch.randint(
42+
0, 2**bit_width, (32, 32, 32), dtype=torch.uint8
43+
).cuda()
44+
packed = pack(test_tensor, bit_width, dim=dim)
45+
unpacked = unpack(packed, bit_width, dim=dim)
46+
assert unpacked.allclose(test_tensor)
4147

48+
@skipIf(not torch.cuda.is_available(), reason="CUDA not available")
49+
@skipIf(not has_triton(), reason="unsupported without triton")
50+
@parametrize("bit_width", bit_widths)
51+
@parametrize("dim", dimensions)
52+
def test_compile(self, bit_width, dim):
53+
torch._dynamo.config.specialize_int = True
54+
torch.compile(pack, fullgraph=True)
55+
torch.compile(unpack, fullgraph=True)
56+
test_tensor = torch.randint(
57+
0, 2**bit_width, (32, 32, 32), dtype=torch.uint8
58+
).cuda()
59+
packed = pack(test_tensor, bit_width, dim=dim)
60+
unpacked = unpack(packed, bit_width, dim=dim)
61+
assert unpacked.allclose(test_tensor)
4262

43-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
44-
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
45-
@pytest.mark.parametrize("bit_width", bit_widths)
46-
@pytest.mark.parametrize("dim", dimensions)
47-
def test_compile(bit_width, dim):
48-
torch._dynamo.config.specialize_int = True
49-
torch.compile(pack, fullgraph=True)
50-
torch.compile(unpack, fullgraph=True)
51-
test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).cuda()
52-
packed = pack(test_tensor, bit_width, dim=dim)
53-
unpacked = unpack(packed, bit_width, dim=dim)
54-
assert unpacked.allclose(test_tensor)
63+
# these test cases are for the example pack walk through in the bitpacking.py file
64+
@skipIf(not torch.cuda.is_available(), "CUDA not available")
65+
def test_pack_example(self):
66+
test_tensor = torch.tensor(
67+
[0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8
68+
).cuda()
69+
shard_4, shard_2 = pack(test_tensor, 6)
70+
print(shard_4, shard_2)
71+
assert (
72+
torch.tensor([0, 105, 151, 37], dtype=torch.uint8).cuda().allclose(shard_4)
73+
)
74+
assert torch.tensor([39, 146], dtype=torch.uint8).cuda().allclose(shard_2)
75+
unpacked = unpack([shard_4, shard_2], 6)
76+
assert unpacked.allclose(test_tensor)
5577

78+
def test_pack_example_CPU(self):
79+
test_tensor = torch.tensor(
80+
[0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8
81+
)
82+
shard_4, shard_2 = pack(test_tensor, 6)
83+
print(shard_4, shard_2)
84+
assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).allclose(shard_4)
85+
assert torch.tensor([39, 146], dtype=torch.uint8).allclose(shard_2)
86+
unpacked = unpack([shard_4, shard_2], 6)
87+
assert unpacked.allclose(test_tensor)
5688

57-
# these test cases are for the example pack walk through in the bitpacking.py file
58-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
59-
def test_pack_example():
60-
test_tensor = torch.tensor(
61-
[0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8
62-
).cuda()
63-
shard_4, shard_2 = pack(test_tensor, 6)
64-
print(shard_4, shard_2)
65-
assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).cuda().allclose(shard_4)
66-
assert torch.tensor([39, 146], dtype=torch.uint8).cuda().allclose(shard_2)
67-
unpacked = unpack([shard_4, shard_2], 6)
68-
assert unpacked.allclose(test_tensor)
6989

90+
instantiate_parametrized_tests(TestBitpacking)
7091

71-
def test_pack_example_CPU():
72-
test_tensor = torch.tensor(
73-
[0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8
74-
)
75-
shard_4, shard_2 = pack(test_tensor, 6)
76-
print(shard_4, shard_2)
77-
assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).allclose(shard_4)
78-
assert torch.tensor([39, 146], dtype=torch.uint8).allclose(shard_2)
79-
unpacked = unpack([shard_4, shard_2], 6)
80-
assert unpacked.allclose(test_tensor)
92+
if __name__ == "__main__":
93+
run_tests()

test/dtypes/test_nf4.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from collections import OrderedDict
1212
from typing import Tuple, Union
1313

14-
import pytest
1514
import torch
1615
import torch.nn.functional as F
1716
from torch import nn
@@ -628,9 +627,9 @@ class TestQLoRA(FSDPTest):
628627
def world_size(self) -> int:
629628
return 2
630629

631-
@pytest.mark.skipif(
630+
@unittest.skipIf(
632631
version.parse(torch.__version__).base_version < "2.4.0",
633-
reason="torch >= 2.4 required",
632+
"torch >= 2.4 required",
634633
)
635634
@skip_if_lt_x_gpu(2)
636635
def test_qlora_fsdp2(self):

0 commit comments

Comments
 (0)