3
3
#
4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
- import pytest
6
+ import unittest
7
7
8
8
from torchao .utils import (
9
9
TORCH_VERSION_AT_LEAST_2_5 ,
10
10
)
11
11
12
12
if not TORCH_VERSION_AT_LEAST_2_5 :
13
- pytest . skip ("Unsupported PyTorch version" , allow_module_level = True )
13
+ raise unittest . SkipTest ("Unsupported PyTorch version" )
14
14
15
15
import copy
16
16
import io
20
20
from functools import partial
21
21
from typing import Tuple
22
22
23
- import pytest
24
23
import torch
25
24
from torch ._inductor .test_case import TestCase as InductorTestCase
26
25
from torch .testing ._internal import common_utils
26
+ from torch .testing ._internal .common_utils import parametrize , run_tests
27
27
28
28
from torchao .dtypes .floatx .float8_layout import Float8AQTTensorImpl , preprocess_scale
29
29
from torchao .float8 .float8_utils import compute_error
@@ -74,12 +74,12 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
74
74
@unittest .skipIf (
75
75
not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
76
76
)
77
- @common_utils . parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
78
- @common_utils . parametrize ("mode" , ["dynamic" , "weight-only" , "static" ])
79
- @common_utils . parametrize ("compile" , [True , False ])
80
- @common_utils . parametrize ("granularity" , [PerTensor (), PerRow ()])
77
+ @parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
78
+ @parametrize ("mode" , ["dynamic" , "weight-only" , "static" ])
79
+ @parametrize ("compile" , [True , False ])
80
+ @parametrize ("granularity" , [PerTensor (), PerRow ()])
81
81
# Inputs are (M,..), K, N
82
- @common_utils . parametrize (
82
+ @parametrize (
83
83
"sizes" ,
84
84
[
85
85
((128 ,), 256 , 128 ),
@@ -99,7 +99,7 @@ def test_fp8_linear_variants(
99
99
)
100
100
101
101
error_context = (
102
- pytest . raises (AssertionError , match = error_message )
102
+ self . assertRaisesRegex (AssertionError , error_message )
103
103
if error_message
104
104
else nullcontext ()
105
105
)
@@ -150,16 +150,16 @@ def test_fp8_linear_variants(
150
150
not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
151
151
)
152
152
def test_invalid_granularity (self ):
153
- with pytest . raises (ValueError , match = "Invalid granularity specification" ):
153
+ with self . assertRaisesRegex (ValueError , "Invalid granularity specification" ):
154
154
float8_dynamic_activation_float8_weight (granularity = "invalid" )
155
155
156
156
@unittest .skipIf (
157
157
not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
158
158
)
159
159
def test_mismatched_granularity (self ):
160
- with pytest . raises (
160
+ with self . assertRaisesRegex (
161
161
ValueError ,
162
- match = "Different granularities for activation and weight are not supported" ,
162
+ "Different granularities for activation and weight are not supported" ,
163
163
):
164
164
float8_dynamic_activation_float8_weight (granularity = (PerTensor (), PerRow ()))
165
165
@@ -170,7 +170,7 @@ def test_unsupported_granularity(self):
170
170
class UnsupportedGranularity :
171
171
pass
172
172
173
- with pytest . raises (ValueError , match = "Invalid granularity types" ):
173
+ with self . assertRaisesRegex (ValueError , "Invalid granularity types" ):
174
174
float8_dynamic_activation_float8_weight (
175
175
granularity = (UnsupportedGranularity (), UnsupportedGranularity ())
176
176
)
@@ -180,9 +180,9 @@ class UnsupportedGranularity:
180
180
not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
181
181
)
182
182
def test_per_row_with_float32 (self ):
183
- with pytest . raises (
183
+ with self . assertRaisesRegex (
184
184
AssertionError ,
185
- match = "PerRow quantization only works for bfloat16 precision" ,
185
+ "PerRow quantization only works for bfloat16 precision" ,
186
186
):
187
187
model = ToyLinearModel (64 , 64 ).eval ().to (torch .float32 ).to ("cuda" )
188
188
quantize_ (
@@ -193,7 +193,7 @@ def test_per_row_with_float32(self):
193
193
@unittest .skipIf (
194
194
not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
195
195
)
196
- @common_utils . parametrize ("mode" , ["dynamic" , "weight-only" , "static" ])
196
+ @parametrize ("mode" , ["dynamic" , "weight-only" , "static" ])
197
197
def test_serialization (self , mode : str ):
198
198
# Create and quantize the model
199
199
model = ToyLinearModel (16 , 32 ).to (device = "cuda" )
@@ -300,13 +300,11 @@ def test_fp8_weight_dimension_warning(self):
300
300
@unittest .skipIf (
301
301
not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
302
302
)
303
- @common_utils .parametrize (
304
- "in_features,out_features" , [(512 , 1024 ), (256 , 768 ), (1024 , 512 )]
305
- )
306
- @common_utils .parametrize (
303
+ @parametrize ("in_features,out_features" , [(512 , 1024 ), (256 , 768 ), (1024 , 512 )])
304
+ @parametrize (
307
305
"leading_shape" , [(1 ,), (8 ,), (16 ,), (2 , 8 ,), (2 , 2 , 16 ,)]
308
306
) # fmt: skip
309
- @common_utils . parametrize ("bias" , [True , False ])
307
+ @parametrize ("bias" , [True , False ])
310
308
def test_mm_float8dq_per_row (
311
309
self , in_features , out_features , leading_shape , bias : bool
312
310
):
@@ -354,8 +352,8 @@ def test_mm_float8dq_per_row(
354
352
@unittest .skipIf (
355
353
not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
356
354
)
357
- @common_utils . parametrize ("float8_dtype" , [torch .float8_e4m3fn , torch .float8_e5m2 ])
358
- @common_utils . parametrize ("output_dtype" , [torch .float32 , torch .bfloat16 ])
355
+ @parametrize ("float8_dtype" , [torch .float8_e4m3fn , torch .float8_e5m2 ])
356
+ @parametrize ("output_dtype" , [torch .float32 , torch .bfloat16 ])
359
357
def test_choose_scale_float8_bounds (self , float8_dtype , output_dtype ):
360
358
block_size = ()
361
359
device = "cuda"
@@ -397,9 +395,9 @@ def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype):
397
395
@unittest .skipIf (
398
396
not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
399
397
)
400
- @common_utils . parametrize ("float8_dtype" , [torch .float8_e4m3fn , torch .float8_e5m2 ])
401
- @common_utils . parametrize ("output_dtype" , [torch .float32 , torch .bfloat16 ])
402
- @common_utils . parametrize ("block_size" , [(), (1 , 32 ), (2 , 16 ), (4 , 8 )])
398
+ @parametrize ("float8_dtype" , [torch .float8_e4m3fn , torch .float8_e5m2 ])
399
+ @parametrize ("output_dtype" , [torch .float32 , torch .bfloat16 ])
400
+ @parametrize ("block_size" , [(), (1 , 32 ), (2 , 16 ), (4 , 8 )])
403
401
def test_dequantize_affine_float8 (self , float8_dtype , output_dtype , block_size ):
404
402
"""Test _dequantize_affine_float8 with various configurations"""
405
403
@@ -462,7 +460,7 @@ def test_dequantize_affine_float8_scale_broadcasting(self):
462
460
@unittest .skipIf (
463
461
not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
464
462
)
465
- @common_utils . parametrize ("granularity" , [PerTensor (), PerRow ()])
463
+ @parametrize ("granularity" , [PerTensor (), PerRow ()])
466
464
def test_float8_tensor_slicing_basic (self , granularity ):
467
465
"""Test basic slicing operations on Float8 tensors"""
468
466
device = "cuda"
@@ -595,7 +593,7 @@ def test_float8_tensor_slicing_edge_cases(self):
595
593
@unittest .skipIf (
596
594
not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
597
595
)
598
- @common_utils . parametrize ("granularity" , [PerTensor (), PerRow ()])
596
+ @parametrize ("granularity" , [PerTensor (), PerRow ()])
599
597
@unittest .skipIf (
600
598
is_sm_version (8 , 9 ),
601
599
"TODO: AssertionError: tensor(-2.1562, device='cuda:0', dtype=torch.bfloat16) not greater than 15" ,
@@ -718,8 +716,8 @@ def test_preprocess_scale_3d_reshape(self):
718
716
expected_shape = (8 , 1 ) # Flattened (2*2*2, 1)
719
717
self .assertEqual (result .shape , expected_shape )
720
718
721
- @common_utils . parametrize ("float8_dtype" , [torch .float8_e4m3fn , torch .float8_e5m2 ])
722
- @common_utils . parametrize ("hp_dtype" , [torch .float32 , torch .bfloat16 ])
719
+ @parametrize ("float8_dtype" , [torch .float8_e4m3fn , torch .float8_e5m2 ])
720
+ @parametrize ("hp_dtype" , [torch .float32 , torch .bfloat16 ])
723
721
def test_quantize_dequantize_fp8_inductor (self , float8_dtype , hp_dtype ):
724
722
quantize_affine_float8 = torch .ops .torchao .quantize_affine_float8
725
723
dequantize_affine_float8 = torch .ops .torchao .dequantize_affine_float8
@@ -762,4 +760,4 @@ def test_quantize_dequantize_fp8_inductor(self, float8_dtype, hp_dtype):
762
760
common_utils .instantiate_parametrized_tests (TestAffineQuantizedFloat8Compile )
763
761
764
762
if __name__ == "__main__" :
765
- pytest . main ([ __file__ ] )
763
+ run_tests ( )
0 commit comments