Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 7, 2025

📄 3,041% (30.41x) speedup for MoeWNA16Config.is_moe_wna16_compatible in python/sglang/srt/layers/quantization/moe_wna16.py

⏱️ Runtime : 33.7 milliseconds 1.07 milliseconds (best of 63 runs)

📝 Explanation and details

The optimization adds @lru_cache(maxsize=1) to the get_device_capability function, which provides a dramatic 31x speedup by eliminating redundant system calls.

Key Optimization:

  • Function Caching: The @lru_cache(maxsize=1) decorator caches the result of get_device_capability(), which involves expensive torch API calls like torch.cuda.is_available() and torch.cuda.get_device_capability().

Why This Works:

  • Device capabilities are static hardware properties that don't change during program execution
  • The original code called get_device_capability() repeatedly (1680 times in profiling), each time making expensive torch API calls
  • Line profiler shows the original function spent 64.4ms total (91.5% in torch.cuda.is_available() alone), while the cached version takes only 2.6ms

Performance Impact:

  • Test results show consistent 25-30x speedups across all test cases
  • Most beneficial for workloads that repeatedly check quantization compatibility, such as model initialization or configuration validation
  • The maxsize=1 is sufficient since device capabilities rarely vary within a single process, and the default device_id=0 covers the common use case

Use Cases:
The optimization particularly benefits scenarios where MoeWNA16Config.is_moe_wna16_compatible() is called frequently, such as during model loading, quantization setup, or configuration validation loops - all critical paths in ML inference systems.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 1680 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import pytest
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config

# --------------------------
# Unit tests for MoeWNA16Config.is_moe_wna16_compatible
# --------------------------

# --- BASIC TEST CASES ---

def test_gptq_compatible_4bit():
    # Basic: GPTQ, 4 bits, desc_act False
    cfg = {
        "quant_method": "gptq",
        "bits": 4,
        "desc_act": False
    }
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 94.7μs -> 3.29μs (2779% faster)

def test_gptq_compatible_8bit():
    # Basic: GPTQ, 8 bits, desc_act False
    cfg = {
        "quant_method": "gptq",
        "bits": 8,
        "desc_act": False
    }
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 71.4μs -> 2.47μs (2791% faster)

def test_gptq_incompatible_desc_act_true():
    # Basic: GPTQ, 4 bits, desc_act True (should be incompatible)
    cfg = {
        "quant_method": "gptq",
        "bits": 4,
        "desc_act": True
    }
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 66.8μs -> 2.25μs (2866% faster)

def test_awq_compatible_min_capability():
    # Basic: AWQ, 4 bits, meets min capability
    cfg = {
        "quant_method": "awq",
        "bits": 4,
        "desc_act": False,
        "_device_capability_tuple": (7, 5),  # 75
        "_awq_min_capability": 75
    }
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 62.4μs -> 2.25μs (2667% faster)

def test_awq_incompatible_low_capability():
    # Basic: AWQ, 4 bits, below min capability
    cfg = {
        "quant_method": "awq",
        "bits": 4,
        "desc_act": False,
        "_device_capability_tuple": (7, 0),  # 70
        "_awq_min_capability": 75
    }
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 61.5μs -> 2.12μs (2806% faster)

def test_awq_incompatible_wrong_bits():
    # Basic: AWQ, 8 bits (should be incompatible)
    cfg = {
        "quant_method": "awq",
        "bits": 8,
        "desc_act": False,
        "_device_capability_tuple": (8, 0),
        "_awq_min_capability": 75
    }
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 61.3μs -> 2.09μs (2835% faster)

def test_other_quant_method():
    # Basic: Unknown quant method (should be incompatible)
    cfg = {
        "quant_method": "foo",
        "bits": 4,
        "desc_act": False
    }
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 63.2μs -> 1.98μs (3100% faster)

# --- EDGE TEST CASES ---

def test_missing_quant_method():
    # Edge: Missing quant_method key
    cfg = {
        "bits": 4,
        "desc_act": False
    }
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 59.8μs -> 2.01μs (2871% faster)

def test_missing_bits():
    # Edge: Missing bits key
    cfg = {
        "quant_method": "gptq",
        "desc_act": False
    }
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 62.3μs -> 2.34μs (2564% faster)

def test_missing_desc_act():
    # Edge: Missing desc_act key
    cfg = {
        "quant_method": "gptq",
        "bits": 4
    }
    # Should be incompatible since desc_act is missing (None is treated as True)
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 60.0μs -> 2.11μs (2749% faster)

def test_device_capability_none():
    # Edge: Device capability tuple is all None
    cfg = {
        "quant_method": "awq",
        "bits": 4,
        "desc_act": False,
        "_device_capability_tuple": (None, None),
        "_awq_min_capability": 75
    }
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 60.6μs -> 2.02μs (2899% faster)

def test_device_capability_exact_min():
    # Edge: Device capability exactly at minimum
    cfg = {
        "quant_method": "awq",
        "bits": 4,
        "desc_act": False,
        "_device_capability_tuple": (7, 5),
        "_awq_min_capability": 75
    }
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 59.2μs -> 1.99μs (2872% faster)

def test_device_capability_above_min():
    # Edge: Device capability above minimum
    cfg = {
        "quant_method": "awq",
        "bits": 4,
        "desc_act": False,
        "_device_capability_tuple": (8, 0),
        "_awq_min_capability": 75
    }
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 59.5μs -> 2.03μs (2825% faster)

def test_bits_not_int():
    # Edge: bits is not integer
    cfg = {
        "quant_method": "gptq",
        "bits": "4",
        "desc_act": False
    }
    # Should be incompatible because bits is not int
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 58.7μs -> 2.17μs (2608% faster)

def test_desc_act_none():
    # Edge: desc_act is None
    cfg = {
        "quant_method": "gptq",
        "bits": 4,
        "desc_act": None
    }
    # None is treated as True (not False)
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 58.8μs -> 2.03μs (2803% faster)

def test_extra_keys():
    # Edge: Extra unrelated keys should not affect compatibility
    cfg = {
        "quant_method": "gptq",
        "bits": 4,
        "desc_act": False,
        "foo": "bar",
        "baz": 123
    }
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 59.3μs -> 2.01μs (2847% faster)

def test_case_insensitive_quant_method():
    # Edge: quant_method case insensitive
    cfg = {
        "quant_method": "GPTQ",
        "bits": 4,
        "desc_act": False
    }
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 58.9μs -> 2.02μs (2810% faster)

def test_bits_float():
    # Edge: bits is float, should not be compatible
    cfg = {
        "quant_method": "gptq",
        "bits": 4.0,
        "desc_act": False
    }
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 60.7μs -> 2.27μs (2573% faster)

def test_awq_with_large_capability_and_wrong_bits():
    # Edge: AWQ, device capability large, but bits wrong
    cfg = {
        "quant_method": "awq",
        "bits": 8,
        "desc_act": False,
        "_device_capability_tuple": (9, 9),  # 99
        "_awq_min_capability": 75
    }
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 60.5μs -> 1.88μs (3125% faster)

def test_gptq_with_extra_keys():
    # Edge: GPTQ, compatible, extra keys
    cfg = {
        "quant_method": "gptq",
        "bits": 8,
        "desc_act": False,
        "irrelevant": "data"
    }
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 58.2μs -> 2.14μs (2622% faster)

# --- LARGE SCALE TEST CASES ---

def test_large_batch_gptq_configs():
    # Large: Test many configs at once for gptq
    for i in range(100):
        cfg = {
            "quant_method": "gptq",
            "bits": 4 if i % 2 == 0 else 8,
            "desc_act": False
        }
        codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 1.99ms -> 63.6μs (3026% faster)

def test_large_batch_awq_configs():
    # Large: Test many configs at once for awq
    for i in range(100):
        capability = 75 + i % 25  # from 75 to 99
        cfg = {
            "quant_method": "awq",
            "bits": 4,
            "desc_act": False,
            "_device_capability_tuple": (capability // 10, capability % 10),
            "_awq_min_capability": 75
        }
        codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 1.97ms -> 63.2μs (3012% faster)

def test_large_batch_awq_configs_below_min():
    # Large: Test many configs below min capability for awq
    for i in range(50):
        capability = 50 + i  # from 50 to 99
        bits = 4
        cfg = {
            "quant_method": "awq",
            "bits": bits,
            "desc_act": False,
            "_device_capability_tuple": (capability // 10, capability % 10),
            "_awq_min_capability": 75
        }
        expected = capability >= 75
        codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 1.02ms -> 32.8μs (3002% faster)

def test_large_batch_mixed_methods():
    # Large: Mix of compatible/incompatible configs
    for i in range(200):
        method = "gptq" if i % 3 == 0 else "awq" if i % 3 == 1 else "foo"
        bits = 4 if i % 2 == 0 else 8
        desc_act = False if i % 5 != 0 else True
        capability = 80
        cfg = {
            "quant_method": method,
            "bits": bits,
            "desc_act": desc_act,
            "_device_capability_tuple": (capability // 10, capability % 10),
            "_awq_min_capability": 75
        }
        if method == "gptq" and not desc_act and bits in [4, 8]:
            expected = True
        elif method == "awq" and bits == 4 and capability >= 75:
            expected = True
        else:
            expected = False
        codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 3.87ms -> 127μs (2936% faster)

def test_large_batch_missing_keys():
    # Large: Many configs missing keys, all should be incompatible
    for i in range(100):
        cfg = {}
        codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 1.98ms -> 57.6μs (3339% faster)

def test_large_batch_edge_cases():
    # Large: Many configs with edge values
    for i in range(100):
        bits = None if i % 2 == 0 else 4
        desc_act = None if i % 3 == 0 else False
        quant_method = "gptq" if i % 4 == 0 else "awq"
        capability = 74 if i % 5 == 0 else 75
        cfg = {
            "quant_method": quant_method,
            "bits": bits,
            "desc_act": desc_act,
            "_device_capability_tuple": (capability // 10, capability % 10),
            "_awq_min_capability": 75
        }
        # Only compatible if bits==4, desc_act==False, quant_method in ["gptq", "awq"], and capability>=75 for awq
        if quant_method == "gptq" and bits == 4 and desc_act is False:
            expected = True
        elif quant_method == "awq" and bits == 4 and desc_act is False and capability >= 75:
            expected = True
        else:
            expected = False
        codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(cfg) # 1.98ms -> 65.0μs (2949% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import pytest
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config


# --- Pytest fixtures for monkeypatching ---
@pytest.fixture
def patch_device_capability(monkeypatch):
    def _patch(return_value):
        monkeypatch.setattr(__name__ + ".get_device_capability", lambda device_id=0: return_value)
    return _patch

@pytest.fixture
def patch_awq_min_capability(monkeypatch):
    def _patch(value):
        monkeypatch.setattr(AWQConfig, "get_min_capability", classmethod(lambda cls: value))
    return _patch

# --- Basic Test Cases ---
def test_gptq_compatible_4bit_desc_act_false():
    # Basic GPTQ, 4 bits, desc_act False should be compatible
    config = {"quant_method": "gptq", "bits": 4, "desc_act": False}
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(config) # 57.1μs -> 2.14μs (2574% faster)

def test_gptq_compatible_8bit_desc_act_false():
    # Basic GPTQ, 8 bits, desc_act False should be compatible
    config = {"quant_method": "gptq", "bits": 8, "desc_act": False}
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(config) # 54.3μs -> 2.13μs (2443% faster)

def test_gptq_incompatible_desc_act_true():
    # GPTQ, desc_act True should be incompatible
    config = {"quant_method": "gptq", "bits": 4, "desc_act": True}
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(config) # 60.7μs -> 2.02μs (2900% faster)

def test_gptq_incompatible_bits_not_supported():
    # GPTQ, bits not 4 or 8 should be incompatible
    config = {"quant_method": "gptq", "bits": 5, "desc_act": False}
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(config) # 55.3μs -> 2.13μs (2491% faster)




def test_unknown_quant_method():
    # Unknown quant_method should be incompatible
    config = {"quant_method": "foo", "bits": 4, "desc_act": False}
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(config) # 94.5μs -> 3.03μs (3019% faster)

# --- Edge Test Cases ---
def test_missing_quant_method_key():
    # quant_method missing: should be incompatible
    config = {"bits": 4, "desc_act": False}
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(config) # 71.5μs -> 2.20μs (3151% faster)

def test_missing_bits_key():
    # bits missing: should be incompatible
    config = {"quant_method": "gptq", "desc_act": False}
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(config) # 65.8μs -> 2.52μs (2511% faster)

def test_missing_desc_act_key():
    # desc_act missing: should be incompatible for gptq
    config = {"quant_method": "gptq", "bits": 4}
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(config) # 64.5μs -> 2.32μs (2683% faster)





def test_gptq_case_insensitive_method():
    # quant_method should be case insensitive
    config = {"quant_method": "GPTQ", "bits": 4, "desc_act": False}
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(config) # 94.5μs -> 3.32μs (2744% faster)



def test_empty_config():
    # Empty config should be incompatible
    config = {}
    codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(config) # 94.0μs -> 3.00μs (3032% faster)

# --- Large Scale Test Cases ---
def test_large_scale_gptq_compatible():
    # Test 1000 configs, all GPTQ, bits 4 or 8, desc_act False
    for i in range(500):
        config = {"quant_method": "gptq", "bits": 4, "desc_act": False}
        codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(config) # 9.50ms -> 298μs (3087% faster)
    for i in range(500):
        config = {"quant_method": "gptq", "bits": 8, "desc_act": False}
        codeflash_output = MoeWNA16Config.is_moe_wna16_compatible(config) # 9.40ms -> 295μs (3077% faster)

To edit these changes git checkout codeflash/optimize-MoeWNA16Config.is_moe_wna16_compatible-mhoyvbxs and push.

Codeflash Static Badge

The optimization adds `@lru_cache(maxsize=1)` to the `get_device_capability` function, which provides a dramatic **31x speedup** by eliminating redundant system calls.

**Key Optimization:**
- **Function Caching**: The `@lru_cache(maxsize=1)` decorator caches the result of `get_device_capability()`, which involves expensive torch API calls like `torch.cuda.is_available()` and `torch.cuda.get_device_capability()`.

**Why This Works:**
- Device capabilities are static hardware properties that don't change during program execution
- The original code called `get_device_capability()` repeatedly (1680 times in profiling), each time making expensive torch API calls
- Line profiler shows the original function spent 64.4ms total (91.5% in `torch.cuda.is_available()` alone), while the cached version takes only 2.6ms

**Performance Impact:**
- Test results show consistent 25-30x speedups across all test cases
- Most beneficial for workloads that repeatedly check quantization compatibility, such as model initialization or configuration validation
- The `maxsize=1` is sufficient since device capabilities rarely vary within a single process, and the default `device_id=0` covers the common use case

**Use Cases:**
The optimization particularly benefits scenarios where `MoeWNA16Config.is_moe_wna16_compatible()` is called frequently, such as during model loading, quantization setup, or configuration validation loops - all critical paths in ML inference systems.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 7, 2025 14:44
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant