Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 7, 2025

📄 16% (0.16x) speedup for transform_select_experts_inputs in python/sglang/srt/eplb/expert_location_dispatch.py

⏱️ Runtime : 269 microseconds 231 microseconds (best of 250 runs)

📝 Explanation and details

The optimization replaces torch.zeros_like(correction_bias) with correction_bias.zero_() on line 29. This is a micro-optimization that eliminates tensor allocation by performing an in-place operation instead of creating a new zero tensor.

Key change: When the dispatch algorithm is "fake" and correction_bias exists, the original code creates a new tensor with torch.zeros_like() and assigns it to the local variable. The optimized version directly zeros the existing tensor in-place using .zero_().

Why it's faster: In-place operations avoid memory allocation overhead and tensor creation costs. torch.zeros_like() must allocate new memory and initialize it, while .zero_() simply writes zeros to existing memory locations.

Performance impact: The 16% speedup is most pronounced in test cases involving the "fake" algorithm path with non-None correction_bias tensors, where speedups range from 22-42%. The optimization has minimal impact on other code paths since they don't execute this line.

Hot path relevance: This function is called from select_experts() in the MoE (Mixture of Experts) layer, which is executed during model inference for expert routing. Since MoE models can process many tokens through expert selection, even micro-optimizations in tensor operations can accumulate to meaningful performance gains during inference workloads.

Test case benefits: The optimization particularly benefits scenarios with larger tensors and when using the "fake" dispatch algorithm, as seen in tests with 1000-element tensors showing 30%+ improvements.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 30 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
from typing import Optional

# imports
import pytest  # used for our unit tests
import torch
from sglang.srt.eplb.expert_location_dispatch import \
    transform_select_experts_inputs


# --- Dummy class for ExpertLocationDispatchInfo ---
class ExpertLocationDispatchInfo:
    def __init__(self, ep_dispatch_algorithm: Optional[str] = None):
        self.ep_dispatch_algorithm = ep_dispatch_algorithm
from sglang.srt.eplb.expert_location_dispatch import \
    transform_select_experts_inputs

# --- Unit tests ---

# 1. Basic Test Cases

def test_basic_no_info_no_bias():
    # Basic: info is None, correction_bias is None, router_logits is a tensor
    router_logits = torch.tensor([1.0, 2.0, 3.0])
    result_logits, result_bias = transform_select_experts_inputs(router_logits.clone(), None, None) # 448ns -> 407ns (10.1% faster)

def test_basic_info_non_fake():
    # Basic: info is not None, but ep_dispatch_algorithm is not 'fake'
    router_logits = torch.tensor([1.0, 2.0, 3.0])
    correction_bias = torch.tensor([0.1, 0.2, 0.3])
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="real")
    result_logits, result_bias = transform_select_experts_inputs(router_logits.clone(), correction_bias.clone(), info) # 629ns -> 624ns (0.801% faster)

def test_basic_info_fake_with_bias():
    # Basic: info is 'fake', correction_bias is provided
    router_logits = torch.tensor([1.0, 2.0, 3.0])
    correction_bias = torch.tensor([0.1, 0.2, 0.3])
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    result_logits, result_bias = transform_select_experts_inputs(router_logits.clone(), correction_bias.clone(), info) # 10.7μs -> 8.23μs (29.6% faster)

def test_basic_info_fake_no_bias():
    # Basic: info is 'fake', correction_bias is None
    router_logits = torch.tensor([1.0, 2.0, 3.0])
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    result_logits, result_bias = transform_select_experts_inputs(router_logits.clone(), None, info) # 5.31μs -> 5.19μs (2.39% faster)

# 2. Edge Test Cases

def test_empty_router_logits():
    # Edge: router_logits is an empty tensor
    router_logits = torch.tensor([])
    correction_bias = torch.tensor([])
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    result_logits, result_bias = transform_select_experts_inputs(router_logits.clone(), correction_bias.clone(), info) # 8.97μs -> 6.74μs (33.1% faster)

def test_single_element_tensor():
    # Edge: router_logits and correction_bias are single element tensors
    router_logits = torch.tensor([2.0])
    correction_bias = torch.tensor([0.5])
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    result_logits, result_bias = transform_select_experts_inputs(router_logits.clone(), correction_bias.clone(), info) # 8.53μs -> 6.44μs (32.4% faster)

def test_none_info_object():
    # Edge: info is None, should not modify tensors
    router_logits = torch.tensor([1.0, 2.0])
    correction_bias = torch.tensor([0.1, 0.2])
    result_logits, result_bias = transform_select_experts_inputs(router_logits.clone(), correction_bias.clone(), None) # 398ns -> 413ns (3.63% slower)

def test_info_missing_attribute():
    # Edge: info object without ep_dispatch_algorithm attribute
    class DummyInfo:
        pass
    info = DummyInfo()
    router_logits = torch.tensor([1.0, 2.0])
    correction_bias = torch.tensor([0.1, 0.2])
    # Should raise AttributeError
    with pytest.raises(AttributeError):
        transform_select_experts_inputs(router_logits.clone(), correction_bias.clone(), info) # 1.34μs -> 1.27μs (5.36% faster)

def test_correction_bias_none_with_fake():
    # Edge: correction_bias is None, info is 'fake'
    router_logits = torch.tensor([1.0, 2.0])
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    result_logits, result_bias = transform_select_experts_inputs(router_logits.clone(), None, info) # 6.46μs -> 6.32μs (2.26% faster)

def test_high_dimensional_tensor():
    # Edge: tensors with more than 2 dimensions
    router_logits = torch.ones((2, 3, 4))
    correction_bias = torch.ones((2, 3, 4))
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    result_logits, result_bias = transform_select_experts_inputs(router_logits.clone(), correction_bias.clone(), info) # 9.10μs -> 6.84μs (32.9% faster)

def test_non_float_tensor():
    # Edge: router_logits is integer type, should raise error on uniform_
    router_logits = torch.tensor([1, 2, 3], dtype=torch.int32)
    correction_bias = torch.tensor([1, 2, 3], dtype=torch.int32)
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    # uniform_ only works on floating point tensors
    with pytest.raises(RuntimeError):
        transform_select_experts_inputs(router_logits.clone(), correction_bias.clone(), info) # 55.1μs -> 54.5μs (1.10% faster)

def test_router_logits_shared_memory():
    # Edge: router_logits is shared memory (simulate with .clone())
    router_logits = torch.tensor([1.0, 2.0, 3.0])
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    original = router_logits.clone()
    result_logits, _ = transform_select_experts_inputs(router_logits, None, info) # 5.95μs -> 5.79μs (2.59% faster)

# 3. Large Scale Test Cases

def test_large_tensor_fake_algorithm():
    # Large: router_logits and correction_bias are large tensors
    size = 1000  # keep under 100MB
    router_logits = torch.ones(size)
    correction_bias = torch.ones(size)
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    result_logits, result_bias = transform_select_experts_inputs(router_logits.clone(), correction_bias.clone(), info) # 11.7μs -> 9.00μs (30.0% faster)

def test_large_tensor_non_fake_algorithm():
    # Large: router_logits and correction_bias are large tensors, but algorithm is not 'fake'
    size = 1000
    router_logits = torch.arange(size, dtype=torch.float32)
    correction_bias = torch.arange(size, dtype=torch.float32)
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="other")
    result_logits, result_bias = transform_select_experts_inputs(router_logits.clone(), correction_bias.clone(), info) # 578ns -> 529ns (9.26% faster)

def test_large_2d_tensor_fake_algorithm():
    # Large: 2D tensors
    router_logits = torch.ones((100, 10))
    correction_bias = torch.ones((100, 10))
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    result_logits, result_bias = transform_select_experts_inputs(router_logits.clone(), correction_bias.clone(), info) # 12.8μs -> 9.89μs (29.4% faster)

def test_large_tensor_none_bias():
    # Large: router_logits is large, correction_bias is None
    router_logits = torch.ones(999)
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    result_logits, result_bias = transform_select_experts_inputs(router_logits.clone(), None, info) # 7.30μs -> 7.27μs (0.344% faster)

def test_large_tensor_empty_bias():
    # Large: router_logits is large, correction_bias is empty
    router_logits = torch.ones(999)
    correction_bias = torch.tensor([])
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    result_logits, result_bias = transform_select_experts_inputs(router_logits.clone(), correction_bias.clone(), info) # 11.1μs -> 9.05μs (22.4% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import pytest  # used for our unit tests
import torch
from sglang.srt.eplb.expert_location_dispatch import \
    transform_select_experts_inputs


# function to test
class ExpertLocationDispatchInfo:
    """Mock class for info parameter with ep_dispatch_algorithm attribute."""
    def __init__(self, ep_dispatch_algorithm):
        self.ep_dispatch_algorithm = ep_dispatch_algorithm
from sglang.srt.eplb.expert_location_dispatch import \
    transform_select_experts_inputs

# unit tests

# 1. Basic Test Cases

def test_no_info_returns_inputs_unchanged():
    # If info is None, router_logits and correction_bias should be unchanged
    router_logits = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
    correction_bias = torch.tensor([0.5, 0.5])
    orig_router = router_logits.clone()
    orig_bias = correction_bias.clone()
    out_logits, out_bias = transform_select_experts_inputs(router_logits, correction_bias, None) # 392ns -> 373ns (5.09% faster)

def test_info_non_fake_algorithm_returns_inputs_unchanged():
    # If info.ep_dispatch_algorithm != "fake", no change
    router_logits = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
    correction_bias = torch.tensor([0.5, 0.5])
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="real")
    orig_router = router_logits.clone()
    orig_bias = correction_bias.clone()
    out_logits, out_bias = transform_select_experts_inputs(router_logits, correction_bias, info) # 591ns -> 585ns (1.03% faster)

def test_info_fake_algorithm_changes_logits_and_bias():
    # If info.ep_dispatch_algorithm == "fake", router_logits should be uniform in [5, 10), bias zeroed
    router_logits = torch.zeros((2, 3))
    correction_bias = torch.ones(3)
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    out_logits, out_bias = transform_select_experts_inputs(router_logits, correction_bias, info) # 11.2μs -> 7.86μs (41.9% faster)

def test_info_fake_algorithm_with_none_bias():
    # If correction_bias is None, should remain None
    router_logits = torch.zeros((2, 3))
    correction_bias = None
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    out_logits, out_bias = transform_select_experts_inputs(router_logits, correction_bias, info) # 6.29μs -> 6.42μs (2.06% slower)

# 2. Edge Test Cases

def test_empty_router_logits_and_bias():
    # Empty tensors for router_logits and correction_bias
    router_logits = torch.empty((0, 0))
    correction_bias = torch.empty((0,))
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    out_logits, out_bias = transform_select_experts_inputs(router_logits, correction_bias, info) # 10.9μs -> 7.81μs (39.1% faster)
    # No elements to check for value, just shape

def test_router_logits_1d_tensor():
    # 1D tensor for router_logits
    router_logits = torch.tensor([1.0, 2.0, 3.0])
    correction_bias = torch.tensor([1.0, 2.0, 3.0])
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    out_logits, out_bias = transform_select_experts_inputs(router_logits, correction_bias, info) # 11.6μs -> 8.99μs (28.5% faster)

def test_router_logits_high_dimensional_tensor():
    # High dimensional tensor
    router_logits = torch.zeros((2, 2, 2, 2))
    correction_bias = torch.ones((2, 2))
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    out_logits, out_bias = transform_select_experts_inputs(router_logits, correction_bias, info) # 8.90μs -> 6.27μs (41.9% faster)

def test_correction_bias_different_shape():
    # correction_bias shape different from router_logits
    router_logits = torch.zeros((4, 5))
    correction_bias = torch.ones((5,))
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    out_logits, out_bias = transform_select_experts_inputs(router_logits, correction_bias, info) # 8.80μs -> 6.22μs (41.5% faster)

def test_router_logits_and_bias_with_negative_values():
    # Input tensors with negative values
    router_logits = torch.tensor([[-1.0, -2.0], [-3.0, -4.0]])
    correction_bias = torch.tensor([-0.5, -0.5])
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    out_logits, out_bias = transform_select_experts_inputs(router_logits, correction_bias, info) # 9.85μs -> 7.71μs (27.7% faster)

# 3. Large Scale Test Cases

def test_large_router_logits_and_bias():
    # Large but not huge tensors
    router_logits = torch.zeros((100, 100))
    correction_bias = torch.ones((100,))
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    out_logits, out_bias = transform_select_experts_inputs(router_logits, correction_bias, info) # 34.8μs -> 31.1μs (11.7% faster)

def test_large_router_logits_none_bias():
    # Large router_logits, None correction_bias
    router_logits = torch.zeros((200, 4))
    correction_bias = None
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="fake")
    out_logits, out_bias = transform_select_experts_inputs(router_logits, correction_bias, info) # 8.23μs -> 8.23μs (0.012% faster)

def test_large_router_logits_and_bias_non_fake():
    # Large tensors, but info not "fake" - should be unchanged
    router_logits = torch.arange(100*10, dtype=torch.float32).reshape(100, 10)
    correction_bias = torch.arange(10, dtype=torch.float32)
    info = ExpertLocationDispatchInfo(ep_dispatch_algorithm="real")
    orig_router = router_logits.clone()
    orig_bias = correction_bias.clone()
    out_logits, out_bias = transform_select_experts_inputs(router_logits, correction_bias, info) # 574ns -> 552ns (3.99% faster)

def test_large_router_logits_and_bias_none_info():
    # Large tensors, info is None - should be unchanged
    router_logits = torch.arange(50*20, dtype=torch.float32).reshape(50, 20)
    correction_bias = torch.arange(20, dtype=torch.float32)
    orig_router = router_logits.clone()
    orig_bias = correction_bias.clone()
    out_logits, out_bias = transform_select_experts_inputs(router_logits, correction_bias, None) # 400ns -> 407ns (1.72% slower)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-transform_select_experts_inputs-mhos1r3v and push.

Codeflash Static Badge

The optimization replaces `torch.zeros_like(correction_bias)` with `correction_bias.zero_()` on line 29. This is a micro-optimization that eliminates tensor allocation by performing an in-place operation instead of creating a new zero tensor.

**Key change:** When the dispatch algorithm is "fake" and correction_bias exists, the original code creates a new tensor with `torch.zeros_like()` and assigns it to the local variable. The optimized version directly zeros the existing tensor in-place using `.zero_()`.

**Why it's faster:** In-place operations avoid memory allocation overhead and tensor creation costs. `torch.zeros_like()` must allocate new memory and initialize it, while `.zero_()` simply writes zeros to existing memory locations.

**Performance impact:** The 16% speedup is most pronounced in test cases involving the "fake" algorithm path with non-None correction_bias tensors, where speedups range from 22-42%. The optimization has minimal impact on other code paths since they don't execute this line.

**Hot path relevance:** This function is called from `select_experts()` in the MoE (Mixture of Experts) layer, which is executed during model inference for expert routing. Since MoE models can process many tokens through expert selection, even micro-optimizations in tensor operations can accumulate to meaningful performance gains during inference workloads.

**Test case benefits:** The optimization particularly benefits scenarios with larger tensors and when using the "fake" dispatch algorithm, as seen in tests with 1000-element tensors showing 30%+ improvements.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 7, 2025 11:33
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant