Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def _test_default_impls_helper(self, lp_tensor, lp_tensor_for_copy):
self.assertTrue(torch.equal(lp_tensor.qdata, reconstructed.qdata))
self.assertEqual(lp_tensor.attr, reconstructed.attr)

# test _get_to_kwargs
_ = lp_tensor._get_to_kwargs(torch.strided, device="cuda")
_ = lp_tensor._get_to_kwargs(layout=torch.strided, device="cuda")
Comment on lines +148 to +149
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, testing the higher level call to to op directly might be better here, lp_tensor.to(...)


# `to` / `_to_copy`
original_device = lp_tensor.device
lp_tensor = lp_tensor.to("cuda")
Expand Down
4 changes: 1 addition & 3 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,9 +717,7 @@ def _get_tensor_impl_constructor(

def _get_to_kwargs(self, *args, **kwargs):
# `torch._C._nn._parse_to` can't handle `layout` argument
for arg in args:
if isinstance(arg, torch.layout):
args.remove(arg)
args = tuple(arg for arg in args if not isinstance(arg, torch.layout))
if "layout" in kwargs:
kwargs.pop("layout")
# ignoring `non_blocking` and `memory_format` args since these are not
Expand Down
Loading