Skip to content

[QEff]: Add OpenAI Oss Models (gpt_oss) #534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 63 additions & 15 deletions QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,61 +120,109 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:

class SplitGateUpWeightsTransform(PytorchTransform):
"""
split fused Gate+Up weights and copy into the model
Split fused Gate+Up weights and copy into the model.
Handles both standard MoE models and GptOss models.

For every transformer layer inside `model`:
• expects <PREFIX>.experts.gate_up_proj in the *source* `sd`
• copies halves into
<PREFIX>.experts.gate_proj <-- Gate [E,H,I]
<PREFIX>.experts.up_proj <-- Up [E,H,I]
• expects <PREFIX>.experts.gate_up_proj in the *source* `sd`
• copies halves into
<PREFIX>.experts.gate_proj <-- Gate [E,H,I]
<PREFIX>.experts.up_proj <-- Up [E,H,I]

Handles both interleaved weights (GptOss) and concatenated weights (standard MoE).
Also handles bias terms when present.
"""

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False
model_class = model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__

if model_class not in VLM_SPLIT_GATE_UP_WEIGHTS:
return model, transformed

model_tmp = model.language_model if hasattr(model, "language_model") else model

num_layers = len(model_tmp.model.layers)
delete_fused_key = True
sd = model_tmp.state_dict()

for layer_idx in range(num_layers):
# Determine if this is a GptOss model or standard MoE model
is_gpt_oss = hasattr(model_tmp.model.layers[layer_idx], "mlp")

# ---- build the textual prefix once per layer ----------
prefix = f"model.layers.{layer_idx}.feed_forward.experts."
if is_gpt_oss:
prefix = f"model.layers.{layer_idx}.mlp.experts."
experts = model_tmp.model.layers[layer_idx].mlp.experts
else:
prefix = f"model.layers.{layer_idx}.feed_forward.experts."
experts = model_tmp.model.layers[layer_idx].feed_forward.experts

fused_key = prefix + "gate_up_proj"
gate_key = prefix + "gate_proj"
up_key = prefix + "up_proj"

# ---- split [E,H,2I] → two [E,H,I] tensors ----------------------
fused = sd[fused_key] # [E, H, 2I] (no .weight here)
# Check if we have bias terms (GptOss case)
has_bias = fused_key + "_bias" in sd
if has_bias:
fused_bias_key = fused_key + "_bias"
gate_bias_key = gate_key + "_bias"
up_bias_key = up_key + "_bias"

# ---- split weights based on model type ----------------------
fused = sd[fused_key] # [E, H, 2I]
E, H, two_I = fused.shape
ffn_dim = two_I // 2
gate, up = fused.split(ffn_dim, dim=-1) # views – no copy

experts = model_tmp.model.layers[layer_idx].feed_forward.experts
if is_gpt_oss:
# For GptOss, gate/up are interleaved: [gate0, up0, gate1, up1, ...]
gate = fused[..., ::2] # [E, H, I] - even indices
up = fused[..., 1::2] # [E, H, I] - odd indices
else:
# For standard MoE, gate/up are concatenated: [gate, up]
ffn_dim = two_I // 2
gate, up = fused.split(ffn_dim, dim=-1) # views – no copy

# Copy weights to model
experts.gate_proj.data.copy_(gate)
experts.up_proj.data.copy_(up)

# Handle bias if present
if has_bias:
fused_bias = sd[fused_bias_key] # [E, 2I]

if is_gpt_oss:
gate_bias = fused_bias[..., ::2] # [E, I] - even indices
up_bias = fused_bias[..., 1::2] # [E, I] - odd indices
else:
ffn_dim = fused_bias.shape[-1] // 2
gate_bias, up_bias = fused_bias.split(ffn_dim, dim=-1)

experts.gate_proj_bias.data.copy_(gate_bias)
experts.up_proj_bias.data.copy_(up_bias)

# ---- update the state-dict so load_state_dict sees the right keys
sd[gate_key] = gate
sd[up_key] = up

if has_bias:
sd[gate_bias_key] = gate_bias
sd[up_bias_key] = up_bias

# Delete fused keys
if delete_fused_key:
del sd[fused_key]
if has_bias:
del sd[fused_bias_key]

logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
transformed = True

if hasattr(model, "language_model"):
model.language_model = model_tmp
else:
model = model_tmp

return model, transformed


VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM"}
# Keep the existing list of supported models
VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM", "QEffGptOssForCausalLM"}
Loading
Loading