Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 7, 2025

📄 30% (0.30x) speedup for rope in invokeai/backend/flux/math.py

⏱️ Runtime : 5.28 milliseconds 4.07 milliseconds (best of 186 runs)

📝 Explanation and details

The optimized code achieves a 29% speedup by replacing expensive tensor operations with more efficient alternatives:

Key Optimizations

1. Eliminated torch.einsum bottleneck (64% → 10% of runtime)

  • Replaced torch.einsum("...n,d->...nd", pos, omega) with direct broadcasting: pos.unsqueeze(-1) * omega
  • Broadcasting is significantly faster than einsum's generic tensor contraction logic
  • This single change accounts for most of the performance gain

2. Reduced trigonometric function calls

  • Original: Called cos and sin twice each within the stack operation
  • Optimized: Compute cos_out and sin_out once, then reuse them
  • Eliminates redundant trigonometric calculations

3. Faster tensor reshaping

  • Replaced einops.rearrange() with direct tensor.view() for the common case
  • .view() is a zero-copy operation that's faster than einops' more general reshaping logic
  • Falls back to rearrange only if needed (though this shouldn't happen in practice)

4. Reduced attribute access overhead

  • Cached pos.device and pos.dtype in variables to avoid repeated attribute lookups
  • Minor but consistent savings across all operations

Performance Impact

The optimizations are particularly effective for:

  • Large tensors: Tests with large batches (32×64) and dimensions (512) show 24-27% speedups
  • All tensor sizes: Even small tensors benefit significantly (27-37% faster)
  • Cross-device compatibility: Maintains the original MPS/CPU dtype selection logic while being faster

Based on the test results, this optimization provides consistent speedups across all input sizes and configurations, making it a valuable improvement for any workload using rotary position embeddings.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 37 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import pytest  # used for our unit tests
import torch
# function to test
# (copied from invokeai/backend/flux/math.py)
from einops import rearrange
from invokeai.backend.flux.math import rope
from torch import Tensor

# unit tests

# ---- BASIC TEST CASES ----

def test_rope_basic_shape_and_dtype():
    # Test that output shape and dtype are correct for basic input
    pos = torch.tensor([[0.0, 1.0], [2.0, 3.0]])  # shape (2,2)
    dim = 4
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 146μs -> 110μs (32.8% faster)

def test_rope_basic_values_known():
    # Test that rope returns expected values for a simple known case
    pos = torch.zeros((1, 1))  # shape (1,1)
    dim = 2
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 138μs -> 108μs (27.3% faster)
    # omega: [1.0, 1.0] (since scale=[0,1]/2 -> [0,0.5], theta^0=1, theta^0.5=100)
    # But with dim=2, scale=[0/2]=0, so omega=[1.0]
    # einsum: pos=0, omega=1.0 -> out=0
    # cos(0)=1, sin(0)=0
    # stack: [1,0,0,1]
    # rearrange: (1,1,1,2,2)
    expected = torch.tensor([[[[[1., 0.], [0., 1.]]]]])

def test_rope_basic_multiple_batch():
    # Test with batch dimension
    pos = torch.tensor([[0.0, 1.0], [2.0, 3.0]])  # shape (2,2)
    dim = 4
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 141μs -> 109μs (29.3% faster)

def test_rope_basic_float32_dtype():
    # Test with float32 dtype
    pos = torch.tensor([[0.0, 1.0]], dtype=torch.float32)
    dim = 4
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 139μs -> 107μs (29.1% faster)

# ---- EDGE TEST CASES ----

def test_rope_dim_not_even_raises():
    # dim must be even, otherwise assert triggers
    pos = torch.zeros((1, 1))
    dim = 3
    theta = 10000
    with pytest.raises(AssertionError):
        rope(pos, dim, theta) # 1.26μs -> 1.21μs (4.04% faster)


def test_rope_negative_theta():
    # Negative theta should not crash, but may produce nan/inf
    pos = torch.tensor([[1.0, 2.0]])
    dim = 2
    theta = -10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 162μs -> 131μs (23.8% faster)

def test_rope_theta_one():
    # theta=1, omega=1 for all
    pos = torch.tensor([[0.5, 1.0]])
    dim = 4
    theta = 1
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 141μs -> 109μs (29.0% faster)

def test_rope_large_theta():
    # Large theta should not crash
    pos = torch.tensor([[0.1, 0.2]])
    dim = 4
    theta = 1_000_000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 142μs -> 109μs (30.3% faster)

def test_rope_empty_pos():
    # Empty pos tensor should work (shape (0, n))
    pos = torch.empty((0, 2))
    dim = 4
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 133μs -> 102μs (30.5% faster)

def test_rope_single_dim():
    # Single dim (minimum allowed even)
    pos = torch.tensor([[1.0]])
    dim = 2
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 138μs -> 106μs (29.7% faster)

def test_rope_high_values():
    # Test with large values in pos
    pos = torch.tensor([[1e10, -1e10]])
    dim = 4
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 142μs -> 107μs (32.3% faster)

def test_rope_device_cpu():
    # Test that output stays on CPU
    pos = torch.tensor([[0.0, 1.0]])
    dim = 4
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 138μs -> 106μs (30.4% faster)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")

def test_rope_large_batch():
    # Test with large batch size
    pos = torch.ones((1000, 2))
    dim = 4
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 216μs -> 172μs (26.0% faster)

def test_rope_large_dim():
    # Test with large dim (max 1000 elements)
    pos = torch.ones((2, 2))
    dim = 1000
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 188μs -> 143μs (31.5% faster)

def test_rope_large_pos():
    # Test with large pos length
    pos = torch.ones((2, 1000))
    dim = 4
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 212μs -> 165μs (28.7% faster)

def test_rope_large_all():
    # Test with all large dimensions, but keep total size reasonable
    pos = torch.ones((8, 32))
    dim = 64
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 268μs -> 215μs (24.5% faster)

def test_rope_noncontiguous_input():
    # Test with non-contiguous input tensor
    pos = torch.ones((10, 10))[::2]  # shape (5,10), noncontiguous
    dim = 4
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 153μs -> 118μs (29.8% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import pytest  # used for our unit tests
import torch
from invokeai.backend.flux.math import rope

# unit tests

# ----------- BASIC TEST CASES -----------

def test_rope_basic_shape_and_dtype():
    # Test that the output shape and dtype are correct for a basic input
    pos = torch.tensor([[0, 1, 2]], dtype=torch.float32)  # shape (1, 3)
    dim = 4
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 164μs -> 127μs (28.9% faster)

def test_rope_basic_values():
    # Test that the output contains expected values for a simple input
    pos = torch.tensor([[0]], dtype=torch.float64)  # shape (1, 1)
    dim = 2
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 139μs -> 104μs (32.8% faster)
    # For pos=0, cos(0)=1, sin(0)=0
    # The output should be all 1s and 0s in the correct places
    # out shape: (1, 1, 1, 2, 2)
    # The first 2x2 matrix should be:
    # [[1, 0],
    #  [0, 1]]
    matrix = out[0,0,0]

def test_rope_multiple_positions():
    # Test with multiple positions and higher dim
    pos = torch.tensor([[0, 1, 2, 3]], dtype=torch.float64)  # shape (1, 4)
    dim = 6
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 138μs -> 102μs (34.2% faster)

def test_rope_batch_input():
    # Test with batch input (multiple sequences)
    pos = torch.tensor([[0, 1], [2, 3]], dtype=torch.float32)  # shape (2, 2)
    dim = 4
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 142μs -> 108μs (30.9% faster)

# ----------- EDGE TEST CASES -----------

def test_rope_dim_not_even():
    # Test that an odd dim raises an assertion error
    pos = torch.tensor([[0, 1]], dtype=torch.float32)
    dim = 3  # Not even
    theta = 10000
    with pytest.raises(AssertionError):
        rope(pos, dim, theta) # 1.26μs -> 1.23μs (2.28% faster)

def test_rope_dim_two():
    # Test with the smallest valid dim (2)
    pos = torch.tensor([[0, 1, 2]], dtype=torch.float64)
    dim = 2
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 143μs -> 110μs (30.4% faster)

def test_rope_theta_one():
    # Test with theta=1 (should not divide by zero, but omega=1 for all scale)
    pos = torch.tensor([[0, 1]], dtype=torch.float32)
    dim = 4
    theta = 1
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 140μs -> 107μs (30.9% faster)

def test_rope_negative_positions():
    # Test with negative positions
    pos = torch.tensor([[-1, 0, 1]], dtype=torch.float64)
    dim = 4
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 134μs -> 100μs (33.1% faster)

def test_rope_large_theta():
    # Test with a very large theta
    pos = torch.tensor([[0, 1]], dtype=torch.float32)
    dim = 4
    theta = int(1e9)
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 140μs -> 107μs (31.3% faster)

def test_rope_zero_positions():
    # Test with all positions zero
    pos = torch.zeros((2, 3), dtype=torch.float32)
    dim = 4
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 139μs -> 103μs (34.8% faster)
    # All cos(0)=1, sin(0)=0, so each 2x2 matrix should be identity
    for b in range(2):
        for n in range(3):
            for d in range(2):
                matrix = out[b, n, d]

def test_rope_empty_positions():
    # Test with empty positions tensor
    pos = torch.empty((0, 3), dtype=torch.float32)
    dim = 4
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 132μs -> 101μs (31.5% faster)


def test_rope_single_position():
    # Test with a single position value
    pos = torch.tensor([[5]], dtype=torch.float64)
    dim = 4
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 155μs -> 124μs (25.5% faster)

def test_rope_non_default_device():
    # Test on CUDA if available, otherwise skip
    if torch.cuda.is_available():
        pos = torch.tensor([[0, 1, 2]], dtype=torch.float32, device='cuda')
        dim = 4
        theta = 10000
        codeflash_output = rope(pos, dim, theta); out = codeflash_output

def test_rope_mps_device():
    # Test on MPS if available, otherwise skip
    if torch.backends.mps.is_available():
        pos = torch.tensor([[0, 1]], dtype=torch.float32, device='mps')
        dim = 4
        theta = 10000
        codeflash_output = rope(pos, dim, theta); out = codeflash_output

# ----------- LARGE SCALE TEST CASES -----------

def test_rope_large_batch_and_seq():
    # Test with large batch and sequence length, but < 1000 elements
    batch_size = 32
    seq_len = 64
    pos = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1).to(torch.float32)  # shape (32, 64)
    dim = 8
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 261μs -> 205μs (27.3% faster)

def test_rope_large_dim():
    # Test with large dim, but < 1000 elements
    pos = torch.tensor([[0, 1, 2]], dtype=torch.float32)
    dim = 512
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 172μs -> 136μs (25.8% faster)

def test_rope_max_elements():
    # Test with maximum allowed elements (batch * seq * dim < 1000)
    batch_size = 5
    seq_len = 10
    dim = 20
    pos = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1).to(torch.float32)
    codeflash_output = rope(pos, dim, 10000); out = codeflash_output # 138μs -> 102μs (36.2% faster)

def test_rope_performance_large():
    # Check that function runs in reasonable time for large input
    batch_size = 8
    seq_len = 64
    dim = 32
    pos = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1).to(torch.float32)
    codeflash_output = rope(pos, dim, 10000); out = codeflash_output # 247μs -> 198μs (24.5% faster)

def test_rope_dtype_preservation():
    # Test that dtype is preserved for float64 input
    pos = torch.tensor([[0, 1, 2]], dtype=torch.float64)
    dim = 4
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 146μs -> 111μs (31.7% faster)

def test_rope_large_float64():
    # Test with large float64 input
    pos = torch.arange(100).unsqueeze(0).to(torch.float64)  # shape (1, 100)
    dim = 10
    theta = 10000
    codeflash_output = rope(pos, dim, theta); out = codeflash_output # 130μs -> 94.8μs (37.8% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-rope-mhodz532 and push.

Codeflash Static Badge

The optimized code achieves a **29% speedup** by replacing expensive tensor operations with more efficient alternatives:

## Key Optimizations

**1. Eliminated `torch.einsum` bottleneck (64% → 10% of runtime)**
- Replaced `torch.einsum("...n,d->...nd", pos, omega)` with direct broadcasting: `pos.unsqueeze(-1) * omega`
- Broadcasting is significantly faster than einsum's generic tensor contraction logic
- This single change accounts for most of the performance gain

**2. Reduced trigonometric function calls**
- Original: Called `cos` and `sin` twice each within the stack operation
- Optimized: Compute `cos_out` and `sin_out` once, then reuse them
- Eliminates redundant trigonometric calculations

**3. Faster tensor reshaping**
- Replaced `einops.rearrange()` with direct `tensor.view()` for the common case
- `.view()` is a zero-copy operation that's faster than einops' more general reshaping logic
- Falls back to rearrange only if needed (though this shouldn't happen in practice)

**4. Reduced attribute access overhead**
- Cached `pos.device` and `pos.dtype` in variables to avoid repeated attribute lookups
- Minor but consistent savings across all operations

## Performance Impact

The optimizations are particularly effective for:
- **Large tensors**: Tests with large batches (32×64) and dimensions (512) show 24-27% speedups
- **All tensor sizes**: Even small tensors benefit significantly (27-37% faster)
- **Cross-device compatibility**: Maintains the original MPS/CPU dtype selection logic while being faster

Based on the test results, this optimization provides consistent speedups across all input sizes and configurations, making it a valuable improvement for any workload using rotary position embeddings.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 7, 2025 04:59
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant