⚡️ Speed up function fused_recurrent_gated_delta_rule_update by 72%
#326
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 72% (0.72x) speedup for
fused_recurrent_gated_delta_rule_updateinpython/sglang/srt/layers/attention/fla/fused_recurrent.py⏱️ Runtime :
28.0 microseconds→16.2 microseconds(best of14runs)📝 Explanation and details
The optimized code achieves a 72% speedup by eliminating redundant attribute lookups and improving tensor creation efficiency. The key optimizations are:
1. Cached Attribute Lookups
q.shape[0]inq_shape0variable instead of accessing it multiple times in error messagescu_seqlens.shape[0]andinitial_state_indices.shape[0]to avoid repeated.shapeattribute lookups2. Optimized Beta Tensor Creation
torch.ones_like(q[..., 0])withtorch.ones(q.shape[:-1], dtype=q.dtype, device=q.device)q[..., 0]) then allocates a new tensor withones_like3. Streamlined Validation Logic
initial_state_source is not Nonecheck earlier to avoid unnecessarylen()calculations when there are no initial stateshasattr()check for safer attribute access oncu_seqlensImpact on Workloads:
The function is called in the hot path of attention mechanisms for hybrid linear attention models, as shown in the function reference where it's called within
forward_extend()during model inference. Given that this function can be called thousands of times during sequence generation, the 72% speedup translates to meaningful wall-clock time savings.The optimizations particularly benefit test cases with parameter validation (showing 166-175% improvements in edge case tests) while maintaining identical functionality and error handling behavior.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
To edit these changes
git checkout codeflash/optimize-fused_recurrent_gated_delta_rule_update-mhounpu2and push.