Skip to content

Conversation

@fegin
Copy link
Contributor

@fegin fegin commented Aug 29, 2025

Summary
This PR utilizes the latest APIs provided by DeviceMesh to simplify the creation of all different meshes.

The design philosophy is as follow:

  1. Create one world mesh with the shape as [world_size,]
  2. Create all 1-D submeshes by using 1) unflattening from the world mesh, or 2) slicing and flatten from other derived meshes.
  3. ParallelDims now provides an API, get_mesh(), which accepts str or list[str]. When the argument is str, the API directly return the corresponding 1-D submesh. If the argument is list[str], the dim names will be used to concatenate to form a n-D device mesh.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 29, 2025
@fegin fegin force-pushed the chienchin/new_device_mesh branch 7 times, most recently from 12eca61 to 19e4a23 Compare October 15, 2025 20:39

return mesh
if self._meshes[dim].size() == 1:
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this will break user expectation. We got asks that DTensor redistribute running on a mesh of size 1 should perform no op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But even for current TorchTitan, we won't create any DeviceMesh if the parallelism degree is 1. So it is unclear to me how DeviceMesh with size 1 exists?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not in torchtitan, in internal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch? Then it is okay, right? DeviceMesh still supports the case but TorchTitan makes a stronger assumption in our use case.

@fegin fegin force-pushed the chienchin/new_device_mesh branch from 19e4a23 to 178bc11 Compare October 28, 2025 20:34
@fegin fegin marked this pull request as ready for review October 28, 2025 21:01
@fegin fegin requested a review from wwwjn as a code owner October 28, 2025 21:01
fsdp = self.dp_shard * self.cp
efsdp = fsdp * self.tp // (self.etp * self.ep)

self._world_mesh = init_device_mesh(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this initialize a world PG?

it may be fine to just ignore this for now in torchtitan, but, i am wondering if users want control over world group creation what would that look like?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc., @fduwjj are we able to disable the global PG initialization?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so right now we don't use split, so we can make it a fake pg. But if split is needed then we need to materialize the world PG anyway.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should modify FLUX train.py as it's in core now.

@ruisizhang123 let's adapt SimpleFSDP after this PR is merged.
oh it seems being fixed in #1959


return mesh
if self._meshes[dim].size() == 1:
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not in torchtitan, in internal

)

self._meshes = {
"pp": dataloading_mesh["pp"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to confirm if things match my expected behavior:

  1. PG will be created for each sub dimension during unflatten, unless backend_override is specified on some dimension with the "fake" backend.
  2. flatten will create a new mesh and a new PG.
  3. slicing will create a new mesh, but reuse the PG created in parent mesh.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Yes
  2. Yes
  3. Yes

self._world_mesh = init_device_mesh(
device_type, (self.world_size,), mesh_dim_names=("world",)
)
dataloading_mesh = unflatten_mesh(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious what will happen if self.pp * batch * self.cp * self.tp != world_size? Will the _unflatten() fail?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will fail

fegin added 5 commits November 3, 2025 14:13
This is a demonstration of how parallel_dims will be when using pytorch/pytorch#161224 stack.

ghstack-source-id: d29d2e2
Pull-Request: #1885
ghstack-source-id: f7c3fef
Pull-Request: #1886
ghstack-source-id: cf7ad2a
Pull-Request: #1887
ghstack-source-id: f7c3fef
Pull-Request: #1888
ghstack-source-id: 6173cc5
Pull-Request: #1889
fegin added 9 commits November 3, 2025 14:13
ghstack-source-id: 065ffd4
Pull-Request: #1890
ghstack-source-id: 08dd4a6
Pull-Request: #1891
ghstack-source-id: dcf962b
Pull-Request: #1892
ghstack-source-id: c9fdc96
Pull-Request: #1893
@fegin fegin force-pushed the chienchin/new_device_mesh branch from 20910ef to a67e87a Compare November 3, 2025 23:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants