From 485308ade3b66a692350229ff06eb94e73c35d5c Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Wed, 6 Aug 2025 19:00:00 +0000 Subject: [PATCH 01/10] [QEff]: Add gpt_oss Signed-off-by: vbaddi --- QEfficient/base/pytorch_transforms.py | 79 ++- .../transformers/models/gpt_oss/__init__.py | 1 + .../models/gpt_oss/modeling_gpt_oss.py | 535 ++++++++++++++++++ .../transformers/models/modeling_auto.py | 3 +- .../transformers/models/pytorch_transforms.py | 25 + examples/gpt_oss.py | 54 ++ pyproject.toml | 2 +- 7 files changed, 694 insertions(+), 5 deletions(-) create mode 100644 QEfficient/transformers/models/gpt_oss/__init__.py create mode 100644 QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py create mode 100644 examples/gpt_oss.py diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py index a20fc4cb3..e6ec713f8 100644 --- a/QEfficient/base/pytorch_transforms.py +++ b/QEfficient/base/pytorch_transforms.py @@ -144,7 +144,7 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: sd = model_tmp.state_dict() for layer_idx in range(num_layers): # ---- build the textual prefix once per layer ---------- - prefix = f"model.layers.{layer_idx}.feed_forward.experts." + prefix = f"model.layers.{layer_idx}.mlp.experts." fused_key = prefix + "gate_up_proj" gate_key = prefix + "gate_proj" @@ -156,7 +156,7 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: ffn_dim = two_I // 2 gate, up = fused.split(ffn_dim, dim=-1) # views – no copy - experts = model_tmp.model.layers[layer_idx].feed_forward.experts + experts = model_tmp.model.layers[layer_idx].mlp.experts experts.gate_proj.data.copy_(gate) experts.up_proj.data.copy_(up) @@ -177,4 +177,77 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: return model, transformed -VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM"} +class SplitGateUpWeightsTransformGPTOSS(PytorchTransform): + """ + split fused Gate+Up weights and copy into the model + + For every transformer layer inside `model`: + • expects .experts.gate_up_proj in the *source* `sd` + • copies halves into + .experts.gate_proj <-- Gate [E,H,I] + .experts.up_proj <-- Up [E,H,I] + """ + + @classmethod + def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: + transformed = False + model_class = model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__ + + if model_class not in VLM_SPLIT_GATE_UP_WEIGHTS: + return model, transformed + + model_tmp = model.language_model if hasattr(model, "language_model") else model + num_layers = len(model_tmp.model.layers) + delete_fused_key = True + sd = model_tmp.state_dict() + + for layer_idx in range(num_layers): + # ---- build the textual prefix once per layer ---------- + prefix = f"model.layers.{layer_idx}.mlp.experts." + fused_key = prefix + "gate_up_proj" + fused_bias_key = prefix + "gate_up_proj_bias" + gate_key = prefix + "gate_proj" + up_key = prefix + "up_proj" + gate_bias_key = prefix + "gate_proj_bias" + up_bias_key = prefix + "up_proj_bias" + + # ---- split [E,H,2I] → two [E,H,I] tensors ---------------------- + fused = sd[fused_key] # [E, H, 2I] + fused_bias = sd[fused_bias_key] # [E, 2I] + E, H, two_I = fused.shape + # ffn_dim = two_I // 2 + + # For GptOss, gate/up are interleaved: [gate0, up0, gate1, up1, ...] + gate = fused[..., ::2] # [E, H, I] - even indices + up = fused[..., 1::2] # [E, H, I] - odd indices + gate_bias = fused_bias[..., ::2] # [E, I] - even indices + up_bias = fused_bias[..., 1::2] # [E, I] - odd indices + + experts = model_tmp.model.layers[layer_idx].mlp.experts + experts.gate_proj.data.copy_(gate) + experts.up_proj.data.copy_(up) + experts.gate_proj_bias.data.copy_(gate_bias) + experts.up_proj_bias.data.copy_(up_bias) + + # ---- update the state-dict so load_state_dict sees the right keys + sd[gate_key] = gate + sd[up_key] = up + sd[gate_bias_key] = gate_bias + sd[up_bias_key] = up_bias + + if delete_fused_key: + del sd[fused_key] + del sd[fused_bias_key] + + logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})") + transformed = True + + if hasattr(model, "language_model"): + model.language_model = model_tmp + else: + model = model_tmp + + return model, transformed + + +VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM", "QEffGptOssForCausalLM"} diff --git a/QEfficient/transformers/models/gpt_oss/__init__.py b/QEfficient/transformers/models/gpt_oss/__init__.py new file mode 100644 index 000000000..792d60054 --- /dev/null +++ b/QEfficient/transformers/models/gpt_oss/__init__.py @@ -0,0 +1 @@ +# diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py new file mode 100644 index 000000000..60c48b3ad --- /dev/null +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -0,0 +1,535 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +from typing import Callable, Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from transformers.models.gpt_oss.modeling_gpt_oss import ( + GptOssAttention, + GptOssConfig, + GptOssDecoderLayer, + GptOssExperts, + GptOssForCausalLM, + GptOssMLP, + GptOssModel, + GptOssRotaryEmbedding, + repeat_kv, +) +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs + +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE + + +class QEffGptOssExperts(GptOssExperts): + def __qeff_init__(self): + self.gate_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, self.expert_dim)) + self.up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, self.expert_dim)) + self.gate_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) + self.up_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) + + +class QEffGptOssMLP(GptOssMLP): + def forward(self, hidden: torch.Tensor): + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + # Skip if no tokens routed to this expert + # if (routing_weight > 0).sum() == 0: + # continue + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + # Gate and Up projections + gate = (hidden @ W_g) + b_g # [T, I] + up = (hidden @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + # Down projection + down_out = (intermediate @ W_d) + b_d # [T, H] + + # Apply routing weights and accumulate + masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out)) + expert_out += masked_down + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + + +# Can be replaced with llama/modeling_llama.py::QEffLlamaRotaryEmbedding but keeping it following transformers ideology +class QEffQwen2RotaryEmbedding(GptOssRotaryEmbedding): + """ + Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, config: GptOssConfig, device=None): + super().__init__(config=config) + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + scores = probs[..., :-1] # we drop the sink here + attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class QEffGptOssAttention(GptOssAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __qeff_init__(self): + self.rotary_emb = QEffQwen2RotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "config": self.config, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + +class QEffGptOssDecoderLayer(GptOssDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + batch_index=batch_index, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores + hidden_states = residual + hidden_states + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class QEffGptOssModel(GptOssModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + return_legacy_cache = False + # import ipdb; ipdb.set_trace() + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + # past_key_values = QEffHybridCache.from_legacy_cache(self.config, past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_length) + # causal_mask = _create_causal_mask( + # position_ids=position_ids, target_length=target_length, sliding_window=self.config.sliding_window + # ) + + # # It may already have been prepared by e.g. `generate` + # if not isinstance(causal_mask_mapping := attention_mask, dict): + # mask_kwargs = { + # "config": self.config, + # "input_embeds": inputs_embeds, + # "attention_mask": attention_mask, + # "cache_position": cache_position, + # "past_key_values": past_key_values, + # } + # causal_mask_mapping = { + # "full_attention": create_causal_mask(**mask_kwargs), + # "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + # } + + hidden_states = inputs_embeds + # position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +class QEffGptOssForCausalLM(GptOssForCausalLM): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, GptOssForCausalLM + + >>> model = GptOssForCausalLM.from_pretrained("mistralai/GptOss-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/GptOss-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states) + logits = logits.float() + + # # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + # logits = self.lm_head(hidden_states[:, slice_indices, :]) + + # loss = None + # if labels is not None: + # loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + # aux_loss = None + # if output_router_logits: + # aux_loss = load_balancing_loss_func( + # outputs.router_logits, + # self.num_experts, + # self.num_experts_per_tok, + # attention_mask, + # ) + # if labels is not None: + # loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=None, + aux_loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 2f3ee3dc0..8690ea03a 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -27,7 +27,7 @@ import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform -from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform +from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform, SplitGateUpWeightsTransformGPTOSS from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.generation.text_generation_inference import ( CloudAI100ExecInfoNew, @@ -1381,6 +1381,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): CustomOpsTransform, KVCacheTransform, SplitGateUpWeightsTransform, + SplitGateUpWeightsTransformGPTOSS, KVCacheExternalModuleMapperTransform, ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index ca74c0ddd..20839590f 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -51,6 +51,15 @@ GPTBigCodeForCausalLM, GPTBigCodeModel, ) +from transformers.models.gpt_oss.modeling_gpt_oss import ( + GptOssAttention, + GptOssDecoderLayer, + GptOssExperts, + GptOssForCausalLM, + GptOssMLP, + GptOssModel, + GptOssRMSNorm, +) from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJForCausalLM, GPTJModel from transformers.models.granite.modeling_granite import ( GraniteAttention, @@ -199,6 +208,14 @@ QEffGPTBigCodeForCausalLM, QEffGPTBigCodeModel, ) +from QEfficient.transformers.models.gpt_oss.modeling_gpt_oss import ( + QEffGptOssAttention, + QEffGptOssDecoderLayer, + QEffGptOssExperts, + QEffGptOssForCausalLM, + QEffGptOssMLP, + QEffGptOssModel, +) from QEfficient.transformers.models.gptj.modeling_gptj import ( QEffGPTJAttention, QEffGPTJBlock, @@ -338,6 +355,7 @@ class CustomOpsTransform(ModuleMappingTransform): MllamaTextRMSNorm: CustomRMSNormAIC, GraniteRMSNorm: CustomRMSNormAIC, GraniteMoeRMSNorm: CustomRMSNormAIC, + GptOssRMSNorm: CustomRMSNormAIC, Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC, } @@ -399,6 +417,13 @@ class KVCacheTransform(ModuleMappingTransform): Gemma3TextModel: QEffGemma3TextModel, Gemma3ForCausalLM: QEffGemma3ForCausalLMModel, Gemma3ForConditionalGeneration: QEffGemma3ForConditionalGeneration, + # GPT_OSS + GptOssAttention: QEffGptOssAttention, + GptOssDecoderLayer: QEffGptOssDecoderLayer, + GptOssModel: QEffGptOssModel, + GptOssForCausalLM: QEffGptOssForCausalLM, + GptOssMLP: QEffGptOssMLP, + GptOssExperts: QEffGptOssExperts, # Granite GraniteModel: QEffGraniteModel, GraniteForCausalLM: QEffGraniteForCausalLM, diff --git a/examples/gpt_oss.py b/examples/gpt_oss.py new file mode 100644 index 000000000..d33500f92 --- /dev/null +++ b/examples/gpt_oss.py @@ -0,0 +1,54 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +## BEFORE RUNNING PLS, RUN THE CONVERT SCRIPT TO CONVERT THE SAFETENSORS FROM FP4 to BF16 +## SEE DETAILS HERE: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +## ONCE CONVERTED, PASS THE MODIFIED WEIGHTS TO THE MODEL_ID BELOW +import torch +from transformers import AutoConfig, GptOssForCausalLM, TextStreamer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils._utils import load_hf_tokenizer +from QEfficient.utils.constants import Constants +from QEfficient.utils.run_utils import ApiRunner + +torch.manual_seed(42) +model_id = "CONVERTED_WEIGHTS" # See Comments above to convert saftensors to BF16 +config = AutoConfig.from_pretrained(model_id) + +model = GptOssForCausalLM.from_pretrained( + model_id, torch_dtype=torch.float32, attn_implementation="eager", config=config +) +model.eval() + +tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_id) +config = model.config +batch_size = len(Constants.INPUT_STR) + +api_runner = ApiRunner(batch_size, tokenizer, config, Constants.INPUT_STR, Constants.PROMPT_LEN, Constants.CTX_LEN) + +qeff_model = QEFFAutoModelForCausalLM(model, continuous_batching=False) +onnx_model_path = qeff_model.export() +qpc_path = qeff_model.compile( + prefill_seq_len=32, + ctx_len=256, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=4, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, +) +print(f"qpc path is {qpc_path}") +streamer = TextStreamer(tokenizer) +exec_info = qeff_model.generate( + tokenizer, + streamer=streamer, + prompts="Who is your creator? and What all you are allowed to do?", + device_ids=[0, 1, 2, 3], +) diff --git a/pyproject.toml b/pyproject.toml index 479736c22..00b2bb7ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ ] requires-python = ">=3.8,<3.11" dependencies = [ - "transformers==4.51.3", + "transformers==4.55.0", "huggingface-hub==0.30.0", "hf_transfer==0.1.9", "peft==0.13.2", From 59e2115f3818eb1d64d1b204f3d010a5daad1019 Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Wed, 6 Aug 2025 19:02:41 +0000 Subject: [PATCH 02/10] nit: update transforms Signed-off-by: vbaddi --- QEfficient/base/pytorch_transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py index e6ec713f8..1c1acfabd 100644 --- a/QEfficient/base/pytorch_transforms.py +++ b/QEfficient/base/pytorch_transforms.py @@ -144,7 +144,7 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: sd = model_tmp.state_dict() for layer_idx in range(num_layers): # ---- build the textual prefix once per layer ---------- - prefix = f"model.layers.{layer_idx}.mlp.experts." + prefix = f"model.layers.{layer_idx}.feed_forward.experts." fused_key = prefix + "gate_up_proj" gate_key = prefix + "gate_proj" @@ -156,7 +156,7 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: ffn_dim = two_I // 2 gate, up = fused.split(ffn_dim, dim=-1) # views – no copy - experts = model_tmp.model.layers[layer_idx].mlp.experts + experts = model_tmp.model.layers[layer_idx].feed_forward.experts experts.gate_proj.data.copy_(gate) experts.up_proj.data.copy_(up) From a6c281269d158e8b805982334c58ba8ed487114a Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Wed, 6 Aug 2025 19:05:18 +0000 Subject: [PATCH 03/10] nit: add header to __init__ Signed-off-by: vbaddi --- QEfficient/transformers/models/gpt_oss/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/QEfficient/transformers/models/gpt_oss/__init__.py b/QEfficient/transformers/models/gpt_oss/__init__.py index 792d60054..75daf1953 100644 --- a/QEfficient/transformers/models/gpt_oss/__init__.py +++ b/QEfficient/transformers/models/gpt_oss/__init__.py @@ -1 +1,6 @@ +# ----------------------------------------------------------------------------- # +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- From 8e5783ec70fc4e1185b29b45d1bdfca4299babdc Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 7 Aug 2025 15:23:21 +0530 Subject: [PATCH 04/10] apirunner change Signed-off-by: Onkar Chougule --- QEfficient/transformers/cache_utils.py | 100 ++++++++++++++++++ QEfficient/transformers/modeling_utils.py | 1 + .../models/gpt_oss/modeling_gpt_oss.py | 79 ++++++-------- .../transformers/models/modeling_auto.py | 17 ++- QEfficient/utils/generate_inputs.py | 6 +- 5 files changed, 155 insertions(+), 48 deletions(-) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 16767fbe2..dc87b2be7 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -488,3 +488,103 @@ def update( ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) return k_out, v_out + + +# This is a hack for now, until we get to merging this code with HybridCache class, +# We don't really need to inherit transformers classes as their cache classes are made to work with pytorch and +# ours are made to work with AIC +class QEffHybridCacheForGPTOSS: + def __init__(self, config, batch_size, max_cache_len): + self.max_cache_len = max_cache_len + self.batch_size = batch_size + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + + @classmethod + def from_legacy_cache( + cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "HybridCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for + backward compatibility.""" + cache = cls(config, batch_size=past_key_values[0][0].shape[0], max_cache_len=past_key_values[1][0].shape[2]) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.key_cache) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or len(self.key_cache[layer_idx]) == 0 # the layer has no cache + ) + layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + return layer_seq_length + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + return legacy_cache + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + k_out, v_out = key_states, value_states + else: + position_ids = cache_kwargs.get("position_ids") + is_sliding_layer = cache_kwargs.get("is_sliding") + sliding_window = cache_kwargs.get("sliding_window") + if is_sliding_layer: + kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % sliding_window) + else: + kv_position_ids = position_ids + + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], kv_position_ids, value_states + ) + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Original Gather + ctx_len = self.key_cache[layer_idx].shape[2] + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + all_indices = torch.arange(sliding_window) + kv_position_ids.max() + 1 + rolling_indices = torch.where(all_indices > (sliding_window - 1), all_indices % sliding_window, all_indices) + if is_sliding_layer: + final_indices = torch.where(position_ids.max() >= sliding_window, rolling_indices, ctx_indices) + else: + final_indices = ctx_indices + k_out = CtxGatherFunc.apply(k_out, final_indices) + v_out = CtxGatherFunc.apply(v_out, final_indices) + ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + if is_sliding_layer: + v_out = torch.where(position_ids.max() >= sliding_window, v_out, ctx_v_out) + return k_out, v_out diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index 72b7acd98..fa6bfd0f4 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -183,6 +183,7 @@ ] ) +# This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc. DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"} # Define a transformers layers to QEff layers dictionary diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 60c48b3ad..5fa12d106 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -28,7 +28,7 @@ from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs -from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.cache_utils import QEffHybridCacheForGPTOSS from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -232,6 +232,7 @@ def forward( past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + sliding_mask=None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] @@ -241,10 +242,10 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] + # kv_seq_len = key_states.shape[-2] - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -255,9 +256,16 @@ def forward( "batch_index": batch_index, "position_ids": position_ids, "config": self.config, + "is_sliding": self.sliding_window is not None, + "sliding_window": self.sliding_window, } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + if self.sliding_window is not None: + attention_mask = sliding_mask + else: + attention_mask = attention_mask + attention_interface: Callable = eager_attention_forward attn_output, attn_weights = attention_interface( self, @@ -289,6 +297,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + sliding_mask=None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor]: residual = hidden_states @@ -303,6 +312,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + sliding_mask=sliding_mask, **kwargs, ) hidden_states = residual + hidden_states @@ -351,11 +361,9 @@ def forward( raise ValueError("You must specify exactly one of input_ids or inputs_embeds") return_legacy_cache = False - # import ipdb; ipdb.set_trace() if use_cache and not isinstance(past_key_values, Cache): return_legacy_cache = True - past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) - # past_key_values = QEffHybridCache.from_legacy_cache(self.config, past_key_values) + past_key_values = QEffHybridCacheForGPTOSS.from_legacy_cache(self.config, past_key_values) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -368,25 +376,13 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens - causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_length) - # causal_mask = _create_causal_mask( - # position_ids=position_ids, target_length=target_length, sliding_window=self.config.sliding_window - # ) - - # # It may already have been prepared by e.g. `generate` - # if not isinstance(causal_mask_mapping := attention_mask, dict): - # mask_kwargs = { - # "config": self.config, - # "input_embeds": inputs_embeds, - # "attention_mask": attention_mask, - # "cache_position": cache_position, - # "past_key_values": past_key_values, - # } - # causal_mask_mapping = { - # "full_attention": create_causal_mask(**mask_kwargs), - # "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), - # } + # target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values.max_cache_len) + sliding_mask = _create_causal_mask( + position_ids=position_ids, + target_length=self.config.sliding_window, + sliding_window=self.config.sliding_window, + ) hidden_states = inputs_embeds # position_embeddings = self.rotary_emb(hidden_states, position_ids) @@ -408,6 +404,7 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, + sliding_mask=sliding_mask, **kwargs, ) hidden_states = layer_outputs[0] @@ -505,25 +502,6 @@ def forward( logits = self.lm_head(hidden_states) logits = logits.float() - # # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - # logits = self.lm_head(hidden_states[:, slice_indices, :]) - - # loss = None - # if labels is not None: - # loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) - - # aux_loss = None - # if output_router_logits: - # aux_loss = load_balancing_loss_func( - # outputs.router_logits, - # self.num_experts, - # self.num_experts_per_tok, - # attention_mask, - # ) - # if labels is not None: - # loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device - return MoeCausalLMOutputWithPast( loss=None, aux_loss=None, @@ -533,3 +511,14 @@ def forward( attentions=outputs.attentions, router_logits=outputs.router_logits, ) + + def get_pkv_dynamic_axes( + self, + ): + pkv_dynamic_axes = [] + for layer_type in self.config.layer_types: + if layer_type == "sliding_attention": + pkv_dynamic_axes.append({0: "batch_size", 2: "sliding_window"}) + elif layer_type == "full_attention": + pkv_dynamic_axes.append({0: "batch_size", 2: "ctx_len"}) + return pkv_dynamic_axes diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 8690ea03a..1c4f674d4 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1592,10 +1592,20 @@ def export(self, export_dir: Optional[str] = None) -> str: output_names.append(f"past_{kv}.{i}_RetainedState") else: + # HACK: create common function for this including above if condition code + pkv_dynamic_axes = ( + self.model.get_pkv_dynamic_axes() if hasattr(self.model, "get_pkv_dynamic_axes") else pkv_dynamic_axes + ) + pkv_dynamic_axes = ( + [pkv_dynamic_axes] * self.model.config.num_hidden_layers + if isinstance(pkv_dynamic_axes, dict) + else pkv_dynamic_axes + ) + for i in range(self.num_layers): for kv in ["key", "value"]: example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) - dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] output_names.append(f"past_{kv}.{i}_RetainedState") if self.continuous_batching: @@ -1842,6 +1852,11 @@ def compile( for kv in ["key", "value"]: custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype + # HACK for now + if self.model.config.model_type == "gpt_oss": + for spec in specializations: + spec.update({"sliding_window": 128}) + qpc_path = self._compile( onnx_path=onnx_path, compile_dir=compile_dir, diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index 361be3080..2f353fe52 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -91,9 +91,11 @@ def prepare_pytorch_inputs(self): inputs["batch_index"] = torch.arange(1).view(-1, 1) past_key_values = [] + sliding_padding_shape = self.padding_shape[:2] + [self.config.sliding_window] + self.padding_shape[-1] for i in range(self.n_layer): - past_key = torch.zeros((self.padding_shape), dtype=torch.float32) - past_value = torch.zeros((self.padding_shape), dtype=torch.float32) + pad_shape = sliding_padding_shape if self.config.layer_types[i] == "sliding_attention" else self.padding_shape + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) pkv = (past_key, past_value) past_key_values.append(pkv) inputs["past_key_values"] = tuple(past_key_values) From 5c3c971449e7f70f5403759b4e035cec30cd19f1 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 7 Aug 2025 19:24:24 +0530 Subject: [PATCH 05/10] added test along with simplified Hybridcache Signed-off-by: Onkar Chougule --- QEfficient/transformers/cache_utils.py | 18 ++++-- .../models/gpt_oss/modeling_gpt_oss.py | 6 +- .../models/mistral/modeling_mistral.py | 5 -- .../transformers/models/modeling_auto.py | 2 +- QEfficient/utils/generate_inputs.py | 6 +- pyproject.toml | 6 +- tests/test_gpt.py | 61 +++++++++++++++++++ 7 files changed, 86 insertions(+), 18 deletions(-) create mode 100644 tests/test_gpt.py diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index dc87b2be7..de51f8114 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -494,9 +494,10 @@ def update( # We don't really need to inherit transformers classes as their cache classes are made to work with pytorch and # ours are made to work with AIC class QEffHybridCacheForGPTOSS: - def __init__(self, config, batch_size, max_cache_len): + def __init__(self, config, batch_size, max_cache_len, sliding_window_len): self.max_cache_len = max_cache_len self.batch_size = batch_size + self.sliding_window_len = sliding_window_len self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] @@ -506,7 +507,12 @@ def from_legacy_cache( ) -> "HybridCache": """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for backward compatibility.""" - cache = cls(config, batch_size=past_key_values[0][0].shape[0], max_cache_len=past_key_values[1][0].shape[2]) + cache = cls( + config, + batch_size=past_key_values[0][0].shape[0], + max_cache_len=past_key_values[1][0].shape[2], + sliding_window_len=past_key_values[0][0].shape[2], + ) if past_key_values is not None: for layer_idx in range(len(past_key_values)): key_states, value_states = past_key_values[layer_idx] @@ -554,6 +560,7 @@ def update( position_ids = cache_kwargs.get("position_ids") is_sliding_layer = cache_kwargs.get("is_sliding") sliding_window = cache_kwargs.get("sliding_window") + if is_sliding_layer: kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % sliding_window) else: @@ -576,12 +583,15 @@ def update( invalid_idx_value = 0 ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - all_indices = torch.arange(sliding_window) + kv_position_ids.max() + 1 - rolling_indices = torch.where(all_indices > (sliding_window - 1), all_indices % sliding_window, all_indices) if is_sliding_layer: + all_indices = torch.arange(sliding_window) + kv_position_ids.max() + 1 + rolling_indices = torch.where( + all_indices > (sliding_window - 1), all_indices % sliding_window, all_indices + )[None, None, ...] final_indices = torch.where(position_ids.max() >= sliding_window, rolling_indices, ctx_indices) else: final_indices = ctx_indices + k_out = CtxGatherFunc.apply(k_out, final_indices) v_out = CtxGatherFunc.apply(v_out, final_indices) ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 5fa12d106..cb25246f0 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -257,7 +257,7 @@ def forward( "position_ids": position_ids, "config": self.config, "is_sliding": self.sliding_window is not None, - "sliding_window": self.sliding_window, + "sliding_window": past_key_value.sliding_window_len, } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -380,8 +380,8 @@ def forward( causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values.max_cache_len) sliding_mask = _create_causal_mask( position_ids=position_ids, - target_length=self.config.sliding_window, - sliding_window=self.config.sliding_window, + target_length=past_key_values.sliding_window_len, + sliding_window=past_key_values.sliding_window_len, ) hidden_states = inputs_embeds diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 60b1c929d..7a084abcb 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -24,7 +24,6 @@ MistralForCausalLM, MistralModel, MistralRotaryEmbedding, - logger, repeat_kv, rotate_half, ) @@ -298,10 +297,6 @@ def forward( if use_cache and not isinstance(past_key_values, Cache) and not self.training: past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) return_legacy_cache = True - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " - "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" - ) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 1c4f674d4..ff92159f5 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1380,7 +1380,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): FP8DeQuantLinearToLinearTransform, CustomOpsTransform, KVCacheTransform, - SplitGateUpWeightsTransform, + # SplitGateUpWeightsTransform, SplitGateUpWeightsTransformGPTOSS, KVCacheExternalModuleMapperTransform, ] diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index 2f353fe52..fd81e6306 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -91,9 +91,11 @@ def prepare_pytorch_inputs(self): inputs["batch_index"] = torch.arange(1).view(-1, 1) past_key_values = [] - sliding_padding_shape = self.padding_shape[:2] + [self.config.sliding_window] + self.padding_shape[-1] + sliding_padding_shape = self.padding_shape[:2] + [self.config.sliding_window] + [self.padding_shape[-1]] for i in range(self.n_layer): - pad_shape = sliding_padding_shape if self.config.layer_types[i] == "sliding_attention" else self.padding_shape + pad_shape = ( + sliding_padding_shape if self.config.layer_types[i] == "sliding_attention" else self.padding_shape + ) past_key = torch.zeros((pad_shape), dtype=torch.float32) past_value = torch.zeros((pad_shape), dtype=torch.float32) pkv = (past_key, past_value) diff --git a/pyproject.toml b/pyproject.toml index 00b2bb7ef..5dfb00e63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,10 +20,10 @@ classifiers = [ requires-python = ">=3.8,<3.11" dependencies = [ "transformers==4.55.0", - "huggingface-hub==0.30.0", + "huggingface-hub", "hf_transfer==0.1.9", - "peft==0.13.2", - "datasets==2.20.0", + "peft", + "datasets", "fsspec==2023.6.0", "multidict==6.0.4", "urllib3<2", diff --git a/tests/test_gpt.py b/tests/test_gpt.py new file mode 100644 index 000000000..27b423b63 --- /dev/null +++ b/tests/test_gpt.py @@ -0,0 +1,61 @@ +import torch +from transformers import AutoConfig, AutoModelForCausalLM, GptOssForCausalLM, TextStreamer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils._utils import load_hf_tokenizer +from QEfficient.utils.constants import Constants +from QEfficient.utils.run_utils import ApiRunner + +Constants.INPUT_STR=["Make sure tokens don't repeat\n\nTo make a simple cup of coffee, start by boiling water. Add one to two teaspoons of instant coffee powder to a mug. Pour the hot water over the coffee and stir well. Add sugar and milk to taste, if desired. For brewed coffee, use a French press or drip filter. Add coarsely ground coffee to the device, pour hot water over it, and let it steep for four minutes. Press or filter the coffee, then serve"] + +torch.manual_seed(42) +model_id = "openai/gpt-oss-20b" +config = AutoConfig.from_pretrained(model_id) +config.num_hidden_layers=2 + +# Remove the quantization_config attribute if it exists, to avoid MXFP4 Issues +if hasattr(config, "quantization_config"): + delattr(config, "quantization_config") + +model = GptOssForCausalLM.from_pretrained( + "/home/vbaddi/transformers/src/transformers/models/gpt_oss/new_weights", torch_dtype=torch.float32, attn_implementation="eager", config=config +) +model.eval() +model.generation_config.sample=False +tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_id) +config = model.config +batch_size = len(Constants.INPUT_STR) + +api_runner = ApiRunner(batch_size, tokenizer, config, Constants.INPUT_STR, 97, 256) +pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model) + + +qeff_model = QEFFAutoModelForCausalLM(model, continuous_batching=False) +# pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) + +onnx_model_path = qeff_model.export() + + +qpc_path = qeff_model.compile( + prefill_seq_len=128, + ctx_len=256, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, +) +print(f"qpc path is {qpc_path}") +streamer = TextStreamer(tokenizer) +exec_info = qeff_model.generate( + tokenizer, + streamer=streamer, + prompts=Constants.INPUT_STR[0], + device_ids=[0], +) + +import ipdb; ipdb.set_trace() +print(pytorch_hf_tokens) +print(exec_info) From ce53d3cd9370dd2406e6b3cf37e54320f82b0b7e Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 7 Aug 2025 19:26:57 +0530 Subject: [PATCH 06/10] added test assert Signed-off-by: Onkar Chougule --- tests/test_gpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_gpt.py b/tests/test_gpt.py index 27b423b63..92c17c353 100644 --- a/tests/test_gpt.py +++ b/tests/test_gpt.py @@ -56,6 +56,6 @@ device_ids=[0], ) -import ipdb; ipdb.set_trace() print(pytorch_hf_tokens) print(exec_info) +assert (exec_info.generated_ids[0][0,:159] == pytorch_hf_tokens).all() From e0bd90fe796bcf16fc7207f56ae5705be9ae75a9 Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Thu, 7 Aug 2025 14:34:08 +0000 Subject: [PATCH 07/10] nit: update modeling and make transform uniform Signed-off-by: vbaddi --- QEfficient/base/pytorch_transforms.py | 139 +++++------ QEfficient/transformers/cache_utils.py | 218 +++++++++++------- .../models/gpt_oss/modeling_gpt_oss.py | 109 +++++---- .../models/mistral/modeling_mistral.py | 5 - .../transformers/models/modeling_auto.py | 3 +- pyproject.toml | 6 +- 6 files changed, 265 insertions(+), 215 deletions(-) diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py index 1c1acfabd..e503a057f 100644 --- a/QEfficient/base/pytorch_transforms.py +++ b/QEfficient/base/pytorch_transforms.py @@ -120,79 +120,23 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: class SplitGateUpWeightsTransform(PytorchTransform): """ - split fused Gate+Up weights and copy into the model + Split fused Gate+Up weights and copy into the model. + Handles both standard MoE models and GptOss models. For every transformer layer inside `model`: - • expects .experts.gate_up_proj in the *source* `sd` - • copies halves into - .experts.gate_proj <-- Gate [E,H,I] - .experts.up_proj <-- Up [E,H,I] - """ - - @classmethod - def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: - transformed = False - model_class = model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__ - - if model_class not in VLM_SPLIT_GATE_UP_WEIGHTS: - return model, transformed - - model_tmp = model.language_model if hasattr(model, "language_model") else model - - num_layers = len(model_tmp.model.layers) - delete_fused_key = True - sd = model_tmp.state_dict() - for layer_idx in range(num_layers): - # ---- build the textual prefix once per layer ---------- - prefix = f"model.layers.{layer_idx}.feed_forward.experts." - - fused_key = prefix + "gate_up_proj" - gate_key = prefix + "gate_proj" - up_key = prefix + "up_proj" + • expects .experts.gate_up_proj in the *source* `sd` + • copies halves into + .experts.gate_proj <-- Gate [E,H,I] + .experts.up_proj <-- Up [E,H,I] - # ---- split [E,H,2I] → two [E,H,I] tensors ---------------------- - fused = sd[fused_key] # [E, H, 2I] (no .weight here) - E, H, two_I = fused.shape - ffn_dim = two_I // 2 - gate, up = fused.split(ffn_dim, dim=-1) # views – no copy - - experts = model_tmp.model.layers[layer_idx].feed_forward.experts - experts.gate_proj.data.copy_(gate) - experts.up_proj.data.copy_(up) - - # ---- update the state-dict so load_state_dict sees the right keys - sd[gate_key] = gate - sd[up_key] = up - - if delete_fused_key: - del sd[fused_key] - - logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})") - transformed = True - - if hasattr(model, "language_model"): - model.language_model = model_tmp - else: - model = model_tmp - return model, transformed - - -class SplitGateUpWeightsTransformGPTOSS(PytorchTransform): - """ - split fused Gate+Up weights and copy into the model - - For every transformer layer inside `model`: - • expects .experts.gate_up_proj in the *source* `sd` - • copies halves into - .experts.gate_proj <-- Gate [E,H,I] - .experts.up_proj <-- Up [E,H,I] + Handles both interleaved weights (GptOss) and concatenated weights (standard MoE). + Also handles bias terms when present. """ @classmethod def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: transformed = False model_class = model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__ - if model_class not in VLM_SPLIT_GATE_UP_WEIGHTS: return model, transformed @@ -202,42 +146,72 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: sd = model_tmp.state_dict() for layer_idx in range(num_layers): + # Determine if this is a GptOss model or standard MoE model + is_gpt_oss = hasattr(model_tmp.model.layers[layer_idx], "mlp") + # ---- build the textual prefix once per layer ---------- - prefix = f"model.layers.{layer_idx}.mlp.experts." + if is_gpt_oss: + prefix = f"model.layers.{layer_idx}.mlp.experts." + experts = model_tmp.model.layers[layer_idx].mlp.experts + else: + prefix = f"model.layers.{layer_idx}.feed_forward.experts." + experts = model_tmp.model.layers[layer_idx].feed_forward.experts + fused_key = prefix + "gate_up_proj" - fused_bias_key = prefix + "gate_up_proj_bias" gate_key = prefix + "gate_proj" up_key = prefix + "up_proj" - gate_bias_key = prefix + "gate_proj_bias" - up_bias_key = prefix + "up_proj_bias" - # ---- split [E,H,2I] → two [E,H,I] tensors ---------------------- + # Check if we have bias terms (GptOss case) + has_bias = fused_key + "_bias" in sd + if has_bias: + fused_bias_key = fused_key + "_bias" + gate_bias_key = gate_key + "_bias" + up_bias_key = up_key + "_bias" + + # ---- split weights based on model type ---------------------- fused = sd[fused_key] # [E, H, 2I] - fused_bias = sd[fused_bias_key] # [E, 2I] E, H, two_I = fused.shape - # ffn_dim = two_I // 2 - # For GptOss, gate/up are interleaved: [gate0, up0, gate1, up1, ...] - gate = fused[..., ::2] # [E, H, I] - even indices - up = fused[..., 1::2] # [E, H, I] - odd indices - gate_bias = fused_bias[..., ::2] # [E, I] - even indices - up_bias = fused_bias[..., 1::2] # [E, I] - odd indices + if is_gpt_oss: + # For GptOss, gate/up are interleaved: [gate0, up0, gate1, up1, ...] + gate = fused[..., ::2] # [E, H, I] - even indices + up = fused[..., 1::2] # [E, H, I] - odd indices + else: + # For standard MoE, gate/up are concatenated: [gate, up] + ffn_dim = two_I // 2 + gate, up = fused.split(ffn_dim, dim=-1) # views – no copy - experts = model_tmp.model.layers[layer_idx].mlp.experts + # Copy weights to model experts.gate_proj.data.copy_(gate) experts.up_proj.data.copy_(up) - experts.gate_proj_bias.data.copy_(gate_bias) - experts.up_proj_bias.data.copy_(up_bias) + + # Handle bias if present + if has_bias: + fused_bias = sd[fused_bias_key] # [E, 2I] + + if is_gpt_oss: + gate_bias = fused_bias[..., ::2] # [E, I] - even indices + up_bias = fused_bias[..., 1::2] # [E, I] - odd indices + else: + ffn_dim = fused_bias.shape[-1] // 2 + gate_bias, up_bias = fused_bias.split(ffn_dim, dim=-1) + + experts.gate_proj_bias.data.copy_(gate_bias) + experts.up_proj_bias.data.copy_(up_bias) # ---- update the state-dict so load_state_dict sees the right keys sd[gate_key] = gate sd[up_key] = up - sd[gate_bias_key] = gate_bias - sd[up_bias_key] = up_bias + if has_bias: + sd[gate_bias_key] = gate_bias + sd[up_bias_key] = up_bias + + # Delete fused keys if delete_fused_key: del sd[fused_key] - del sd[fused_bias_key] + if has_bias: + del sd[fused_bias_key] logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})") transformed = True @@ -250,4 +224,5 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: return model, transformed +# Keep the existing list of supported models VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM", "QEffGptOssForCausalLM"} diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 16767fbe2..f00c08d25 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Tuple import torch -from transformers.cache_utils import DynamicCache, EncoderDecoderCache, HybridCache, HybridChunkedCache +from transformers.cache_utils import EncoderDecoderCache, HybridCache, HybridChunkedCache from QEfficient.customop import ( CtxGatherFunc, @@ -23,18 +23,142 @@ ) -class QEffDynamicCache(DynamicCache): - """ - A cache that grows dynamically as more tokens are generated. This is the default for generative models. +class QEffDynamicCache: + def __init__(self) -> None: + self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] - It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is - `[batch_size, num_heads, seq_len, head_dim]`. + @classmethod + def from_legacy_cache( + cls, past_key_values: tuple[tuple[torch.FloatTensor, torch.FloatTensor], ...] + ) -> "QEffDynamicCache": + """ + Converts a cache in the legacy cache format into an equivalent `Cache`. Used for + backward compatibility. + """ + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + # Directly populate the cache lists + cache.key_cache.append(key_states) + cache.value_cache.append(value_states) + return cache - - Optimized implementation for the Cloud AI 100 to reuse KV Cache. - - get the position_ids input using kwargs. - - Use custom Onnxscript ops to write optimized version to generate Onnx model. + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or not self.key_cache[layer_idx].numel() # the layer has no cache + ) + layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + return layer_seq_length - """ + def get_max_cache_shape(self) -> Optional[int]: + """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length.""" + return None + + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_cache_shape() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return (self.key_cache[layer_idx], self.value_cache[layer_idx]) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.key_cache) + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + return legacy_cache + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + """ + # Update the cache + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + k_out, v_out = key_states, value_states + else: + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) + # Scatter + if batch_index is not None: + invalid_scatter_index = torch.iinfo(torch.int32).max + scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) + self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states + ).clone() + self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states + ).clone() + else: + self.key_cache[layer_idx] = CtxScatterFunc.apply( + self.key_cache[layer_idx], position_ids, key_states + ).clone() + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], position_ids, value_states + ).clone() + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + # Gather + ctx_len = k_out.shape[2] + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices) + v_out = CtxGatherFunc.apply(v_out, ctx_indices) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return k_out, v_out def write_only(self, key_states, value_states, layer_idx, cache_kwargs): """ @@ -113,80 +237,6 @@ def read_only(self, layer_idx, cache_kwargs): v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the cache - if len(self.key_cache) <= layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - k_out, v_out = key_states, value_states - else: - position_ids = cache_kwargs.get("position_ids") - batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs - - # Scatter - if batch_index is not None: - invalid_scatter_index = torch.iinfo(torch.int32).max - scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) - - self.key_cache[layer_idx] = CtxScatterFuncCB.apply( - self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states - ) - - self.value_cache[layer_idx] = CtxScatterFuncCB.apply( - self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states - ) - else: - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply( - self.value_cache[layer_idx], position_ids, value_states - ) - - k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] - - # Gather - ctx_len = k_out.shape[2] - ctx_indices = torch.arange(ctx_len)[None, None, ...] - gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) - invalid_mask = ctx_indices > gather_limit - - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 - - ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) - else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices) - v_out = CtxGatherFunc.apply(v_out, ctx_indices) - v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) - - return k_out, v_out - def update3D( self, key_states: torch.Tensor, diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 60c48b3ad..82787e1d4 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -67,10 +67,6 @@ def forward(self, hidden: torch.Tensor): for e in range(self.experts.num_experts): routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] - # Skip if no tokens routed to this expert - # if (routing_weight > 0).sum() == 0: - # continue - W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] W_d = self.experts.down_proj[e] # [I, H] @@ -98,9 +94,75 @@ def forward(self, hidden: torch.Tensor): # original shape [B, S, H] return expert_out.view(B, S, H), router_logits + # V2 + # B, S, H = hidden.shape + # T = B * S # Total number of tokens + + # hidden = hidden.view(T, H) + + # router_logits = F.linear(hidden, self.router.weight, self.router.bias) # [T, num_experts] + # top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + # top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + # routing_weights = torch.zeros_like(router_logits) # [T, num_experts] + # routing_weights.scatter_(1, top_i, top_w) # Scatter top-k weights to their positions + + # ffn_dim = self.experts.intermediate_size # Intermediate dimension + # upgate = hidden.new_zeros((T, ffn_dim)) # Buffer for up-gate activations + # expert_out = hidden.new_zeros((T, H)) # Buffer for final expert outputs + + # for e in range(self.experts.num_experts): + # # Get routing weight for this expert + # routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + # # Get weight matrices and biases for this expert + # W_g = self.experts.gate_proj[e] # [H, ffn_dim] + # W_u = self.experts.up_proj[e] # [H, ffn_dim] + # b_g = self.experts.gate_proj_bias[e] # [ffn_dim] + # b_u = self.experts.up_proj_bias[e] # [ffn_dim] + + # # ===== Gate projection with bias and clamping ===== + # gate = (hidden @ W_g) + b_g # [T, ffn_dim] + # gate = gate.clamp(min=None, max=self.experts.limit) + + # # ===== Up projection with bias and clamping ===== + # up = (hidden @ W_u) + b_u # [T, ffn_dim] + # up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + # glu = gate * torch.sigmoid(gate * self.experts.alpha) + # intermediate = (up + 1) * glu # [T, ffn_dim] + # masked_intermediate = torch.where( + # routing_weight > 0, + # intermediate, + # torch.zeros_like(upgate) + # ) + + # # ===== Accumulate to upgate buffer using += ===== + # # The += operator is important for compiler optimization + # upgate += masked_intermediate + + # for e in range(self.experts.num_experts): + # # Get routing weight for this expert + # routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + # # Get down projection matrix and bias + # W_d = self.experts.down_proj[e] # [ffn_dim, H] + # b_d = self.experts.down_proj_bias[e] # [H] + # down_out = (upgate @ W_d) + b_d # [T, H] + # down_out = down_out * routing_weight # [T, H] + # masked_down = torch.where( + # routing_weight > 0, + # down_out, + # torch.zeros_like(expert_out) + # ) + # expert_out += masked_down + + # expert_out = expert_out.view(B, S, H) + + # # Return output and router logits for loss computation + # return expert_out, router_logits + # Can be replaced with llama/modeling_llama.py::QEffLlamaRotaryEmbedding but keeping it following transformers ideology -class QEffQwen2RotaryEmbedding(GptOssRotaryEmbedding): +class QEffGptOssRotaryEmbedding(GptOssRotaryEmbedding): """ Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py The only differences are: @@ -221,7 +283,7 @@ class QEffGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" def __qeff_init__(self): - self.rotary_emb = QEffQwen2RotaryEmbedding(config=self.config) + self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) def forward( self, @@ -370,24 +432,12 @@ def forward( target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_length) + + # TODO: Enable SWA # causal_mask = _create_causal_mask( # position_ids=position_ids, target_length=target_length, sliding_window=self.config.sliding_window # ) - # # It may already have been prepared by e.g. `generate` - # if not isinstance(causal_mask_mapping := attention_mask, dict): - # mask_kwargs = { - # "config": self.config, - # "input_embeds": inputs_embeds, - # "attention_mask": attention_mask, - # "cache_position": cache_position, - # "past_key_values": past_key_values, - # } - # causal_mask_mapping = { - # "full_attention": create_causal_mask(**mask_kwargs), - # "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), - # } - hidden_states = inputs_embeds # position_embeddings = self.rotary_emb(hidden_states, position_ids) @@ -505,25 +555,6 @@ def forward( logits = self.lm_head(hidden_states) logits = logits.float() - # # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - # logits = self.lm_head(hidden_states[:, slice_indices, :]) - - # loss = None - # if labels is not None: - # loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) - - # aux_loss = None - # if output_router_logits: - # aux_loss = load_balancing_loss_func( - # outputs.router_logits, - # self.num_experts, - # self.num_experts_per_tok, - # attention_mask, - # ) - # if labels is not None: - # loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device - return MoeCausalLMOutputWithPast( loss=None, aux_loss=None, diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 60b1c929d..7a084abcb 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -24,7 +24,6 @@ MistralForCausalLM, MistralModel, MistralRotaryEmbedding, - logger, repeat_kv, rotate_half, ) @@ -298,10 +297,6 @@ def forward( if use_cache and not isinstance(past_key_values, Cache) and not self.training: past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) return_legacy_cache = True - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " - "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" - ) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 8690ea03a..2f3ee3dc0 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -27,7 +27,7 @@ import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform -from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform, SplitGateUpWeightsTransformGPTOSS +from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.generation.text_generation_inference import ( CloudAI100ExecInfoNew, @@ -1381,7 +1381,6 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): CustomOpsTransform, KVCacheTransform, SplitGateUpWeightsTransform, - SplitGateUpWeightsTransformGPTOSS, KVCacheExternalModuleMapperTransform, ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] diff --git a/pyproject.toml b/pyproject.toml index 00b2bb7ef..5dfb00e63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,10 +20,10 @@ classifiers = [ requires-python = ">=3.8,<3.11" dependencies = [ "transformers==4.55.0", - "huggingface-hub==0.30.0", + "huggingface-hub", "hf_transfer==0.1.9", - "peft==0.13.2", - "datasets==2.20.0", + "peft", + "datasets", "fsspec==2023.6.0", "multidict==6.0.4", "urllib3<2", From 30ed22269e0201b291a4cffd12ae93f52f5be773 Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Fri, 8 Aug 2025 02:44:05 +0000 Subject: [PATCH 08/10] nit: update test gpt file Signed-off-by: vbaddi --- tests/test_gpt.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/tests/test_gpt.py b/tests/test_gpt.py index 92c17c353..8e44f2f82 100644 --- a/tests/test_gpt.py +++ b/tests/test_gpt.py @@ -1,27 +1,39 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + import torch -from transformers import AutoConfig, AutoModelForCausalLM, GptOssForCausalLM, TextStreamer +from transformers import AutoConfig, GptOssForCausalLM, TextStreamer from QEfficient import QEFFAutoModelForCausalLM from QEfficient.utils._utils import load_hf_tokenizer from QEfficient.utils.constants import Constants from QEfficient.utils.run_utils import ApiRunner -Constants.INPUT_STR=["Make sure tokens don't repeat\n\nTo make a simple cup of coffee, start by boiling water. Add one to two teaspoons of instant coffee powder to a mug. Pour the hot water over the coffee and stir well. Add sugar and milk to taste, if desired. For brewed coffee, use a French press or drip filter. Add coarsely ground coffee to the device, pour hot water over it, and let it steep for four minutes. Press or filter the coffee, then serve"] +Constants.INPUT_STR = [ + "Make sure tokens don't repeat\n\nTo make a simple cup of coffee, start by boiling water. Add one to two teaspoons of instant coffee powder to a mug. Pour the hot water over the coffee and stir well. Add sugar and milk to taste, if desired. For brewed coffee, use a French press or drip filter. Add coarsely ground coffee to the device, pour hot water over it, and let it steep for four minutes. Press or filter the coffee, then serve" +] torch.manual_seed(42) model_id = "openai/gpt-oss-20b" config = AutoConfig.from_pretrained(model_id) -config.num_hidden_layers=2 +config.num_hidden_layers = 2 # Remove the quantization_config attribute if it exists, to avoid MXFP4 Issues if hasattr(config, "quantization_config"): delattr(config, "quantization_config") model = GptOssForCausalLM.from_pretrained( - "/home/vbaddi/transformers/src/transformers/models/gpt_oss/new_weights", torch_dtype=torch.float32, attn_implementation="eager", config=config + "/home/vbaddi/transformers/src/transformers/models/gpt_oss/new_weights", + torch_dtype=torch.float32, + attn_implementation="eager", + config=config, ) model.eval() -model.generation_config.sample=False +model.generation_config.sample = False tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_id) config = model.config batch_size = len(Constants.INPUT_STR) @@ -58,4 +70,4 @@ print(pytorch_hf_tokens) print(exec_info) -assert (exec_info.generated_ids[0][0,:159] == pytorch_hf_tokens).all() +assert (exec_info.generated_ids[0][0, :159] == pytorch_hf_tokens).all() From 14afedb602b6603c99c1a5a55b31192409334943 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Fri, 8 Aug 2025 12:36:11 +0530 Subject: [PATCH 09/10] MOE optimized Signed-off-by: Onkar Chougule --- .../models/gpt_oss/modeling_gpt_oss.py | 163 +++++++++++------- 1 file changed, 97 insertions(+), 66 deletions(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 8e82e702d..855543c5f 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -42,7 +42,7 @@ def __qeff_init__(self): class QEffGptOssMLP(GptOssMLP): - def forward(self, hidden: torch.Tensor): + def alt_forward(self, hidden: torch.Tensor): B, S, H = hidden.shape T = B * S hidden = hidden.view(T, H) @@ -94,71 +94,101 @@ def forward(self, hidden: torch.Tensor): # original shape [B, S, H] return expert_out.view(B, S, H), router_logits - # V2 - # B, S, H = hidden.shape - # T = B * S # Total number of tokens - - # hidden = hidden.view(T, H) - - # router_logits = F.linear(hidden, self.router.weight, self.router.bias) # [T, num_experts] - # top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] - # top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) - # routing_weights = torch.zeros_like(router_logits) # [T, num_experts] - # routing_weights.scatter_(1, top_i, top_w) # Scatter top-k weights to their positions - - # ffn_dim = self.experts.intermediate_size # Intermediate dimension - # upgate = hidden.new_zeros((T, ffn_dim)) # Buffer for up-gate activations - # expert_out = hidden.new_zeros((T, H)) # Buffer for final expert outputs - - # for e in range(self.experts.num_experts): - # # Get routing weight for this expert - # routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] - - # # Get weight matrices and biases for this expert - # W_g = self.experts.gate_proj[e] # [H, ffn_dim] - # W_u = self.experts.up_proj[e] # [H, ffn_dim] - # b_g = self.experts.gate_proj_bias[e] # [ffn_dim] - # b_u = self.experts.up_proj_bias[e] # [ffn_dim] - - # # ===== Gate projection with bias and clamping ===== - # gate = (hidden @ W_g) + b_g # [T, ffn_dim] - # gate = gate.clamp(min=None, max=self.experts.limit) - - # # ===== Up projection with bias and clamping ===== - # up = (hidden @ W_u) + b_u # [T, ffn_dim] - # up = up.clamp(min=-self.experts.limit, max=self.experts.limit) - # glu = gate * torch.sigmoid(gate * self.experts.alpha) - # intermediate = (up + 1) * glu # [T, ffn_dim] - # masked_intermediate = torch.where( - # routing_weight > 0, - # intermediate, - # torch.zeros_like(upgate) - # ) - - # # ===== Accumulate to upgate buffer using += ===== - # # The += operator is important for compiler optimization - # upgate += masked_intermediate - - # for e in range(self.experts.num_experts): - # # Get routing weight for this expert - # routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] - - # # Get down projection matrix and bias - # W_d = self.experts.down_proj[e] # [ffn_dim, H] - # b_d = self.experts.down_proj_bias[e] # [H] - # down_out = (upgate @ W_d) + b_d # [T, H] - # down_out = down_out * routing_weight # [T, H] - # masked_down = torch.where( - # routing_weight > 0, - # down_out, - # torch.zeros_like(expert_out) - # ) - # expert_out += masked_down - - # expert_out = expert_out.view(B, S, H) - - # # Return output and router logits for loss computation - # return expert_out, router_logits + def forward(self, hidden_states: torch.Tensor): + B, S, H = hidden_states.shape + T = B * S + hidden_states = hidden_states.view(T, H) + + # Router computation + router_logits = F.linear(hidden_states, self.router.weight, self.router.bias) + + # Top-k selection + top_w, selected_experts = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + # Creating experts mask and routing weights masked + awesome_experts_mask_1 = ( + torch.nn.functional.one_hot(selected_experts[:, 0], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + awesome_experts_mask_2 = ( + torch.nn.functional.one_hot(selected_experts[:, 1], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + awesome_experts_mask_3 = ( + torch.nn.functional.one_hot(selected_experts[:, 2], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + awesome_experts_mask_4 = ( + torch.nn.functional.one_hot(selected_experts[:, 3], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + + gateupout1 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + gateupout2 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + gateupout3 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + gateupout4 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + + # Gate and Up projections + gate = (hidden_states @ W_g) + b_g # [T, I] + up = (hidden_states @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + gateupout1 += torch.where(awesome_experts_mask_1[e], intermediate, torch.zeros_like(gateupout1)) + gateupout2 += torch.where(awesome_experts_mask_2[e], intermediate, torch.zeros_like(gateupout2)) + gateupout3 += torch.where(awesome_experts_mask_3[e], intermediate, torch.zeros_like(gateupout3)) + gateupout4 += torch.where(awesome_experts_mask_4[e], intermediate, torch.zeros_like(gateupout4)) + + concat_down = torch.zeros((self.router.top_k, T, H)) + concat_mask = torch.cat( + ( + awesome_experts_mask_1.unsqueeze(0), + awesome_experts_mask_2.unsqueeze(0), + awesome_experts_mask_3.unsqueeze(0), + awesome_experts_mask_4.unsqueeze(0), + ), + dim=0, + ) + + concat_gateout = torch.cat( + (gateupout1.unsqueeze(0), gateupout2.unsqueeze(0), gateupout3.unsqueeze(0), gateupout4.unsqueeze(0)), dim=0 + ) + + for e in range(self.experts.num_experts): + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + # Down projection + down_out = (concat_gateout @ W_d) + b_d # [T, H] + + concat_down += torch.where(concat_mask[:, e, :], down_out, torch.zeros_like(concat_down)) + + downout1, downout2, downout3, downout4 = concat_down[0], concat_down[1], concat_down[2], concat_down[3] + hidden_states = ( + downout1 * top_w[:, 0].unsqueeze(-1) + + downout2 * top_w[:, 1].unsqueeze(-1) + + downout3 * top_w[:, 2].unsqueeze(-1) + + downout4 * top_w[:, 3].unsqueeze(-1) + ).reshape(B, S, H) + + # original shape [B, S, H] + return hidden_states, router_logits # Can be replaced with llama/modeling_llama.py::QEffLlamaRotaryEmbedding but keeping it following transformers ideology @@ -383,6 +413,7 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores + # alth, _ = self.mlp.alt_forward(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) From 99f4795bc7c3d0f053b76218a883867958555bd3 Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Mon, 11 Aug 2025 07:00:22 +0000 Subject: [PATCH 10/10] nit: update modeling with new decode moe forward Signed-off-by: vbaddi --- .../models/gpt_oss/modeling_gpt_oss.py | 47 ++++++++++++++++++- .../transformers/models/modeling_auto.py | 2 +- .../transformers/models/pytorch_transforms.py | 4 +- 3 files changed, 48 insertions(+), 5 deletions(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 855543c5f..cf1228fae 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -94,7 +94,52 @@ def alt_forward(self, hidden: torch.Tensor): # original shape [B, S, H] return expert_out.view(B, S, H), router_logits - def forward(self, hidden_states: torch.Tensor): + # ------------------- Gather based, weights as activation approach --------------- + def forward(self, hidden_states): + bs, seq_len, _ = hidden_states.shape + hidden_states = hidden_states.view(bs * seq_len, self.experts.hidden_size) + + # Router computation + router_logits = F.linear(hidden_states, self.router.weight, self.router.bias) + router_top_value, router_indices = torch.topk(router_logits, self.router.top_k, dim=-1) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + + # GATHER - collect weights for selected experts + gate_up_proj = self.experts.gate_up_proj[router_indices.flatten()] + gate_up_proj_bias = self.experts.gate_up_proj_bias[router_indices.flatten()] + down_proj = self.experts.down_proj[router_indices.flatten()] + down_proj_bias = self.experts.down_proj_bias[router_indices.flatten()] + + # Apply Chosen Experts (without routing weights first) + # expert_in = hidden_states.repeat_interleave(self.router.top_k, dim=0) + # expert_in = expert_in.view(-1, 1, self.experts.hidden_size) + # Reshape for bmm: (bs*seq_len*top_k, 1, hidden_size) + expert_in = ( + hidden_states.unsqueeze(1) + .expand(-1, self.router.top_k, -1) + .contiguous() + .view(-1, 1, self.experts.hidden_size) + ) + + gate_up = torch.bmm(expert_in, gate_up_proj) + gate_up_proj_bias.unsqueeze(1) + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + + # Apply activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + glu = gate * torch.sigmoid(gate * self.experts.alpha) + gated_output = (up + 1) * glu + + experts_out = torch.bmm(gated_output, down_proj) + down_proj_bias.unsqueeze(1) + experts_out = experts_out.view(bs * seq_len, self.router.top_k, self.experts.hidden_size) + + # Apply routing weights AFTER expert computation (This is before on Llama4) + experts_out = experts_out * router_top_value.unsqueeze(-1) + experts_out = experts_out.sum(dim=1) + + return experts_out, router_logits + + def optimized_moe_forward(self, hidden_states: torch.Tensor): B, S, H = hidden_states.shape T = B * S hidden_states = hidden_states.view(T, H) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 470bf65d6..25f1475b9 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1380,7 +1380,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): FP8DeQuantLinearToLinearTransform, CustomOpsTransform, KVCacheTransform, - SplitGateUpWeightsTransform, + # SplitGateUpWeightsTransform, KVCacheExternalModuleMapperTransform, ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 20839590f..f220743a9 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -54,7 +54,6 @@ from transformers.models.gpt_oss.modeling_gpt_oss import ( GptOssAttention, GptOssDecoderLayer, - GptOssExperts, GptOssForCausalLM, GptOssMLP, GptOssModel, @@ -211,7 +210,6 @@ from QEfficient.transformers.models.gpt_oss.modeling_gpt_oss import ( QEffGptOssAttention, QEffGptOssDecoderLayer, - QEffGptOssExperts, QEffGptOssForCausalLM, QEffGptOssMLP, QEffGptOssModel, @@ -423,7 +421,7 @@ class KVCacheTransform(ModuleMappingTransform): GptOssModel: QEffGptOssModel, GptOssForCausalLM: QEffGptOssForCausalLM, GptOssMLP: QEffGptOssMLP, - GptOssExperts: QEffGptOssExperts, + # GptOssExperts: QEffGptOssExperts, # Granite GraniteModel: QEffGraniteModel, GraniteForCausalLM: QEffGraniteForCausalLM,