|
3 | 3 | #
|
4 | 4 | # This source code is licensed under the BSD 3-Clause license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 |
| -import pytest |
| 6 | +from unittest import skipIf |
| 7 | + |
7 | 8 | import torch
|
| 9 | +from torch.testing._internal.common_utils import ( |
| 10 | + TestCase, |
| 11 | + instantiate_parametrized_tests, |
| 12 | + parametrize, |
| 13 | + run_tests, |
| 14 | +) |
8 | 15 | from torch.utils._triton import has_triton
|
9 | 16 |
|
10 | 17 | from torchao.dtypes.uintx.bitpacking import pack, pack_cpu, unpack, unpack_cpu
|
|
13 | 20 | dimensions = (0, -1, 1)
|
14 | 21 |
|
15 | 22 |
|
16 |
| -@pytest.fixture(autouse=True) |
17 |
| -def run_before_and_after_tests(): |
18 |
| - yield |
19 |
| - torch._dynamo.reset() # reset cache between tests |
20 |
| - |
21 |
| - |
22 |
| -@pytest.mark.parametrize("bit_width", bit_widths) |
23 |
| -@pytest.mark.parametrize("dim", dimensions) |
24 |
| -def test_CPU(bit_width, dim): |
25 |
| - test_tensor = torch.randint( |
26 |
| - 0, 2**bit_width, (32, 32, 32), dtype=torch.uint8, device="cpu" |
27 |
| - ) |
28 |
| - packed = pack_cpu(test_tensor, bit_width, dim=dim) |
29 |
| - unpacked = unpack_cpu(packed, bit_width, dim=dim) |
30 |
| - assert unpacked.allclose(test_tensor) |
| 23 | +class TestBitpacking(TestCase): |
| 24 | + def tearDown(self): |
| 25 | + torch._dynamo.reset() # reset cache between tests |
31 | 26 |
|
| 27 | + @parametrize("bit_width", bit_widths) |
| 28 | + @parametrize("dim", dimensions) |
| 29 | + def test_CPU(self, bit_width, dim): |
| 30 | + test_tensor = torch.randint( |
| 31 | + 0, 2**bit_width, (32, 32, 32), dtype=torch.uint8, device="cpu" |
| 32 | + ) |
| 33 | + packed = pack_cpu(test_tensor, bit_width, dim=dim) |
| 34 | + unpacked = unpack_cpu(packed, bit_width, dim=dim) |
| 35 | + assert unpacked.allclose(test_tensor) |
32 | 36 |
|
33 |
| -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
34 |
| -@pytest.mark.parametrize("bit_width", bit_widths) |
35 |
| -@pytest.mark.parametrize("dim", dimensions) |
36 |
| -def test_GPU(bit_width, dim): |
37 |
| - test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).cuda() |
38 |
| - packed = pack(test_tensor, bit_width, dim=dim) |
39 |
| - unpacked = unpack(packed, bit_width, dim=dim) |
40 |
| - assert unpacked.allclose(test_tensor) |
| 37 | + @skipIf(not torch.cuda.is_available(), "CUDA not available") |
| 38 | + @parametrize("bit_width", bit_widths) |
| 39 | + @parametrize("dim", dimensions) |
| 40 | + def test_GPU(self, bit_width, dim): |
| 41 | + test_tensor = torch.randint( |
| 42 | + 0, 2**bit_width, (32, 32, 32), dtype=torch.uint8 |
| 43 | + ).cuda() |
| 44 | + packed = pack(test_tensor, bit_width, dim=dim) |
| 45 | + unpacked = unpack(packed, bit_width, dim=dim) |
| 46 | + assert unpacked.allclose(test_tensor) |
41 | 47 |
|
| 48 | + @skipIf(not torch.cuda.is_available(), reason="CUDA not available") |
| 49 | + @skipIf(not has_triton(), reason="unsupported without triton") |
| 50 | + @parametrize("bit_width", bit_widths) |
| 51 | + @parametrize("dim", dimensions) |
| 52 | + def test_compile(self, bit_width, dim): |
| 53 | + torch._dynamo.config.specialize_int = True |
| 54 | + torch.compile(pack, fullgraph=True) |
| 55 | + torch.compile(unpack, fullgraph=True) |
| 56 | + test_tensor = torch.randint( |
| 57 | + 0, 2**bit_width, (32, 32, 32), dtype=torch.uint8 |
| 58 | + ).cuda() |
| 59 | + packed = pack(test_tensor, bit_width, dim=dim) |
| 60 | + unpacked = unpack(packed, bit_width, dim=dim) |
| 61 | + assert unpacked.allclose(test_tensor) |
42 | 62 |
|
43 |
| -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
44 |
| -@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") |
45 |
| -@pytest.mark.parametrize("bit_width", bit_widths) |
46 |
| -@pytest.mark.parametrize("dim", dimensions) |
47 |
| -def test_compile(bit_width, dim): |
48 |
| - torch._dynamo.config.specialize_int = True |
49 |
| - torch.compile(pack, fullgraph=True) |
50 |
| - torch.compile(unpack, fullgraph=True) |
51 |
| - test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).cuda() |
52 |
| - packed = pack(test_tensor, bit_width, dim=dim) |
53 |
| - unpacked = unpack(packed, bit_width, dim=dim) |
54 |
| - assert unpacked.allclose(test_tensor) |
| 63 | + # these test cases are for the example pack walk through in the bitpacking.py file |
| 64 | + @skipIf(not torch.cuda.is_available(), "CUDA not available") |
| 65 | + def test_pack_example(self): |
| 66 | + test_tensor = torch.tensor( |
| 67 | + [0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8 |
| 68 | + ).cuda() |
| 69 | + shard_4, shard_2 = pack(test_tensor, 6) |
| 70 | + print(shard_4, shard_2) |
| 71 | + assert ( |
| 72 | + torch.tensor([0, 105, 151, 37], dtype=torch.uint8).cuda().allclose(shard_4) |
| 73 | + ) |
| 74 | + assert torch.tensor([39, 146], dtype=torch.uint8).cuda().allclose(shard_2) |
| 75 | + unpacked = unpack([shard_4, shard_2], 6) |
| 76 | + assert unpacked.allclose(test_tensor) |
55 | 77 |
|
| 78 | + def test_pack_example_CPU(self): |
| 79 | + test_tensor = torch.tensor( |
| 80 | + [0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8 |
| 81 | + ) |
| 82 | + shard_4, shard_2 = pack(test_tensor, 6) |
| 83 | + print(shard_4, shard_2) |
| 84 | + assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).allclose(shard_4) |
| 85 | + assert torch.tensor([39, 146], dtype=torch.uint8).allclose(shard_2) |
| 86 | + unpacked = unpack([shard_4, shard_2], 6) |
| 87 | + assert unpacked.allclose(test_tensor) |
56 | 88 |
|
57 |
| -# these test cases are for the example pack walk through in the bitpacking.py file |
58 |
| -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
59 |
| -def test_pack_example(): |
60 |
| - test_tensor = torch.tensor( |
61 |
| - [0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8 |
62 |
| - ).cuda() |
63 |
| - shard_4, shard_2 = pack(test_tensor, 6) |
64 |
| - print(shard_4, shard_2) |
65 |
| - assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).cuda().allclose(shard_4) |
66 |
| - assert torch.tensor([39, 146], dtype=torch.uint8).cuda().allclose(shard_2) |
67 |
| - unpacked = unpack([shard_4, shard_2], 6) |
68 |
| - assert unpacked.allclose(test_tensor) |
69 | 89 |
|
| 90 | +instantiate_parametrized_tests(TestBitpacking) |
70 | 91 |
|
71 |
| -def test_pack_example_CPU(): |
72 |
| - test_tensor = torch.tensor( |
73 |
| - [0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8 |
74 |
| - ) |
75 |
| - shard_4, shard_2 = pack(test_tensor, 6) |
76 |
| - print(shard_4, shard_2) |
77 |
| - assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).allclose(shard_4) |
78 |
| - assert torch.tensor([39, 146], dtype=torch.uint8).allclose(shard_2) |
79 |
| - unpacked = unpack([shard_4, shard_2], 6) |
80 |
| - assert unpacked.allclose(test_tensor) |
| 92 | +if __name__ == "__main__": |
| 93 | + run_tests() |
0 commit comments