-
Notifications
You must be signed in to change notification settings - Fork 589
Fix how SimpleFSDP get the nD mesh #1959
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[ghstack-poisoned]
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #1960 * #1959 * __->__ #1963 pytorch/pytorch#166130 changes the configs and this PR adopts the new configs Squash and Merge button won't work for this PR. I'll merge by myself.
[ghstack-poisoned]
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #1960 * #1959 * __->__ #1965 **Squash and Merge button won't work for this PR. I'll merge by myself.** #1963 was accdientally merge with Squash and Merge button. This is a reland PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
| assert inner_mesh.mesh_dim_names is not None | ||
| submesh_names = outer_mesh.mesh_dim_names + inner_mesh.mesh_dim_names | ||
| spanned_mesh = outer_global_mesh[submesh_names] | ||
| spanned_mesh = DeviceMesh._concatenate((outer_mesh, inner_mesh)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
One nit: isn't the argument supposed to be a list, and not a tuple?
If so, how come there is no type checking or other linting to catch this?
Note that we've already observed TorchTitan being somewhat incorrect with types, e.g., it passes lists to init_device_mesh instead of tuples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I believe we have not enabled type checking for TorchTitan, which we should.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * __->__ #1960 * #1959 As title, no logic change. **Squash and Merge button won't work for this PR. I'll merge by myself.**
Stack from ghstack (oldest at bottom):
This is the recommended way to get the nD mesh now that DeviceMesh has _concatenate().
Squash and Merge button won't work for this PR. I'll merge by myself.