Skip to content

self.gate dtype update for GLM-4.5 #22203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 5, 2025
Merged

Conversation

zRzRzRzRzRzRzR
Copy link
Contributor

The entire self.gate module needs to remain in float32 to ensure benchmark performance for GLM-4.5 and GLM-4.5V during propagation.

Copy link

github-actions bot commented Aug 4, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to ensure the self.gate module in Glm4MoeBlock operates in float32 for performance reasons, as stated in the description.

The change to initialize the gate's weights in float32 is correct. However, the explicit casting of hidden_states to float32 in the forward method is redundant and introduces unnecessary performance overhead. PyTorch's linear layer implementation automatically handles type promotion, ensuring the computation is performed in float32 when the weights are float32.

I've recommended removing the explicit cast to avoid the performance penalty of an extra memory copy.

@@ -180,7 +181,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The explicit cast hidden_states.to(dtype=torch.float32) is redundant and introduces an unnecessary memory copy, which can negatively impact performance.

Since self.gate.weight is already of dtype=torch.float32 (due to the change in __init__), torch.nn.functional.linear (which is called internally by ColumnParallelLinear) will automatically perform the matrix multiplication in float32 by upcasting the hidden_states tensor. This implicit type promotion is more efficient than an explicit cast.

Removing the explicit cast will rely on this standard PyTorch behavior and avoid the overhead, while still achieving the goal of performing the gate computation in float32.

Suggested change
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
router_logits, _ = self.gate(hidden_states)

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) August 4, 2025 16:34
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 4, 2025
@Isotr0py Isotr0py added this to the v0.10.1 milestone Aug 4, 2025
Signed-off-by: zRzRzRzRzRzRzR <[email protected]>
auto-merge was automatically disabled August 4, 2025 16:38

Head branch was pushed to by a user without write access

@mergify mergify bot added the documentation Improvements or additions to documentation label Aug 4, 2025
@zRzRzRzRzRzRzR
Copy link
Contributor Author

zRzRzRzRzRzRzR commented Aug 4, 2025

Also, the name changed of Final GLM-V model

@Isotr0py Isotr0py enabled auto-merge (squash) August 4, 2025 17:26
@vllm-bot vllm-bot merged commit 6fa41e0 into vllm-project:main Aug 5, 2025
42 of 44 checks passed
juuice-lee pushed a commit to juuice-lee/vllm-moe.code that referenced this pull request Aug 5, 2025
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
@Mushoz
Copy link

Mushoz commented Aug 6, 2025

@zRzRzRzRzRzRzR What kind of benchmarks showed degraded performance without this change? There is a discussion taking place on the pull request that was merged into llama.cpp that introduced support for these models. We are wondering if llama.cpp would need a similar change, and in which usecases it would help. A perplexity test did not show any improvements when this was changed to float32.

Pull request in question can be found here: ggml-org/llama.cpp#14939

Thank you very much in advance for your clarification!

wenbinc-Bin pushed a commit to wenbinc-Bin/vllm-fork that referenced this pull request Aug 7, 2025
myselvess pushed a commit to myselvess/vllm that referenced this pull request Aug 7, 2025
jingyu-ml pushed a commit to jingyu-ml/vllm that referenced this pull request Aug 8, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
Signed-off-by: zRzRzRzRzRzRzR <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
noamgat pushed a commit to noamgat/vllm that referenced this pull request Aug 9, 2025
yyihuang pushed a commit to yyihuang/vllm that referenced this pull request Aug 11, 2025
Signed-off-by: zRzRzRzRzRzRzR <[email protected]>
Signed-off-by: Avery Yingyi Huang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants