-
Notifications
You must be signed in to change notification settings - Fork 500
[EP] add support for ETP=1 #1555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
8143949
to
79b2934
Compare
torchtitan/config/job_config.py
Outdated
Expert parallelism degree. 1 means disabled. No effect for non-MoE models. | ||
Currently, it is supported with the following constraints: | ||
- when etp = tp: cp * tp <= ep <= dp_shard * cp * tp | ||
- when etp = 1: cp * tp <= ep <= dp_shard * cp * tp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add some comments about the divisibility constraints? For instance, we require ep % cp = 0 and dp_shard * cp % ep == 0. Or do you think since people usually use powers of 2 this is not needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I can add something. Fwiw I used to use |
symbol to denote mod x == 0
but people don't seem to understand what it means.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think the %
symbol is more common?
torchtitan/models/moe.py
Outdated
@@ -143,16 +143,17 @@ def init_weights(self, init_std: float): | |||
nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std) | |||
|
|||
|
|||
class TokenChoiceTopKRouter(nn.Module): | |||
class TokenRouter(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel the earlier name was more descriptive, since we are doing the torch.topk operation in this module. TokenRouter seems more like just the routing part of token choice.
I added a couple of nit comments. The loss curves look consistent. LGTM! |
@tianyu-l this is very cool. Could you help clarify what's the difference between ETP and EP + TP on a high-level? Do you have an example config to run it, please? |
fixes bug introduced in #1555
…tants (#160805) Used in pytorch/torchtitan#1555 Pull Request resolved: #160805 Approved by: https://github.com/StrongerXi, https://github.com/mlazos
This is a followup of original EP support #1324
PR summary
[TBA] description + figure
numerics verification
setup
comparison set