Skip to content

Migrate to unittest for files in test/dtypes #2605

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 29 additions & 31 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import unittest

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
)

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
raise unittest.SkipTest("Unsupported PyTorch version")

import copy
import io
Expand All @@ -20,11 +20,11 @@
from functools import partial
from typing import Tuple

import pytest
import torch
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.profiler import ProfilerActivity, profile
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import parametrize, run_tests

from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale
from torchao.float8.float8_utils import compute_error
Expand Down Expand Up @@ -75,12 +75,12 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@parametrize("dtype", [torch.bfloat16, torch.float32])
@parametrize("mode", ["dynamic", "weight-only", "static"])
@parametrize("compile", [True, False])
@parametrize("granularity", [PerTensor(), PerRow()])
# Inputs are (M,..), K, N
@common_utils.parametrize(
@parametrize(
"sizes",
[
((128,), 256, 128),
Expand All @@ -100,7 +100,7 @@ def test_fp8_linear_variants(
)

error_context = (
pytest.raises(AssertionError, match=error_message)
self.assertRaisesRegex(AssertionError, error_message)
if error_message
else nullcontext()
)
Expand Down Expand Up @@ -151,16 +151,16 @@ def test_fp8_linear_variants(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_invalid_granularity(self):
with pytest.raises(ValueError, match="Invalid granularity specification"):
with self.assertRaisesRegex(ValueError, "Invalid granularity specification"):
float8_dynamic_activation_float8_weight(granularity="invalid")

@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_mismatched_granularity(self):
with pytest.raises(
with self.assertRaisesRegex(
ValueError,
match="Different granularities for activation and weight are not supported",
"Different granularities for activation and weight are not supported",
):
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))

Expand All @@ -171,7 +171,7 @@ def test_unsupported_granularity(self):
class UnsupportedGranularity:
pass

with pytest.raises(ValueError, match="Invalid granularity types"):
with self.assertRaisesRegex(ValueError, "Invalid granularity types"):
float8_dynamic_activation_float8_weight(
granularity=(UnsupportedGranularity(), UnsupportedGranularity())
)
Expand All @@ -181,9 +181,9 @@ class UnsupportedGranularity:
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_per_row_with_float32(self):
with pytest.raises(
with self.assertRaisesRegex(
AssertionError,
match="PerRow quantization only works for bfloat16 precision",
"PerRow quantization only works for bfloat16 precision",
):
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
quantize_(
Expand All @@ -194,7 +194,7 @@ def test_per_row_with_float32(self):
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
@parametrize("mode", ["dynamic", "weight-only", "static"])
def test_serialization(self, mode: str):
# Create and quantize the model
model = ToyLinearModel(16, 32).to(device="cuda")
Expand Down Expand Up @@ -301,13 +301,11 @@ def test_fp8_weight_dimension_warning(self):
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize(
"in_features,out_features", [(512, 1024), (256, 768), (1024, 512)]
)
@common_utils.parametrize(
@parametrize("in_features,out_features", [(512, 1024), (256, 768), (1024, 512)])
@parametrize(
"leading_shape", [(1,), (8,), (16,), (2, 8,), (2, 2, 16,)]
) # fmt: skip
@common_utils.parametrize("bias", [True, False])
@parametrize("bias", [True, False])
def test_mm_float8dq_per_row(
self, in_features, out_features, leading_shape, bias: bool
):
Expand Down Expand Up @@ -355,8 +353,8 @@ def test_mm_float8dq_per_row(
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
@parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@parametrize("output_dtype", [torch.float32, torch.bfloat16])
def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype):
block_size = ()
device = "cuda"
Expand Down Expand Up @@ -398,9 +396,9 @@ def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype):
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
@common_utils.parametrize("block_size", [(), (1, 32), (2, 16), (4, 8)])
@parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@parametrize("output_dtype", [torch.float32, torch.bfloat16])
@parametrize("block_size", [(), (1, 32), (2, 16), (4, 8)])
def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
"""Test _dequantize_affine_float8 with various configurations"""

Expand Down Expand Up @@ -463,7 +461,7 @@ def test_dequantize_affine_float8_scale_broadcasting(self):
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@parametrize("granularity", [PerTensor(), PerRow()])
def test_float8_tensor_slicing_basic(self, granularity):
"""Test basic slicing operations on Float8 tensors"""
device = "cuda"
Expand Down Expand Up @@ -596,7 +594,7 @@ def test_float8_tensor_slicing_edge_cases(self):
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@parametrize("granularity", [PerTensor(), PerRow()])
@unittest.skipIf(
is_sm_version(8, 9),
"TODO: AssertionError: tensor(-2.1562, device='cuda:0', dtype=torch.bfloat16) not greater than 15",
Expand Down Expand Up @@ -724,8 +722,8 @@ def test_preprocess_scale_3d_reshape(self):
@unittest.skipIf(
not is_sm_at_least_90(), "Requires GPU with compute capability >= 9.0"
)
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@common_utils.parametrize(
@parametrize("granularity", [PerTensor(), PerRow()])
@parametrize(
"torch_compile_mode",
[
"default",
Expand Down Expand Up @@ -792,4 +790,4 @@ def test_expected_kernels_on_gpu(self, granularity, torch_compile_mode):
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

if __name__ == "__main__":
pytest.main([__file__])
run_tests()
3 changes: 1 addition & 2 deletions test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.
import unittest

import pytest
import torch
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
from torch.testing._internal import common_utils
Expand Down Expand Up @@ -37,7 +36,7 @@
has_gemlite = False

if torch.version.hip is not None:
pytest.skip("Skipping the test in ROCm", allow_module_level=True)
raise unittest.SkipTest("Skipping the test in ROCm", allow_module_level=True)


class TestAffineQuantizedTensorParallel(DTensorTestBase):
Expand Down