Skip to content

Commit 10edb5f

Browse files
committed
Add test case for in method
1 parent 0ae76db commit 10edb5f

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

test/test_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ def _test_default_impls_helper(self, lp_tensor, lp_tensor_for_copy):
144144
self.assertTrue(torch.equal(lp_tensor.qdata, reconstructed.qdata))
145145
self.assertEqual(lp_tensor.attr, reconstructed.attr)
146146

147+
# test _get_to_kwargs
148+
_ = lp_tensor._get_to_kwargs(torch.strided, device="cuda")
149+
_ = lp_tensor._get_to_kwargs(layout=torch.strided, device="cuda")
150+
147151
# `to` / `_to_copy`
148152
original_device = lp_tensor.device
149153
lp_tensor = lp_tensor.to("cuda")
@@ -340,9 +344,6 @@ def __init__(
340344
)
341345
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
342346

343-
def test__get_to_kwargs_with_layout(self):
344-
MyClass = TorchAOBaseTensor()
345-
MyClass._get_to_kwargs(torch.strided, device="cuda")
346347

347348
if __name__ == "__main__":
348349
unittest.main()

0 commit comments

Comments
 (0)