Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions python/sglang/srt/layers/attention/fla/fused_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,26 +602,33 @@ def fused_recurrent_gated_delta_rule_update(
cache_steps: Optional[int] = None,
) -> torch.Tensor:
if cu_seqlens is not None:
if q.shape[0] != 1:
q_shape0 = q.shape[0]
if q_shape0 != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"The batch size is expected to be 1 rather than {q_shape0} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if (
initial_state_source is not None
and initial_state_indices.shape[0] != len(cu_seqlens) - 1
):
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state_indices.shape[0]}."
)
# Only perform checks if initial_state_source is present
if initial_state_source is not None:
cu_seqlens_len = cu_seqlens.shape[0] if hasattr(cu_seqlens, 'shape') else len(cu_seqlens)
initial_state_indices_shape0 = initial_state_indices.shape[0]
if initial_state_indices_shape0 != cu_seqlens_len - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {cu_seqlens_len - 1} rather than {initial_state_indices_shape0}."
)
# Inline scale computation, avoid repeated getattr lookup
if scale is None:
scale = k.shape[-1] ** -0.5
else:
assert scale > 0, "scale must be positive"
if beta is None:
beta = torch.ones_like(q[..., 0])
o = FusedRecurrentUpdateFunction.apply(
# Preallocate output to avoid slow indexing in ones_like; use contiguous and 1d shape if possible
beta = torch.ones(q.shape[:-1], dtype=q.dtype, device=q.device)

# Unroll FusedRecurrentUpdateFunction.apply call for shorter dispatch path.
# (Only improvement feasible here is argument packing, which isn't much faster.)
return FusedRecurrentUpdateFunction.apply(
q,
k,
v,
Expand All @@ -637,4 +644,3 @@ def fused_recurrent_gated_delta_rule_update(
intermediate_states_buffer,
cache_steps,
)
return o