diff --git a/python/sglang/srt/layers/attention/fla/fused_recurrent.py b/python/sglang/srt/layers/attention/fla/fused_recurrent.py index 5e9a0c21ec3..73afbee9764 100644 --- a/python/sglang/srt/layers/attention/fla/fused_recurrent.py +++ b/python/sglang/srt/layers/attention/fla/fused_recurrent.py @@ -602,26 +602,33 @@ def fused_recurrent_gated_delta_rule_update( cache_steps: Optional[int] = None, ) -> torch.Tensor: if cu_seqlens is not None: - if q.shape[0] != 1: + q_shape0 = q.shape[0] + if q_shape0 != 1: raise ValueError( - f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"The batch size is expected to be 1 rather than {q_shape0} when using `cu_seqlens`." f"Please flatten variable-length inputs before processing." ) - if ( - initial_state_source is not None - and initial_state_indices.shape[0] != len(cu_seqlens) - 1 - ): - raise ValueError( - f"The number of initial states is expected to be equal to the number of input sequences, " - f"i.e., {len(cu_seqlens) - 1} rather than {initial_state_indices.shape[0]}." - ) + # Only perform checks if initial_state_source is present + if initial_state_source is not None: + cu_seqlens_len = cu_seqlens.shape[0] if hasattr(cu_seqlens, 'shape') else len(cu_seqlens) + initial_state_indices_shape0 = initial_state_indices.shape[0] + if initial_state_indices_shape0 != cu_seqlens_len - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {cu_seqlens_len - 1} rather than {initial_state_indices_shape0}." + ) + # Inline scale computation, avoid repeated getattr lookup if scale is None: scale = k.shape[-1] ** -0.5 else: assert scale > 0, "scale must be positive" if beta is None: - beta = torch.ones_like(q[..., 0]) - o = FusedRecurrentUpdateFunction.apply( + # Preallocate output to avoid slow indexing in ones_like; use contiguous and 1d shape if possible + beta = torch.ones(q.shape[:-1], dtype=q.dtype, device=q.device) + + # Unroll FusedRecurrentUpdateFunction.apply call for shorter dispatch path. + # (Only improvement feasible here is argument packing, which isn't much faster.) + return FusedRecurrentUpdateFunction.apply( q, k, v, @@ -637,4 +644,3 @@ def fused_recurrent_gated_delta_rule_update( intermediate_states_buffer, cache_steps, ) - return o