Skip to content

Commit c092cff

Browse files
committed
wrap stages in FSDP root mod
1 parent eb13ba2 commit c092cff

File tree

1 file changed

+33
-2
lines changed

1 file changed

+33
-2
lines changed

torchtitan/models/llama3/infra/pipeline.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818

1919
from torchtitan.components.loss import LossFunction
20-
from torchtitan.config import JobConfig
20+
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
2121
from torchtitan.distributed import ParallelDims
2222
from torchtitan.distributed.pipeline_parallel import (
2323
build_pipeline_schedule,
@@ -27,7 +27,23 @@
2727

2828
from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction
2929
from torchtitan.tools.logging import logger
30-
30+
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
31+
32+
root_mod = None
33+
34+
class PipelineStagesWrapper(nn.Module):
35+
"""Wrapper to establish parent-child relationship for pipeline stages."""
36+
def __init__(self, stages):
37+
super().__init__()
38+
# Store stages as actual child modules
39+
for i, stage in enumerate(stages):
40+
self.add_module(f"stage_{i}", stage)
41+
42+
def forward(self, x):
43+
# This won't be called in pipeline mode, but FSDP requires it
44+
for stage in self.children():
45+
x = stage(x)
46+
return x
3147

3248
def pipeline_llama(
3349
model: nn.Module,
@@ -136,6 +152,21 @@ def pipeline_llama(
136152
# in case the model is modified e.g. by torch.compile
137153
stages[i].submod = m
138154

155+
if parallel_dims.fsdp_enabled:
156+
world_mesh = parallel_dims.world_mesh
157+
if parallel_dims.dp_replicate_enabled:
158+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
159+
else:
160+
dp_mesh_dim_names = ("dp_shard_cp",)
161+
162+
mp_policy = MixedPrecisionPolicy(param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce])
163+
fsdp_config = {"mesh": world_mesh[tuple(dp_mesh_dim_names)], "mp_policy": mp_policy}
164+
165+
# Wrap the model parts into a root-level FSDP Module
166+
parent_module = PipelineStagesWrapper(model_parts)
167+
global root_mod
168+
root_mod = fully_shard(parent_module, **fsdp_config)
169+
139170
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
140171

141172
# This is used in the train loop to determine whether to pass in the input_ids and labels

0 commit comments

Comments
 (0)