Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 7, 2025

📄 51% (0.51x) speedup for _topk_ids_logical_to_physical_static in python/sglang/srt/eplb/expert_location_dispatch.py

⏱️ Runtime : 540 microseconds 357 microseconds (best of 250 runs)

📝 Explanation and details

The optimization replaces standard tensor indexing with torch.take() when conditions are favorable, achieving a 51% speedup (540μs → 357μs).

Key Optimization:

  • Conditional use of torch.take(): When the mapping tensor is 1D and indices are torch.long, the code uses torch.take(partial_map, topk_ids) instead of partial_map[topk_ids]
  • Safe fallback: If conditions aren't met (non-1D mapping or non-long indices), it falls back to the original indexing approach

Why it's faster:
torch.take() is PyTorch's optimized function specifically designed for 1D tensor indexing. It bypasses the general-purpose advanced indexing machinery that handles arbitrary dimensional cases, resulting in more efficient memory access patterns and reduced overhead.

Performance characteristics from tests:

  • Best case: 89-139% speedup for typical use cases (1D mappings with long indices)
  • Edge cases: Still performs well with empty tensors, high-dimensional indices, and large scales
  • Regression case: ~25% slower when topk_ids uses int32 dtype (falls back to original method)

Impact on workloads:
This function is called from topk_ids_logical_to_physical() in expert routing scenarios, likely in inference hot paths where expert selection happens frequently. The optimization particularly benefits:

  • Large-scale model inference with many experts
  • Batch processing scenarios with substantial topk_ids tensors
  • Systems where expert dispatch happens repeatedly per forward pass

The conditional approach ensures no behavioral changes while maximizing performance for the common case of 1D mappings with proper integer types.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 33 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import pytest  # used for our unit tests
import torch
from sglang.srt.eplb.expert_location_dispatch import \
    _topk_ids_logical_to_physical_static


# function to test
class ExpertLocationDispatchInfo:
    """
    Minimal mock of the ExpertLocationDispatchInfo class for testing purposes.
    """
    def __init__(self, partial_logical_to_rank_dispatch_physical_map):
        self.partial_logical_to_rank_dispatch_physical_map = partial_logical_to_rank_dispatch_physical_map
from sglang.srt.eplb.expert_location_dispatch import \
    _topk_ids_logical_to_physical_static

# unit tests

# ==============================
# 1. Basic Test Cases
# ==============================

def test_basic_single_element():
    # Test with a single-element topk_ids and a simple mapping
    mapping = torch.tensor([10, 20, 30])
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([1])
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 13.7μs -> 7.23μs (89.4% faster)

def test_basic_multiple_elements():
    # Test with multiple elements in topk_ids
    mapping = torch.tensor([5, 6, 7, 8])
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([0, 2, 3])
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 13.1μs -> 6.37μs (105% faster)

def test_basic_repeated_indices():
    # Test with repeated indices in topk_ids
    mapping = torch.tensor([1, 2, 3])
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([2, 2, 0, 1])
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 12.4μs -> 6.41μs (93.0% faster)

def test_basic_different_dtypes():
    # Test with different integer dtypes
    mapping = torch.tensor([100, 200, 300], dtype=torch.int64)
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([0, 2], dtype=torch.int32)
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 21.1μs -> 28.3μs (25.3% slower)

def test_basic_2d_tensor():
    # Test with a 2D topk_ids tensor
    mapping = torch.tensor([11, 22, 33, 44, 55])
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([[0, 1], [2, 3]])
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 13.3μs -> 7.80μs (70.6% faster)
    expected = torch.tensor([[11, 22], [33, 44]])

# ==============================
# 2. Edge Test Cases
# ==============================

def test_edge_empty_topk_ids():
    # Test with empty topk_ids tensor
    mapping = torch.tensor([1, 2, 3])
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([], dtype=torch.long)
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 11.1μs -> 5.59μs (98.9% faster)

def test_edge_empty_mapping():
    # Test with empty mapping and empty topk_ids
    mapping = torch.tensor([], dtype=torch.long)
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([], dtype=torch.long)
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 11.5μs -> 5.27μs (118% faster)

def test_edge_out_of_bounds_index():
    # Test with a topk_id that is out of bounds
    mapping = torch.tensor([1, 2, 3])
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([0, 3])  # 3 is out of bounds
    try:
        _topk_ids_logical_to_physical_static(topk_ids, info)
    except IndexError:
        pass  # Expected

def test_edge_negative_index():
    # Test with negative indices (should work like python indexing)
    mapping = torch.tensor([10, 20, 30])
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([-1, 0])  # -1 should map to 30
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 13.9μs -> 7.32μs (89.6% faster)

def test_edge_large_index_but_valid():
    # Test with the largest valid index
    mapping = torch.arange(100)
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([99])
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 11.0μs -> 5.50μs (99.7% faster)

def test_edge_noncontiguous_mapping():
    # Test with a mapping tensor that is not contiguous in memory
    mapping = torch.arange(10)[::2]  # [0,2,4,6,8]
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([0, 2, 4])
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 11.0μs -> 5.52μs (99.3% faster)

def test_edge_noncontiguous_topk_ids():
    # Test with a topk_ids tensor that is not contiguous
    mapping = torch.tensor([10, 20, 30, 40, 50])
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids_full = torch.tensor([0, 1, 2, 3, 4])
    topk_ids = topk_ids_full[::2]  # [0,2,4]
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 12.5μs -> 6.66μs (87.8% faster)

def test_edge_high_rank_tensor():
    # Test with a high-rank (3D) tensor
    mapping = torch.arange(100)
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 12.9μs -> 5.39μs (139% faster)
    expected = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])

# ==============================
# 3. Large Scale Test Cases
# ==============================

def test_large_scale_many_elements():
    # Test with a large mapping and topk_ids
    mapping = torch.arange(1000)
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.arange(1000)
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 13.8μs -> 6.97μs (98.7% faster)

def test_large_scale_random_indices():
    # Test with random indices within bounds
    torch.manual_seed(42)
    mapping = torch.arange(1000)
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.randint(0, 1000, (500,))
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 12.8μs -> 6.32μs (102% faster)
    # Check that each output is equal to mapping at that index
    for i in range(topk_ids.size(0)):
        pass

def test_large_scale_high_dimensional():
    # Test with a large, high-dimensional topk_ids tensor
    mapping = torch.arange(100)
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.arange(24).reshape(2, 3, 4)
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 11.4μs -> 4.96μs (130% faster)

def test_large_scale_performance():
    # Test that the function works for the upper limit of allowed size (~1000 elements)
    mapping = torch.arange(999)
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.arange(999)
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 12.9μs -> 6.49μs (98.3% faster)

def test_large_scale_maximum_valid_index():
    # Test with topk_ids containing the maximum valid index for the mapping
    mapping = torch.arange(1000)
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([999])
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 10.5μs -> 4.93μs (113% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import pytest
import torch
from sglang.srt.eplb.expert_location_dispatch import \
    _topk_ids_logical_to_physical_static


# Dummy class to simulate ExpertLocationDispatchInfo for testing
class ExpertLocationDispatchInfo:
    def __init__(self, partial_logical_to_rank_dispatch_physical_map):
        self.partial_logical_to_rank_dispatch_physical_map = partial_logical_to_rank_dispatch_physical_map
from sglang.srt.eplb.expert_location_dispatch import \
    _topk_ids_logical_to_physical_static

# unit tests

# 1. Basic Test Cases

def test_basic_identity_mapping():
    # Test with identity mapping (logical == physical)
    mapping = torch.arange(10)
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([0, 1, 2, 3, 4])
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 11.7μs -> 4.91μs (138% faster)

def test_basic_non_identity_mapping():
    # Test with a shuffled mapping
    mapping = torch.tensor([3, 1, 4, 0, 2])
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([0, 2, 4])
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 12.7μs -> 6.38μs (99.7% faster)

def test_basic_single_element():
    # Test with a single element
    mapping = torch.tensor([10, 20, 30])
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([1])
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 12.4μs -> 6.68μs (84.8% faster)

def test_basic_repeated_indices():
    # Test with repeated indices in topk_ids
    mapping = torch.tensor([5, 6, 7])
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([2, 2, 1, 0])
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 12.9μs -> 6.68μs (92.7% faster)

# 2. Edge Test Cases

def test_empty_topk_ids():
    # Test with empty topk_ids tensor
    mapping = torch.arange(10)
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([], dtype=torch.long)
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 10.3μs -> 4.50μs (130% faster)

def test_empty_mapping_and_topk_ids():
    # Both mapping and topk_ids empty
    mapping = torch.tensor([], dtype=torch.long)
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([], dtype=torch.long)
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 11.2μs -> 5.25μs (113% faster)

def test_out_of_bounds_index_raises():
    # Test with out-of-bounds index in topk_ids
    mapping = torch.tensor([1, 2, 3])
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([0, 1, 3])  # 3 is out of bounds
    with pytest.raises(IndexError):
        _topk_ids_logical_to_physical_static(topk_ids, info) # 63.7μs -> 60.3μs (5.56% faster)

def test_negative_index():
    # Test with negative index in topk_ids
    mapping = torch.tensor([10, 20, 30])
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([-1, 0])  # -1 should wrap to last element
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 14.0μs -> 7.58μs (85.0% faster)

def test_large_indices():
    # Test with large indices, but within bounds
    mapping = torch.arange(100)
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([99, 0, 50])
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 11.8μs -> 5.52μs (114% faster)

def test_non_contiguous_mapping():
    # Test with a mapping tensor that is not contiguous in memory
    mapping = torch.arange(20)[::2]  # [0,2,4,6,8,10,12,14,16,18]
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([0, 3, 5, 9])
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 10.7μs -> 5.38μs (98.5% faster)

def test_different_dtypes():
    # Test with different integer dtypes
    mapping = torch.tensor([1, 2, 3], dtype=torch.int64)
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([1, 2], dtype=torch.int32)
    # torch supports mixing int32 indices with int64 mapping
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 20.2μs -> 25.8μs (21.8% slower)

def test_high_dimensional_topk_ids():
    # Test with 2D topk_ids tensor
    mapping = torch.arange(10)
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.tensor([[0, 1], [2, 3]])
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 12.5μs -> 6.74μs (84.8% faster)


def test_large_scale():
    # Test with large mapping and topk_ids tensors (but <100MB)
    n = 1000
    mapping = torch.arange(n)
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.arange(n-1, -1, -1)  # reversed indices
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 22.0μs -> 10.2μs (117% faster)
    expected = torch.arange(n-1, -1, -1)

def test_large_randomized():
    # Test with randomized mapping and indices
    n = 1000
    mapping = torch.randperm(n)
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.randint(0, n, (n,))
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 15.6μs -> 7.10μs (120% faster)
    # Check that each result is equal to mapping[topk_ids[i]]
    for i in range(n):
        pass

def test_large_scale_2d():
    # Test with a large 2D topk_ids tensor
    n, m = 32, 32  # 32*32 = 1024 elements
    mapping = torch.arange(2000)
    info = ExpertLocationDispatchInfo(mapping)
    topk_ids = torch.randint(0, 2000, (n, m))
    codeflash_output = _topk_ids_logical_to_physical_static(topk_ids, info); result = codeflash_output # 15.1μs -> 7.23μs (109% faster)
    for i in range(0, n, 10):
        for j in range(0, m, 10):
            pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-_topk_ids_logical_to_physical_static-mhosffbw and push.

Codeflash Static Badge

The optimization replaces standard tensor indexing with `torch.take()` when conditions are favorable, achieving a **51% speedup** (540μs → 357μs).

**Key Optimization:**
- **Conditional use of `torch.take()`**: When the mapping tensor is 1D and indices are `torch.long`, the code uses `torch.take(partial_map, topk_ids)` instead of `partial_map[topk_ids]`
- **Safe fallback**: If conditions aren't met (non-1D mapping or non-long indices), it falls back to the original indexing approach

**Why it's faster:**
`torch.take()` is PyTorch's optimized function specifically designed for 1D tensor indexing. It bypasses the general-purpose advanced indexing machinery that handles arbitrary dimensional cases, resulting in more efficient memory access patterns and reduced overhead.

**Performance characteristics from tests:**
- **Best case**: 89-139% speedup for typical use cases (1D mappings with long indices)
- **Edge cases**: Still performs well with empty tensors, high-dimensional indices, and large scales
- **Regression case**: ~25% slower when `topk_ids` uses `int32` dtype (falls back to original method)

**Impact on workloads:**
This function is called from `topk_ids_logical_to_physical()` in expert routing scenarios, likely in inference hot paths where expert selection happens frequently. The optimization particularly benefits:
- Large-scale model inference with many experts
- Batch processing scenarios with substantial `topk_ids` tensors
- Systems where expert dispatch happens repeatedly per forward pass

The conditional approach ensures no behavioral changes while maximizing performance for the common case of 1D mappings with proper integer types.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 7, 2025 11:44
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant