From 9c010158218d35d2d8b9d2a8658e9d5750460b09 Mon Sep 17 00:00:00 2001 From: orangeH25 <18085625039@163.com> Date: Wed, 17 Sep 2025 07:47:22 +0000 Subject: [PATCH 1/4] fix: avoid removing from tuple in _get_to_kwargs --- torchao/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchao/utils.py b/torchao/utils.py index daf7eab83c..dac1e68e47 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -717,9 +717,7 @@ def _get_tensor_impl_constructor( def _get_to_kwargs(self, *args, **kwargs): # `torch._C._nn._parse_to` can't handle `layout` argument - for arg in args: - if isinstance(arg, torch.layout): - args.remove(arg) + args = tuple(arg for arg in args if not isinstance(arg, torch.layout)) if "layout" in kwargs: kwargs.pop("layout") # ignoring `non_blocking` and `memory_format` args since these are not From 9af9cd4e75e83bac8da8d941954dfbc3c6d70131 Mon Sep 17 00:00:00 2001 From: orangeH25 <18085625039@163.com> Date: Thu, 18 Sep 2025 06:40:29 +0000 Subject: [PATCH 2/4] Add test case of _get_to_kwargs --- test/test_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_utils.py b/test/test_utils.py index f06835c932..a5d50daa61 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -340,6 +340,9 @@ def __init__( ) self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) + def test__get_to_kwargs_with_layout(self): + MyClass = TorchAOBaseTensor() + MyClass._get_to_kwargs(torch.strided,device="cuda") if __name__ == "__main__": unittest.main() From 0ae76dbb4e268b463b6207850004b25feb0d1772 Mon Sep 17 00:00:00 2001 From: orangeH25 <18085625039@163.com> Date: Thu, 18 Sep 2025 06:44:41 +0000 Subject: [PATCH 3/4] Update test case of _get_to_kwargs --- test/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index a5d50daa61..edec27caa3 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -342,7 +342,7 @@ def __init__( def test__get_to_kwargs_with_layout(self): MyClass = TorchAOBaseTensor() - MyClass._get_to_kwargs(torch.strided,device="cuda") + MyClass._get_to_kwargs(torch.strided, device="cuda") if __name__ == "__main__": unittest.main() From dc8d60784e200e96dfdee82c38a7beb5d8c86917 Mon Sep 17 00:00:00 2001 From: orangeH25 <18085625039@163.com> Date: Fri, 19 Sep 2025 08:31:42 +0000 Subject: [PATCH 4/4] Add test case for _get_to_kwargs in _test_default_impls_helper method --- test/test_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index edec27caa3..9f93e445bc 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -144,6 +144,10 @@ def _test_default_impls_helper(self, lp_tensor, lp_tensor_for_copy): self.assertTrue(torch.equal(lp_tensor.qdata, reconstructed.qdata)) self.assertEqual(lp_tensor.attr, reconstructed.attr) + # test _get_to_kwargs + _ = lp_tensor._get_to_kwargs(torch.strided, device="cuda") + _ = lp_tensor._get_to_kwargs(layout=torch.strided, device="cuda") + # `to` / `_to_copy` original_device = lp_tensor.device lp_tensor = lp_tensor.to("cuda") @@ -340,9 +344,6 @@ def __init__( ) self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) - def test__get_to_kwargs_with_layout(self): - MyClass = TorchAOBaseTensor() - MyClass._get_to_kwargs(torch.strided, device="cuda") if __name__ == "__main__": unittest.main()