- 
                Notifications
    
You must be signed in to change notification settings  - Fork 594
 
Use new DeviceMesh unflatten to rewrite parallel_dims #1660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
12eca61    to
    19e4a23      
    Compare
  
    | 
               | 
          ||
| return mesh | ||
| if self._meshes[dim].size() == 1: | ||
| return None | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this will break user expectation. We got asks that DTensor redistribute running on a mesh of size 1 should perform no op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But even for current TorchTitan, we won't create any DeviceMesh if the parallelism degree is 1. So it is unclear to me how DeviceMesh with size 1 exists?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not in torchtitan, in internal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTorch? Then it is okay, right? DeviceMesh still supports the case but TorchTitan makes a stronger assumption in our use case.
19e4a23    to
    178bc11      
    Compare
  
    | fsdp = self.dp_shard * self.cp | ||
| efsdp = fsdp * self.tp // (self.etp * self.ep) | ||
| 
               | 
          ||
| self._world_mesh = init_device_mesh( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this initialize a world PG?
it may be fine to just ignore this for now in torchtitan, but, i am wondering if users want control over world group creation what would that look like?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc., @fduwjj are we able to disable the global PG initialization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so right now we don't use split, so we can make it a fake pg. But if split is needed then we need to materialize the world PG anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should modify FLUX train.py as it's in core now.
@ruisizhang123 let's adapt SimpleFSDP after this PR is merged.
oh it seems being fixed in #1959
| 
               | 
          ||
| return mesh | ||
| if self._meshes[dim].size() == 1: | ||
| return None | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not in torchtitan, in internal
| ) | ||
| 
               | 
          ||
| self._meshes = { | ||
| "pp": dataloading_mesh["pp"], | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to confirm if things match my expected behavior:
- PG will be created for each sub dimension during 
unflatten, unlessbackend_overrideis specified on some dimension with the"fake"backend. flattenwill create a new mesh and a new PG.- slicing will create a new mesh, but reuse the PG created in parent mesh.
 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Yes
 - Yes
 - Yes
 
| self._world_mesh = init_device_mesh( | ||
| device_type, (self.world_size,), mesh_dim_names=("world",) | ||
| ) | ||
| dataloading_mesh = unflatten_mesh( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious what will happen if self.pp * batch * self.cp * self.tp != world_size? Will the _unflatten() fail?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it will fail
This is a demonstration of how parallel_dims will be when using pytorch/pytorch#161224 stack. ghstack-source-id: d29d2e2 Pull-Request: #1885
20910ef    to
    a67e87a      
    Compare
  
    
Summary
This PR utilizes the latest APIs provided by DeviceMesh to simplify the creation of all different meshes.
The design philosophy is as follow: