-
Notifications
You must be signed in to change notification settings - Fork 108
Open
Description
lightning-thunder/thunder/tests/test_update_aliases.py
Lines 122 to 135 in 9e12768
| @ops(_inplace_opinfos, supported_dtypes=(dtypes.float32,)) | |
| def test_update_aliases(op, device, dtype, executor, _): | |
| sample = next(op.sample_inputs(device, dtype)) | |
| # The sample generator is the one for `polygamma`. | |
| # `polygamma` expects an int as its first argument and a tensor as its second but | |
| # `polygamma_` wants opposite; tensor first, int second. | |
| args = list(sample.args) | |
| if op.name == "polygamma_": | |
| args[0], args[1] = args[1], args[0] | |
| j_op = executor.make_callable(op.torch_reference) | |
| actual = j_op(*args, **sample.kwargs) | |
| expected = op.torch_reference(*args, **sample.kwargs) | |
| torch.testing.assert_close(actual, expected, equal_nan=True) |
This is a test for in-place ops such as Tensor.square_, Tensor.abs_, etc.
lightning-thunder/thunder/tests/test_update_aliases.py
Lines 132 to 135 in 9e12768
| j_op = executor.make_callable(op.torch_reference) | |
| actual = j_op(*args, **sample.kwargs) | |
| expected = op.torch_reference(*args, **sample.kwargs) | |
| torch.testing.assert_close(actual, expected, equal_nan=True) |
op.torch_reference is precisely those ops like Tensor.square_, so it returns the same tensor as its first argument. It ends up being mutated by both jitted op and eager op, and torch.testing.assert_close(actual, expected) is trivially True.
You can verify this with this patch:
diff --git a/thunder/tests/test_update_aliases.py b/thunder/tests/test_update_aliases.py
index 7cafcbec..e5fe7b8c 100644
--- a/thunder/tests/test_update_aliases.py
+++ b/thunder/tests/test_update_aliases.py
@@ -132,6 +132,7 @@ def test_update_aliases(op, device, dtype, executor, _):
j_op = executor.make_callable(op.torch_reference)
actual = j_op(*args, **sample.kwargs)
expected = op.torch_reference(*args, **sample.kwargs)
+ assert id(args[0]) == id(actual) == id(expected), f"{id(args[0]) = }, {id(actual) = }, {id(expected) = }"
torch.testing.assert_close(actual, expected, equal_nan=True)This assertion passes on all ops except for the case op = tanhshrink (which is also strange because tanhshrink is not even in-place).
Metadata
Metadata
Assignees
Labels
No labels