|
17 | 17 | ) |
18 | 18 |
|
19 | 19 | from torchtitan.components.loss import LossFunction |
20 | | -from torchtitan.config import JobConfig |
| 20 | +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP |
21 | 21 | from torchtitan.distributed import ParallelDims |
22 | 22 | from torchtitan.distributed.pipeline_parallel import ( |
23 | 23 | build_pipeline_schedule, |
|
27 | 27 |
|
28 | 28 | from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction |
29 | 29 | from torchtitan.tools.logging import logger |
30 | | - |
| 30 | +from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy |
| 31 | + |
| 32 | +root_mod = None |
| 33 | + |
| 34 | +class PipelineStagesWrapper(nn.Module): |
| 35 | + """Wrapper to establish parent-child relationship for pipeline stages.""" |
| 36 | + def __init__(self, stages): |
| 37 | + super().__init__() |
| 38 | + # Store stages as actual child modules |
| 39 | + for i, stage in enumerate(stages): |
| 40 | + self.add_module(f"stage_{i}", stage) |
| 41 | + |
| 42 | + def forward(self, x): |
| 43 | + # This won't be called in pipeline mode, but FSDP requires it |
| 44 | + for stage in self.children(): |
| 45 | + x = stage(x) |
| 46 | + return x |
31 | 47 |
|
32 | 48 | def pipeline_llama( |
33 | 49 | model: nn.Module, |
@@ -136,6 +152,21 @@ def pipeline_llama( |
136 | 152 | # in case the model is modified e.g. by torch.compile |
137 | 153 | stages[i].submod = m |
138 | 154 |
|
| 155 | + if parallel_dims.fsdp_enabled: |
| 156 | + world_mesh = parallel_dims.world_mesh |
| 157 | + if parallel_dims.dp_replicate_enabled: |
| 158 | + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") |
| 159 | + else: |
| 160 | + dp_mesh_dim_names = ("dp_shard_cp",) |
| 161 | + |
| 162 | + mp_policy = MixedPrecisionPolicy(param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce]) |
| 163 | + fsdp_config = {"mesh": world_mesh[tuple(dp_mesh_dim_names)], "mp_policy": mp_policy} |
| 164 | + |
| 165 | + # Wrap the model parts into a root-level FSDP Module |
| 166 | + parent_module = PipelineStagesWrapper(model_parts) |
| 167 | + global root_mod |
| 168 | + root_mod = fully_shard(parent_module, **fsdp_config) |
| 169 | + |
139 | 170 | pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) |
140 | 171 |
|
141 | 172 | # This is used in the train loop to determine whether to pass in the input_ids and labels |
|
0 commit comments