-
Notifications
You must be signed in to change notification settings - Fork 338
fix: avoid removing from tuple in _get_to_kwargs #3018
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3018
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit dc8d607 with merge base 58c3064 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hey @jerryzh168 @huydhn , Please help to review this one,thank you. |
@orangeH25 thanks! can you add the test to |
Got it, I'll add the test |
Hey @jerryzh168 , the test case has been added. Please review it again. Thanks! |
test/test_utils.py
Outdated
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) | ||
|
||
def test__get_to_kwargs_with_layout(self): | ||
MyClass = TorchAOBaseTensor() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not the intended use for TorchAOBaseTensor, maybe you can add a test item in
Line 108 in 18dbe87
def _test_default_impls_helper(self, lp_tensor, lp_tensor_for_copy): |
_get_to_kwargs
is also a default impl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can probably expand the op test for to
:
Lines 148 to 152 in 18dbe87
original_device = lp_tensor.device | |
lp_tensor = lp_tensor.to("cuda") | |
self.assertEqual(lp_tensor.device.type, "cuda") | |
lp_tensor = lp_tensor.to(original_device) | |
self.assertEqual(lp_tensor.device, original_device) |
and also add an explicit test for _get_to_kwargs
as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I'll modify the test
10edb5f
to
dc8d607
Compare
_ = lp_tensor._get_to_kwargs(torch.strided, device="cuda") | ||
_ = lp_tensor._get_to_kwargs(layout=torch.strided, device="cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, testing the higher level call to to
op directly might be better here, lp_tensor.to(...)
I'll merge for now and follow up to change the call to |
This PR fixes a bug in
_get_to_kwargs
whereargs
is a tuple but the code attempted to callargs.remove(arg)
, leading to:Temporary test case in test/test_utils.py::TestTorchAOBaseTensor(unittest.TestCase):
Temporary test command: