diff --git a/src/compressed_tensors/utils/__init__.py b/src/compressed_tensors/utils/__init__.py index 8763e6ee..cb2cb491 100644 --- a/src/compressed_tensors/utils/__init__.py +++ b/src/compressed_tensors/utils/__init__.py @@ -18,7 +18,6 @@ from .match import * from .offload import * from .permutations_24 import * -from .permute import * from .safetensors_load import * from .semi_structured_conversions import * from .type import * diff --git a/src/compressed_tensors/utils/permute.py b/src/compressed_tensors/utils/permute.py deleted file mode 100644 index 86a0ee80..00000000 --- a/src/compressed_tensors/utils/permute.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from compressed_tensors.utils.helpers import deprecated - - -__all__ = ["safe_permute"] - - -@deprecated("Tensor.index_select") -def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch.Tensor: - """ - Perform out-of-place permutation without using torch.Tensor.index_put_, - whose implementation is missing for datatypes such as `torch.float8_e4m3fn` - - :param value: tensor to permute - :param perm: permutation map - :param dim: dimension along which to apply permutation - :return: permuted value - """ - return value.index_select(dim, perm) diff --git a/tests/test_quantization/lifecycle/test_helpers.py b/tests/test_quantization/lifecycle/test_helpers.py deleted file mode 100644 index 20fd39da..00000000 --- a/tests/test_quantization/lifecycle/test_helpers.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import pytest -import torch -from compressed_tensors.utils.permute import safe_permute -from tests.testing_utils import requires_gpu - - -@requires_gpu -@pytest.mark.unit -@pytest.mark.filterwarnings("ignore::DeprecationWarning") -@pytest.mark.parametrize( - "dtype", - [ - torch.int8, - torch.int16, - torch.int32, - torch.int64, - torch.bfloat16, - torch.float16, - torch.float32, - torch.float64, - torch.float8_e4m3fn, - ], -) -@pytest.mark.parametrize( - "device", [torch.device("cpu"), torch.device("cuda"), torch.device("meta")] -) -def test_safe_permute(dtype: torch.dtype, device: torch.device): - value = torch.tensor([[0, 1, 2, 3]], dtype=dtype, device=device) - perm = torch.tensor([3, 1, 0, 2], device=device) - - result = safe_permute(value, perm, dim=-1) - - if device.type != "meta": - assert torch.equal(result.squeeze(0), perm.to(result.dtype))