Skip to content

AssertionError for mutation on intermediates being reordered after return #2776

@shino16

Description

@shino16

Bug

On main, 9e12768

import torch, thunder, thunder.dynamo

def f():
    torch.randn(4, device='cuda').sin_()

jf = thunder.dynamo.thunderfx(f)
# or jf = thunder.jit(f, fusion_type="dataflow")
jf()
  File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 928, in fusion_pass
    fusedtrace = self.cse(fusedtrace)
                 ^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 821, in cse
    assert return_bsym.sym.id == prims.PrimIDs.RETURN
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

nvFuserExecutor fuses the trace into

# Constructed by Remove redundant casts (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation():
  # tmp/main.py:4: 	    torch.randn(4, device='cuda').sin_()
  t6 = torch.randn((4,), device=torch.device("cuda:0"), dtype=torch.float32)  # t6: "cuda:0 f32[4]"
    # t6 = ltorch.randn((4,), generator=None, dtype=torch.float32, device=torch.device("cuda:0"), layout=torch.strided, requires_grad=False, pin_memory=False, out=None)  # t6: "cuda:0 f32[4]"
      # t6 = prims.randn((4,), device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t6: "cuda:0 f32[4]"
  return {'output': (), 'flat_args': []}
  (t7,) = update_aliases((t6,))
  nvFusion0(t7)
    # t1 = prims.sin(t7)  # t1: "cuda:0 f32[4]"
    # t2 = prims.copy_(t1, t7, grad_enabled=True)  # t2: "cuda:0 f32[4]"

sin_ here on intermediate t7 is meaningless, but it survives through DCE because prims.copy has DONT_DCE tag. (Afterwards it gets fused into nvFusion0, which gets DCE'd.) nvFuser0 will not be DCE'd because it has copy_ as a subsymbol.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions