Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 7, 2025

📄 28% (0.28x) speedup for MoeWNA16Config.from_config in python/sglang/srt/layers/quantization/moe_wna16.py

⏱️ Runtime : 385 microseconds 302 microseconds (best of 39 runs)

📝 Explanation and details

Explanation of Optimizations

  1. base_config.py

    • Changed the default for packed_modules_mapping initialization from dict() to {}. This is a micro-optimization as {} is faster than dict() for literal empty dict creation.
    • Optimized the looping in get_from_keys by using the generator version of the pattern: return next((config[k] for k in keys if k in config), ...) with a default sentinel to raise ValueError. This reduces average-case lookup time when the key is found early.
  2. moe_wna16.py

    • Avoided recalculating device capability tuple if not needed.
    • Pulled the computation of modules_to_not_convert outside the conditional, as the logic for setting it to [] or a default value is repeated.
    • Used direct truthy-check for modules_to_not_convert, reducing a small overhead and slightly improving code clarity.
    • Used dictionary unpacking (still compatible with 3.10), e.g., for method calls that take a dictionary, since it does not apply here, left unchanged.
    • Changed modules_to_not_convert assignment: instead of if/else None-check, just use or to fallback to empty list, saving a branch at runtime.
    • No other runtime optimizations were applicable while preserving behavioral and code style constraints.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 414 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import sys
import types
from typing import Any, Dict, List, Optional

# imports
import pytest
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config

# --- Minimal stubs/mocks for dependencies (as per instructions, no pytest.mock etc.) ---

# Simulate CUDA always available for test purposes
class torch_cuda_stub:
    @staticmethod
    def is_available():
        return True

    @staticmethod
    def get_device_capability(device_id=0):
        # Return a "modern" capability (8, 0) for most tests
        return (8, 0)

class torch_stub:
    cuda = torch_cuda_stub()
    # For AWQ, we only need cuda for now

# --- Implementing stubs for required classes and functions ---

# sglang.srt.utils.get_device_capability
def get_device_capability(device_id: int = 0):
    return torch_stub.cuda.get_device_capability(device_id)

# sglang.srt.layers.quantization.awq.AWQConfig
class AWQConfig:
    @classmethod
    def get_min_capability(cls):
        # AWQ requires at least capability 7.5 (Turing)
        return 75

# sglang.srt.layers.quantization.gptq.GPTQMarlinConfig
class GPTQMarlinConfig:
    @classmethod
    def is_gptq_marlin_compatible(cls, quant_config):
        # For testing, let's say if 'marlin_compatible' is set, return its value
        # Otherwise, True if bits==4 or bits==8, group_size > 0, sym in [True, False], desc_act in [True, False]
        if "marlin_compatible" in quant_config:
            return quant_config["marlin_compatible"]
        bits = quant_config.get("bits")
        group_size = quant_config.get("group_size")
        sym = quant_config.get("sym")
        desc_act = quant_config.get("desc_act")
        if bits in (4, 8) and group_size is not None and sym in (True, False) and desc_act in (True, False):
            return True
        return False
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config

# --- UNIT TESTS ---

# ------------------------------
# 1. BASIC TEST CASES
# ------------------------------

def test_gptq_basic_config():
    # Basic GPTQ config, sym=False, should set has_zp=True, modules_to_not_convert=[]
    config = {
        "quant_method": "gptq",
        "bits": 4,
        "group_size": 128,
        "sym": False,
        "desc_act": True,
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output # 10.2μs -> 7.60μs (33.9% faster)

def test_gptq_basic_config_sym_true():
    # GPTQ config, sym=True, should set has_zp=False
    config = {
        "quant_method": "gptq",
        "bits": 8,
        "group_size": 64,
        "sym": True,
        "desc_act": False,
        "marlin_compatible": False,  # override use_marlin
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output # 7.40μs -> 5.14μs (43.9% faster)

def test_gptq_with_lm_head():
    # GPTQ config with lm_head=True
    config = {
        "quant_method": "gptq",
        "bits": 4,
        "group_size": 32,
        "sym": True,
        "desc_act": True,
        "lm_head": True,
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output # 5.61μs -> 4.97μs (12.7% faster)

def test_awq_basic_config():
    # AWQ config, zero_point=True, modules_to_not_convert provided
    config = {
        "quant_method": "awq",
        "bits": 4,
        "group_size": 64,
        "zero_point": True,
        "modules_to_not_convert": ["foo", "bar"],
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output

def test_awq_basic_config_no_modules_to_not_convert():
    # AWQ config, modules_to_not_convert not provided (should default to [])
    config = {
        "quant_method": "awq",
        "bits": 4,
        "group_size": 32,
        "zero_point": False,
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output

def test_awq_with_lm_head():
    # AWQ config with lm_head=True
    config = {
        "quant_method": "awq",
        "bits": 4,
        "group_size": 32,
        "zero_point": True,
        "lm_head": True,
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output

# ------------------------------
# 2. EDGE TEST CASES
# ------------------------------

def test_missing_quant_method():
    # Should raise ValueError if quant_method is missing
    config = {
        "bits": 4,
        "group_size": 32,
        "sym": True,
    }
    with pytest.raises(ValueError):
        MoeWNA16Config.from_config(config) # 3.16μs -> 2.97μs (6.43% faster)

def test_missing_bits():
    # Should raise ValueError if bits is missing
    config = {
        "quant_method": "gptq",
        "group_size": 32,
        "sym": True,
    }
    with pytest.raises(ValueError):
        MoeWNA16Config.from_config(config) # 2.65μs -> 2.57μs (3.15% faster)

def test_missing_group_size():
    # Should raise ValueError if group_size is missing
    config = {
        "quant_method": "gptq",
        "bits": 4,
        "sym": True,
    }
    with pytest.raises(ValueError):
        MoeWNA16Config.from_config(config) # 2.58μs -> 2.51μs (2.82% faster)

def test_gptq_missing_sym():
    # Should raise ValueError if sym is missing for gptq
    config = {
        "quant_method": "gptq",
        "bits": 4,
        "group_size": 32,
    }
    with pytest.raises(ValueError):
        MoeWNA16Config.from_config(config) # 4.63μs -> 3.30μs (40.4% faster)

def test_awq_missing_zero_point():
    # Should raise ValueError if zero_point is missing for awq
    config = {
        "quant_method": "awq",
        "bits": 4,
        "group_size": 32,
    }
    with pytest.raises(ValueError):
        MoeWNA16Config.from_config(config) # 3.93μs -> 2.90μs (35.5% faster)

def test_invalid_quant_method():
    # Should raise ValueError for invalid quant_method
    config = {
        "quant_method": "foobar",
        "bits": 4,
        "group_size": 32,
        "sym": True,
    }
    with pytest.raises(ValueError):
        MoeWNA16Config.from_config(config) # 3.45μs -> 2.02μs (70.4% faster)

def test_awq_device_capability_too_low(monkeypatch):
    # Should raise ValueError if device capability is too low for AWQ
    def fake_get_device_capability(device_id=0):
        return (5, 0)  # 50 < 75
    monkeypatch.setattr(__name__ + ".get_device_capability", fake_get_device_capability)
    config = {
        "quant_method": "awq",
        "bits": 4,
        "group_size": 32,
        "zero_point": True,
    }
    with pytest.raises(ValueError) as excinfo:
        MoeWNA16Config.from_config(config)

def test_awq_device_capability_none(monkeypatch):
    # Should raise ValueError if get_device_capability returns None
    def fake_get_device_capability(device_id=0):
        return (None, None)
    monkeypatch.setattr(__name__ + ".get_device_capability", fake_get_device_capability)
    config = {
        "quant_method": "awq",
        "bits": 4,
        "group_size": 32,
        "zero_point": True,
    }
    with pytest.raises(ValueError):
        MoeWNA16Config.from_config(config)

def test_gptq_marlin_compatible_false():
    # use_marlin should be False if GPTQMarlinConfig.is_gptq_marlin_compatible returns False
    config = {
        "quant_method": "gptq",
        "bits": 4,
        "group_size": 32,
        "sym": True,
        "desc_act": True,
        "marlin_compatible": False,
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output # 9.71μs -> 7.75μs (25.3% faster)

def test_gptq_marlin_compatible_true():
    # use_marlin should be True if GPTQMarlinConfig.is_gptq_marlin_compatible returns True
    config = {
        "quant_method": "gptq",
        "bits": 4,
        "group_size": 32,
        "sym": True,
        "desc_act": True,
        "marlin_compatible": True,
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output # 7.33μs -> 5.20μs (41.0% faster)

def test_awq_modules_to_not_convert_none():
    # AWQ: modules_to_not_convert=None should result in []
    config = {
        "quant_method": "awq",
        "bits": 4,
        "group_size": 32,
        "zero_point": True,
        "modules_to_not_convert": None,
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output

def test_awq_modules_to_not_convert_empty_list():
    # AWQ: modules_to_not_convert=[]
    config = {
        "quant_method": "awq",
        "bits": 4,
        "group_size": 32,
        "zero_point": True,
        "modules_to_not_convert": [],
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output

def test_gptq_modules_to_not_convert_always_empty():
    # GPTQ: modules_to_not_convert always []
    config = {
        "quant_method": "gptq",
        "bits": 8,
        "group_size": 64,
        "sym": True,
        "desc_act": False,
        "modules_to_not_convert": ["should", "be", "ignored"],
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output # 10.0μs -> 7.56μs (32.4% faster)

# ------------------------------
# 3. LARGE SCALE TEST CASES
# ------------------------------

def test_large_group_size():
    # Test with large group_size (but < 1000)
    config = {
        "quant_method": "gptq",
        "bits": 4,
        "group_size": 999,
        "sym": False,
        "desc_act": True,
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output # 7.55μs -> 5.30μs (42.4% faster)

def test_large_modules_to_not_convert():
    # AWQ: modules_to_not_convert with 999 elements
    modules = [f"mod_{i}" for i in range(999)]
    config = {
        "quant_method": "awq",
        "bits": 4,
        "group_size": 128,
        "zero_point": True,
        "modules_to_not_convert": modules,
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output

def test_large_config_dict():
    # Config dict with many irrelevant keys
    config = {
        "quant_method": "gptq",
        "bits": 8,
        "group_size": 128,
        "sym": True,
        "desc_act": False,
        "lm_head": False,
    }
    # Add 900 irrelevant keys
    for i in range(900):
        config[f"irrelevant_{i}"] = i
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output # 7.97μs -> 7.55μs (5.54% faster)

def test_many_calls_performance():
    # Call from_config 100 times, should not degrade or leak
    config = {
        "quant_method": "gptq",
        "bits": 4,
        "group_size": 256,
        "sym": True,
        "desc_act": True,
    }
    results = []
    for i in range(100):
        codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output # 216μs -> 166μs (29.7% faster)
        results.append(obj)

def test_awq_large_modules_to_not_convert_and_lm_head():
    # AWQ with large modules_to_not_convert and lm_head
    modules = [f"layer_{i}" for i in range(500)]
    config = {
        "quant_method": "awq",
        "bits": 4,
        "group_size": 128,
        "zero_point": False,
        "modules_to_not_convert": modules,
        "lm_head": True,
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from typing import Any, Dict, List, Optional

# imports
import pytest
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config

# --- Minimal stubs and mocks for dependencies (since we can't import sglang) ---

# Simulate CUDA/Device capability for AWQ
class AWQConfig:
    @classmethod
    def get_min_capability(cls):
        return 75

# Simulate Marlin compatibility for GPTQ
class GPTQMarlinConfig:
    @classmethod
    def is_gptq_marlin_compatible(cls, quant_config):
        # For test, return True if "marlin_compatible" in config and it's True
        return quant_config.get("marlin_compatible", False)

# Simulate device capability getter
def get_device_capability():
    # For test, can be monkeypatched
    return (8, 0)  # 80, which is >= 75 (AWQ min capability)
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config

# --- Test suite for MoeWNA16Config.from_config ---

# --- BASIC TEST CASES ---

def test_gptq_basic_config():
    # Basic GPTQ config, sym=False -> has_zp=True
    config = {
        "quant_method": "gptq",
        "bits": 4,
        "group_size": 32,
        "sym": False,
        "lm_head": True,
        "marlin_compatible": True,
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output # 8.04μs -> 7.69μs (4.64% faster)

def test_gptq_basic_config_sym_true():
    # sym=True -> has_zp=False
    config = {
        "quant_method": "gptq",
        "bits": 8,
        "group_size": 128,
        "sym": True,
        "lm_head": False,
        "marlin_compatible": False,
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output # 5.84μs -> 5.21μs (12.2% faster)

def test_awq_basic_config():
    # Basic AWQ config, zero_point=True
    config = {
        "quant_method": "awq",
        "bits": 4,
        "group_size": 64,
        "zero_point": True,
        "modules_to_not_convert": ["layer1", "layer2"],
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output

def test_awq_basic_config_modules_to_not_convert_none():
    # AWQ config, modules_to_not_convert missing (should default to [])
    config = {
        "quant_method": "awq",
        "bits": 4,
        "group_size": 128,
        "zero_point": False,
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output

def test_lm_head_default_false():
    # lm_head missing, should default to False
    config = {
        "quant_method": "gptq",
        "bits": 4,
        "group_size": 32,
        "sym": True,
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output # 10.2μs -> 7.67μs (33.0% faster)

# --- EDGE TEST CASES ---

def test_missing_quant_method_raises():
    # quant_method missing
    config = {
        "bits": 4,
        "group_size": 32,
        "sym": False,
    }
    with pytest.raises(ValueError):
        MoeWNA16Config.from_config(config) # 2.44μs -> 2.82μs (13.7% slower)

def test_missing_bits_raises():
    config = {
        "quant_method": "gptq",
        "group_size": 32,
        "sym": False,
    }
    with pytest.raises(ValueError):
        MoeWNA16Config.from_config(config) # 2.69μs -> 2.57μs (4.67% faster)

def test_missing_group_size_raises():
    config = {
        "quant_method": "gptq",
        "bits": 4,
        "sym": False,
    }
    with pytest.raises(ValueError):
        MoeWNA16Config.from_config(config) # 2.78μs -> 2.54μs (9.36% faster)

def test_gptq_missing_sym_raises():
    config = {
        "quant_method": "gptq",
        "bits": 4,
        "group_size": 32,
    }
    with pytest.raises(ValueError):
        MoeWNA16Config.from_config(config) # 4.48μs -> 3.09μs (44.8% faster)

def test_awq_missing_zero_point_raises():
    config = {
        "quant_method": "awq",
        "bits": 4,
        "group_size": 32,
    }
    with pytest.raises(ValueError):
        MoeWNA16Config.from_config(config) # 3.92μs -> 3.10μs (26.7% faster)

def test_unsupported_quant_method_raises():
    config = {
        "quant_method": "foo",
        "bits": 4,
        "group_size": 32,
        "sym": False,
    }
    with pytest.raises(ValueError, match="moe_wna16 only support gptq and awq"):
        MoeWNA16Config.from_config(config) # 3.44μs -> 2.17μs (58.6% faster)

def test_awq_device_capability_too_low(monkeypatch):
    # Simulate device capability below AWQ min (e.g. 70 < 75)
    config = {
        "quant_method": "awq",
        "bits": 4,
        "group_size": 32,
        "zero_point": True,
    }
    def fake_get_device_capability():
        return (7, 0)  # 70
    monkeypatch.setattr(__name__ + ".get_device_capability", fake_get_device_capability)
    with pytest.raises(ValueError, match="not supported.*Minimum capability: 75"):
        MoeWNA16Config.from_config(config)

def test_awq_device_capability_none(monkeypatch):
    # Simulate device capability returns None (should be treated as -1)
    config = {
        "quant_method": "awq",
        "bits": 4,
        "group_size": 32,
        "zero_point": True,
    }
    def fake_get_device_capability():
        return (None, None)
    monkeypatch.setattr(__name__ + ".get_device_capability", fake_get_device_capability)
    with pytest.raises(ValueError, match="not supported.*Minimum capability: 75"):
        MoeWNA16Config.from_config(config)

def test_gptq_marlin_compatibility(monkeypatch):
    # Simulate GPTQMarlinConfig.is_gptq_marlin_compatible returning True/False
    config = {
        "quant_method": "gptq",
        "bits": 4,
        "group_size": 32,
        "sym": False,
        "marlin_compatible": True,
    }
    # Should set use_marlin True
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output # 9.95μs -> 7.54μs (32.0% faster)

    config["marlin_compatible"] = False
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output # 3.66μs -> 2.76μs (32.5% faster)

def test_awq_modules_to_not_convert_explicit_none():
    # AWQ config, modules_to_not_convert explicitly None
    config = {
        "quant_method": "awq",
        "bits": 4,
        "group_size": 32,
        "zero_point": True,
        "modules_to_not_convert": None,
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output

def test_gptq_modules_to_not_convert_always_empty():
    # For gptq, modules_to_not_convert must always be []
    config = {
        "quant_method": "gptq",
        "bits": 4,
        "group_size": 32,
        "sym": False,
        "modules_to_not_convert": ["foo"],  # Should be ignored
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output # 9.76μs -> 7.66μs (27.4% faster)

# --- LARGE SCALE TEST CASES ---

def test_large_group_size():
    # Test with large group_size
    config = {
        "quant_method": "gptq",
        "bits": 4,
        "group_size": 999,
        "sym": False,
        "marlin_compatible": False,
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output # 7.50μs -> 5.29μs (41.7% faster)

def test_large_modules_to_not_convert():
    # Large list for modules_to_not_convert
    modules = [f"layer{i}" for i in range(500)]
    config = {
        "quant_method": "awq",
        "bits": 4,
        "group_size": 64,
        "zero_point": True,
        "modules_to_not_convert": modules,
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output

def test_many_configs_in_succession():
    # Create and check 100 configs in a loop (scalability, determinism)
    for i in range(100):
        config = {
            "quant_method": "gptq" if i % 2 == 0 else "awq",
            "bits": 4,
            "group_size": i + 1,
            "sym": False if i % 2 == 0 else None,
            "zero_point": True if i % 2 == 1 else None,
            "marlin_compatible": i % 3 == 0,
        }
        if config["quant_method"] == "gptq":
            config["sym"] = False
            codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output
        else:
            config["zero_point"] = True
            codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output

def test_large_config_dict():
    # Config with many irrelevant keys (should ignore extras)
    config = {
        "quant_method": "gptq",
        "bits": 4,
        "group_size": 32,
        "sym": False,
        "lm_head": True,
    }
    # Add 900 irrelevant keys
    for i in range(900):
        config[f"irrelevant_{i}"] = i
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output # 7.83μs -> 7.40μs (5.82% faster)

def test_awq_with_long_module_names():
    # modules_to_not_convert with long string names
    modules = [f"layer_{'x'*100}_{i}" for i in range(50)]
    config = {
        "quant_method": "awq",
        "bits": 4,
        "group_size": 32,
        "zero_point": False,
        "modules_to_not_convert": modules,
    }
    codeflash_output = MoeWNA16Config.from_config(config); obj = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-MoeWNA16Config.from_config-mhoy795d and push.

Codeflash

**Explanation of Optimizations**

1. **base_config.py**
    - Changed the default for `packed_modules_mapping` initialization from `dict()` to `{}`. This is a micro-optimization as `{}` is faster than `dict()` for literal empty dict creation.
    - Optimized the looping in `get_from_keys` by using the generator version of the pattern: return next((config[k] for k in keys if k in config), ...) with a default sentinel to raise ValueError. This reduces average-case lookup time when the key is found early.

2. **moe_wna16.py**
    - Avoided recalculating device capability tuple if not needed.
    - Pulled the computation of `modules_to_not_convert` outside the conditional, as the logic for setting it to `[]` or a default value is repeated.
    - Used direct truthy-check for `modules_to_not_convert`, reducing a small overhead and slightly improving code clarity.
    - Used dictionary unpacking (still compatible with 3.10), e.g., for method calls that take a dictionary, since it does not apply here, left unchanged.
    - Changed `modules_to_not_convert` assignment: instead of if/else None-check, just use `or` to fallback to empty list, saving a branch at runtime.
    - No other runtime optimizations were applicable while preserving behavioral and code style constraints.

---
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 7, 2025 14:25
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Nov 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant