Skip to content

Commit ad17cd1

Browse files
committed
Update
1 parent 58b35eb commit ad17cd1

File tree

6 files changed

+29
-26
lines changed

6 files changed

+29
-26
lines changed

test/float8/test_dtensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
"""
1212

1313
import os
14+
import unittest
1415

15-
import pytest
1616
import torch
1717

1818
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1919

2020
if not TORCH_VERSION_AT_LEAST_2_5:
21-
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
21+
raise unittest.SkipTest("Unsupported PyTorch version")
2222

2323
from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_tensor
2424
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh

test/float8/test_float8_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66
import unittest
77

8-
import pytest
98
import torch
109
from torch.testing._internal.common_utils import (
1110
TestCase,
@@ -76,7 +75,7 @@ def test_round_scale_down_to_power_of_2_valid_inputs(
7675
)
7776
def test_non_float32_input(self, invalid_dtype: torch.dtype):
7877
non_float32_tensor = torch.tensor([3.0], dtype=invalid_dtype)
79-
with pytest.raises(AssertionError, match="scale must be float32 tensor"):
78+
with self.assertRaisesRegex(AssertionError, "scale must be float32 tensor"):
8079
_round_scale_down_to_power_of_2(non_float32_tensor)
8180

8281

test/float8/test_fsdp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313

1414
import copy
1515
import os
16+
import unittest
1617
import warnings
1718

1819
import fire
19-
import pytest
2020

2121
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
2222

2323
if not TORCH_VERSION_AT_LEAST_2_5:
24-
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
24+
raise unittest.SkipTest("Unsupported PyTorch version")
2525

2626
import torch
2727
import torch.distributed as dist

test/float8/test_fsdp2_tp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212

1313
import copy
1414
import os
15+
import unittest
1516

16-
import pytest
1717
import torch
1818

1919
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
2020

2121
if not TORCH_VERSION_AT_LEAST_2_5:
22-
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
22+
raise unittest.SkipTest("Unsupported PyTorch version")
2323

2424
from torch.distributed._composable.fsdp import fully_shard
2525
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh

test/float8/test_fsdp_compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
"""
1010

1111
import os
12+
import unittest
1213
import warnings
1314

1415
import fire
15-
import pytest
1616

1717
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1818

1919
if not TORCH_VERSION_AT_LEAST_2_5:
20-
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
20+
raise unittest.SkipTest("Unsupported PyTorch version")
2121

2222
import torch
2323
import torch.distributed as dist

test/float8/test_numerics_integration.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,15 @@
77
# Tests LLaMa FeedForward numerics with float8
88

99
import copy
10+
import unittest
1011
from typing import Optional
1112

12-
import pytest
13+
from torch.testing._internal.common_utils import (
14+
TestCase,
15+
instantiate_parametrized_tests,
16+
parametrize,
17+
run_tests,
18+
)
1319

1420
from torchao.utils import (
1521
TORCH_VERSION_AT_LEAST_2_5,
@@ -18,7 +24,7 @@
1824
)
1925

2026
if not TORCH_VERSION_AT_LEAST_2_5:
21-
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
27+
raise unittest.SkipTest("Unsupported PyTorch version")
2228

2329
import torch
2430
import torch.nn as nn
@@ -83,7 +89,7 @@ def init_weights(self, init_std: float):
8389
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
8490

8591

86-
class TestFloat8NumericsIntegrationTest:
92+
class TestFloat8NumericsIntegrationTest(TestCase):
8793
def _test_impl(self, config: Float8LinearConfig) -> None:
8894
data_dtype = torch.bfloat16
8995
# LLaMa 3 70B shapes
@@ -147,22 +153,20 @@ def _test_impl(self, config: Float8LinearConfig) -> None:
147153
sqnr = compute_error(ref_grad, cur_grad)
148154
assert sqnr > grad_sqnr_threshold
149155

150-
@pytest.mark.parametrize(
156+
@parametrize(
151157
"scaling_type_input",
152158
[ScalingType.DYNAMIC],
153159
)
154-
@pytest.mark.parametrize(
160+
@parametrize(
155161
"scaling_type_weight",
156162
[ScalingType.DYNAMIC],
157163
)
158-
@pytest.mark.parametrize(
164+
@parametrize(
159165
"scaling_type_grad_output",
160166
[ScalingType.DYNAMIC],
161167
)
162-
@pytest.mark.skipif(
163-
not is_sm_at_least_89(), reason="requires SM89 compatible machine"
164-
)
165-
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
168+
@unittest.skipIf(not is_sm_at_least_89(), reason="requires SM89 compatible machine")
169+
@unittest.skipIf(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
166170
def test_encoder_fw_bw_from_config_params(
167171
self,
168172
scaling_type_input: ScalingType,
@@ -177,17 +181,15 @@ def test_encoder_fw_bw_from_config_params(
177181
)
178182
self._test_impl(config)
179183

180-
@pytest.mark.parametrize(
184+
@parametrize(
181185
"recipe_name",
182186
[
183187
Float8LinearRecipeName.ROWWISE,
184188
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
185189
],
186190
)
187-
@pytest.mark.skipif(
188-
not is_sm_at_least_90(), reason="requires SM90 compatible machine"
189-
)
190-
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
191+
@unittest.skipIf(not is_sm_at_least_90(), reason="requires SM90 compatible machine")
192+
@unittest.skipIf(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
191193
def test_encoder_fw_bw_from_recipe(
192194
self,
193195
recipe_name: str,
@@ -196,5 +198,7 @@ def test_encoder_fw_bw_from_recipe(
196198
self._test_impl(config)
197199

198200

201+
instantiate_parametrized_tests(TestFloat8NumericsIntegrationTest)
202+
199203
if __name__ == "__main__":
200-
pytest.main([__file__])
204+
run_tests()

0 commit comments

Comments
 (0)