-
Notifications
You must be signed in to change notification settings - Fork 109
Open
Labels
Description
🚀 Feature
Thunder needs to support torch.sym_max and torch.sym_min operations, which are used by TorchDynamo to represent max/min operations on symbolic values, particularly for sliding window attention patterns.
Current Behavior
Likely unsupported or falls back to eager evaluation.
Expected Behavior
Support symbolic max/min that preserves symbolic expressions:
max(Sym(s0 - 1055), 0)→Sym(Max(0, s0 - 1055))- Generate appropriate runtime checks
Example from torch.compile
sub: "Sym(s67 - 1056)" = l_kwargs_past_key_values_layers_0_cumulative_length - 1056
add_1: "Sym(s67 - 1055)" = sub + 1
sym_max: "Sym(Max(0, s67 - 1055))" = torch.sym_max(add_1, 0)This pattern is used for computing KV-cache offsets with sliding windows.
Minimal Reproduction Case
import torch
import thunder
@thunder.jit
def compute_kv_offset(cumulative_length: int, sliding_window: int = 1056) -> int:
"""Compute offset for sliding window attention."""
offset_raw = cumulative_length - sliding_window + 1
# Should support: torch.sym_max for symbolic values
kv_offset = max(offset_raw, 0) # or torch.sym_max(offset_raw, 0)
return kv_offset
assert compute_kv_offset(100) == 0 # max(100 - 1056 + 1, 0) = 0
assert compute_kv_offset(2000) == 945 # max(2000 - 1056 + 1, 0) = 945
print(compute_kv_offset._lc_cs.last_epilogue_traces[-1])
# def epilogue():
# return 945Use Case
Critical for HF transformer models with sliding window attention (e.g., Mistral, Llama with sliding window).
Success Criteria
- Support
torch.sym_max(symbolic, constant) - Support
torch.sym_min(symbolic, constant) - Support
max(symbolic, constant)as alias - Generated code produces correct runtime values
- Symbolic expression preserved for further optimization