-
Notifications
You must be signed in to change notification settings - Fork 108
Open
Description
Bug
On main, 9e12768
import torch, thunder, thunder.dynamo
def f():
torch.randn(4, device='cuda').sin_()
jf = thunder.dynamo.thunderfx(f)
# or jf = thunder.jit(f, fusion_type="dataflow")
jf() File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 928, in fusion_pass
fusedtrace = self.cse(fusedtrace)
^^^^^^^^^^^^^^^^^^^^
File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 821, in cse
assert return_bsym.sym.id == prims.PrimIDs.RETURN
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
nvFuserExecutor fuses the trace into
# Constructed by Remove redundant casts (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation():
# tmp/main.py:4: torch.randn(4, device='cuda').sin_()
t6 = torch.randn((4,), device=torch.device("cuda:0"), dtype=torch.float32) # t6: "cuda:0 f32[4]"
# t6 = ltorch.randn((4,), generator=None, dtype=torch.float32, device=torch.device("cuda:0"), layout=torch.strided, requires_grad=False, pin_memory=False, out=None) # t6: "cuda:0 f32[4]"
# t6 = prims.randn((4,), device=devices.Device("cuda:0"), dtype=dtypes.float32) # t6: "cuda:0 f32[4]"
return {'output': (), 'flat_args': []}
(t7,) = update_aliases((t6,))
nvFusion0(t7)
# t1 = prims.sin(t7) # t1: "cuda:0 f32[4]"
# t2 = prims.copy_(t1, t7, grad_enabled=True) # t2: "cuda:0 f32[4]"sin_ here on intermediate t7 is meaningless, but it survives through DCE because prims.copy has DONT_DCE tag. (Afterwards it gets fused into nvFusion0, which gets DCE'd.)nvFuser0 will not be DCE'd because it has copy_ as a subsymbol.
Metadata
Metadata
Assignees
Labels
No labels