Skip to content

self.gate dtype update for GLM-4.5 #22203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4MoeForCausalLM` | GLM-4.5 | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4v_moeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V-Air`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4v_moeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ |
Expand Down
2 changes: 1 addition & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def check_available_online(
trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
"Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"), # noqa: E501
"Glm4v_moeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V-Air",
"Glm4v_moeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V",
is_available_online=False), # noqa: E501
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
trust_remote_code=True,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
config.n_routed_experts,
bias=False,
quant_config=None,
params_dtype=torch.float32,
prefix=f"{prefix}.gate")

self.gate.e_score_correction_bias = nn.Parameter(
Expand Down Expand Up @@ -180,7 +181,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The explicit cast hidden_states.to(dtype=torch.float32) is redundant and introduces an unnecessary memory copy, which can negatively impact performance.

Since self.gate.weight is already of dtype=torch.float32 (due to the change in __init__), torch.nn.functional.linear (which is called internally by ColumnParallelLinear) will automatically perform the matrix multiplication in float32 by upcasting the hidden_states tensor. This implicit type promotion is more efficient than an explicit cast.

Removing the explicit cast will rely on this standard PyTorch behavior and avoid the overhead, while still achieving the goal of performing the gate computation in float32.

Suggested change
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
router_logits, _ = self.gate(hidden_states)

final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
Expand Down