|
15 | 15 |
|
16 | 16 | import pytest |
17 | 17 | import torch |
18 | | -from compressed_tensors.utils import safe_permute |
19 | | -from compressed_tensors.utils.permute import _EXPERIMENTAL_DTYPES |
| 18 | +from compressed_tensors.utils.permute import safe_permute |
| 19 | +from tests.testing_utils import requires_gpu |
20 | 20 |
|
21 | 21 |
|
| 22 | +@requires_gpu |
22 | 23 | @pytest.mark.parametrize( |
23 | | - "dtype,device,exp_experimental", |
| 24 | + "dtype", |
24 | 25 | [ |
25 | | - (torch.int8, torch.device("cpu"), False), |
26 | | - (torch.int16, torch.device("cpu"), False), |
27 | | - (torch.int32, torch.device("cpu"), False), |
28 | | - (torch.int64, torch.device("cpu"), False), |
29 | | - (torch.float16, torch.device("cpu"), False), |
30 | | - (torch.float32, torch.device("cpu"), False), |
31 | | - (torch.float64, torch.device("cpu"), False), |
32 | | - (torch.float8_e4m3fn, torch.device("cpu"), True), |
| 26 | + torch.int8, |
| 27 | + torch.int16, |
| 28 | + torch.int32, |
| 29 | + torch.bfloat16, |
| 30 | + torch.float16, |
| 31 | + torch.float32, |
| 32 | + torch.float64, |
| 33 | + torch.float8_e4m3fn, |
33 | 34 | ], |
34 | 35 | ) |
35 | | -def test_safe_permute(dtype: torch.dtype, device: str, exp_experimental: bool): |
36 | | - # some dtypes do not support arange initialization |
37 | | - tensor = torch.tensor([0, 1, 2, 3], dtype=dtype, device=device) |
38 | | - perm = torch.tensor([3, 1, 0, 2]) |
39 | | - expected = torch.tensor([3, 1, 0, 2], dtype=dtype, device=device) |
| 36 | +@pytest.mark.parametrize( |
| 37 | + "device", [torch.device("cpu"), torch.device("cuda"), torch.device("meta")] |
| 38 | +) |
| 39 | +def test_safe_permute(dtype: torch.dtype, device: torch.device): |
| 40 | + value = torch.tensor([[0, 1, 2, 3]], dtype=dtype, device=device) |
| 41 | + perm = torch.tensor([3, 1, 0, 2], device=device) |
40 | 42 |
|
41 | | - result = safe_permute(tensor, perm, dim=0) |
| 43 | + result = safe_permute(value, perm, dim=-1) |
42 | 44 |
|
43 | | - if exp_experimental: |
44 | | - assert (dtype, device) in _EXPERIMENTAL_DTYPES |
45 | | - assert all(result == expected) |
| 45 | + if device.type != "meta": |
| 46 | + assert torch.equal(result.squeeze(0), perm.to(result.dtype)) |
0 commit comments