From 7f37b7dc090907c6741686ce20bed2713588a281 Mon Sep 17 00:00:00 2001 From: zeshengzong Date: Fri, 25 Jul 2025 09:09:50 +0000 Subject: [PATCH] Migrate to unittest for files in test/dtypes --- test/dtypes/test_affine_quantized_float.py | 60 +++++++++---------- .../test_affine_quantized_tensor_parallel.py | 3 +- 2 files changed, 30 insertions(+), 33 deletions(-) diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 56010d7d1b..ee8db398d3 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -3,14 +3,14 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import pytest +import unittest from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, ) if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) + raise unittest.SkipTest("Unsupported PyTorch version") import copy import io @@ -20,11 +20,11 @@ from functools import partial from typing import Tuple -import pytest import torch from torch._inductor.test_case import TestCase as InductorTestCase from torch.profiler import ProfilerActivity, profile from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import parametrize, run_tests from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale from torchao.float8.float8_utils import compute_error @@ -75,12 +75,12 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase): @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) - @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) - @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) - @common_utils.parametrize("compile", [True, False]) - @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + @parametrize("dtype", [torch.bfloat16, torch.float32]) + @parametrize("mode", ["dynamic", "weight-only", "static"]) + @parametrize("compile", [True, False]) + @parametrize("granularity", [PerTensor(), PerRow()]) # Inputs are (M,..), K, N - @common_utils.parametrize( + @parametrize( "sizes", [ ((128,), 256, 128), @@ -100,7 +100,7 @@ def test_fp8_linear_variants( ) error_context = ( - pytest.raises(AssertionError, match=error_message) + self.assertRaisesRegex(AssertionError, error_message) if error_message else nullcontext() ) @@ -151,16 +151,16 @@ def test_fp8_linear_variants( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) def test_invalid_granularity(self): - with pytest.raises(ValueError, match="Invalid granularity specification"): + with self.assertRaisesRegex(ValueError, "Invalid granularity specification"): float8_dynamic_activation_float8_weight(granularity="invalid") @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) def test_mismatched_granularity(self): - with pytest.raises( + with self.assertRaisesRegex( ValueError, - match="Different granularities for activation and weight are not supported", + "Different granularities for activation and weight are not supported", ): float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow())) @@ -171,7 +171,7 @@ def test_unsupported_granularity(self): class UnsupportedGranularity: pass - with pytest.raises(ValueError, match="Invalid granularity types"): + with self.assertRaisesRegex(ValueError, "Invalid granularity types"): float8_dynamic_activation_float8_weight( granularity=(UnsupportedGranularity(), UnsupportedGranularity()) ) @@ -181,9 +181,9 @@ class UnsupportedGranularity: not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) def test_per_row_with_float32(self): - with pytest.raises( + with self.assertRaisesRegex( AssertionError, - match="PerRow quantization only works for bfloat16 precision", + "PerRow quantization only works for bfloat16 precision", ): model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda") quantize_( @@ -194,7 +194,7 @@ def test_per_row_with_float32(self): @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) - @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) + @parametrize("mode", ["dynamic", "weight-only", "static"]) def test_serialization(self, mode: str): # Create and quantize the model model = ToyLinearModel(16, 32).to(device="cuda") @@ -301,13 +301,11 @@ def test_fp8_weight_dimension_warning(self): @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) - @common_utils.parametrize( - "in_features,out_features", [(512, 1024), (256, 768), (1024, 512)] - ) - @common_utils.parametrize( + @parametrize("in_features,out_features", [(512, 1024), (256, 768), (1024, 512)]) + @parametrize( "leading_shape", [(1,), (8,), (16,), (2, 8,), (2, 2, 16,)] ) # fmt: skip - @common_utils.parametrize("bias", [True, False]) + @parametrize("bias", [True, False]) def test_mm_float8dq_per_row( self, in_features, out_features, leading_shape, bias: bool ): @@ -355,8 +353,8 @@ def test_mm_float8dq_per_row( @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) - @common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) - @common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16]) + @parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + @parametrize("output_dtype", [torch.float32, torch.bfloat16]) def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype): block_size = () device = "cuda" @@ -398,9 +396,9 @@ def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype): @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) - @common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) - @common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16]) - @common_utils.parametrize("block_size", [(), (1, 32), (2, 16), (4, 8)]) + @parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + @parametrize("output_dtype", [torch.float32, torch.bfloat16]) + @parametrize("block_size", [(), (1, 32), (2, 16), (4, 8)]) def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size): """Test _dequantize_affine_float8 with various configurations""" @@ -463,7 +461,7 @@ def test_dequantize_affine_float8_scale_broadcasting(self): @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) - @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + @parametrize("granularity", [PerTensor(), PerRow()]) def test_float8_tensor_slicing_basic(self, granularity): """Test basic slicing operations on Float8 tensors""" device = "cuda" @@ -596,7 +594,7 @@ def test_float8_tensor_slicing_edge_cases(self): @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) - @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + @parametrize("granularity", [PerTensor(), PerRow()]) @unittest.skipIf( is_sm_version(8, 9), "TODO: AssertionError: tensor(-2.1562, device='cuda:0', dtype=torch.bfloat16) not greater than 15", @@ -724,8 +722,8 @@ def test_preprocess_scale_3d_reshape(self): @unittest.skipIf( not is_sm_at_least_90(), "Requires GPU with compute capability >= 9.0" ) - @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) - @common_utils.parametrize( + @parametrize("granularity", [PerTensor(), PerRow()]) + @parametrize( "torch_compile_mode", [ "default", @@ -792,4 +790,4 @@ def test_expected_kernels_on_gpu(self, granularity, torch_compile_mode): common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile) if __name__ == "__main__": - pytest.main([__file__]) + run_tests() diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index c2eff77b07..806c79a382 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import unittest -import pytest import torch from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.testing._internal import common_utils @@ -37,7 +36,7 @@ has_gemlite = False if torch.version.hip is not None: - pytest.skip("Skipping the test in ROCm", allow_module_level=True) + raise unittest.SkipTest("Skipping the test in ROCm", allow_module_level=True) class TestAffineQuantizedTensorParallel(DTensorTestBase):