Skip to content

Commit 6c1ce4a

Browse files
committed
refactor FTManager
1 parent db52d57 commit 6c1ce4a

File tree

7 files changed

+66
-116
lines changed

7 files changed

+66
-116
lines changed

scripts/estimate/estimation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
1616
from torch.testing._internal.distributed.fake_pg import FakeStore
1717

18-
from torchtitan.components.ft import init_ft_manager
1918
from torchtitan.components.lr_scheduler import build_lr_schedulers
2019
from torchtitan.components.optimizer import build_optimizers
2120
from torchtitan.config_manager import ConfigManager, JobConfig
@@ -116,8 +115,7 @@ def estimate_memory(job_config: JobConfig):
116115
model.train()
117116

118117
# build optimizer after applying parallelisms to the model
119-
ft_manager = init_ft_manager(job_config)
120-
optimizers = build_optimizers([model], job_config, parallel_dims, ft_manager)
118+
optimizers = build_optimizers([model], job_config, parallel_dims)
121119
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)
122120
# Post optimizer step model converters hook.
123121
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2

torchtitan/components/ft.py

Lines changed: 50 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,15 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import copy
87
import importlib
98
from contextlib import nullcontext
109
from typing import ContextManager, Optional, TYPE_CHECKING, Union
1110

1211
import torch
1312
import torch.distributed as dist
14-
import torch.distributed._functional_collectives as funcol
1513
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
16-
from torch.distributed.device_mesh import DeviceMesh
1714
from torch.distributed.distributed_c10d import ReduceOp
18-
from torch.distributed.tensor import DTensor
19-
from torchtitan.config_manager import JobConfig
15+
from torchtitan.config_manager import FaultTolerance as FTConfig
2016

2117
if importlib.util.find_spec("torchft") is not None:
2218
import torchft as ft
@@ -32,14 +28,32 @@
3228
class FTManager:
3329
def __init__(
3430
self,
35-
manager: Optional["ft.Manager"],
36-
group_size: int = 1,
37-
replica_id: int = 0,
31+
ft_config: FTConfig,
3832
) -> None:
39-
self._manager = manager
40-
self.group_size = group_size
41-
self.replica_id = replica_id
42-
if has_torchft and manager is not None:
33+
if not ft_config.enable:
34+
self._manager = None
35+
return
36+
37+
if not has_torchft:
38+
raise ImportError("torchft is not installed. Please install it.")
39+
40+
pg = ft.ProcessGroupNCCL()
41+
42+
# If the training method is specific, then the quorum should be synchronous
43+
self.use_async_quorum = ft_config.semi_sync_method is None
44+
45+
self._manager = ft.Manager(
46+
pg=pg,
47+
min_replica_size=ft_config.min_replica_size,
48+
load_state_dict=None,
49+
state_dict=None,
50+
use_async_quorum=self.use_async_quorum,
51+
replica_id=f"torchtitan_ft_{ft_config.replica_id}",
52+
)
53+
self.group_size = ft_config.group_size
54+
self.replica_id = ft_config.replica_id
55+
56+
if self.use_async_quorum:
4357
self.replicate_pg = ft.process_group.ManagedProcessGroup(self._manager)
4458
self.replicate_pg.register("dp_replicate")
4559

@@ -53,96 +67,46 @@ def manager(self) -> "ft.Manager":
5367
return self._manager
5468

5569
def get_dp_info(self, dp_degree: int, dp_rank: int) -> tuple[int, int]:
56-
return dp_degree * self.group_size, dp_degree * self.replica_id + dp_rank
57-
58-
def set_all_reduce_hook(self, model_parts: list[torch.nn.Module]) -> None:
59-
def all_reduce_hook(output):
60-
dist.all_reduce(output, group=self.replicate_pg, op=ReduceOp.AVG)
61-
62-
def apply_set_all_reduce_hook(m):
63-
if isinstance(m, FSDPModule):
64-
m.set_all_reduce_hook(all_reduce_hook)
65-
66-
for part in model_parts:
67-
part.apply(apply_set_all_reduce_hook)
68-
69-
70-
def init_ft_manager(job: JobConfig) -> FTManager:
71-
"""Initialize the FT manager if TorchFT is enabled.
72-
73-
Args:
74-
job (JobConfig): The job configuration.
75-
76-
Returns:
77-
FTManager: A wrapper around TorchFT.Manager
78-
"""
79-
if not job.fault_tolerance.enable:
80-
return FTManager(None)
81-
82-
if not has_torchft:
83-
raise ImportError("torchft is not installed. Please install it.")
84-
85-
if job.fault_tolerance.min_replica_size < 1:
86-
raise ValueError("At least one FT replica is required.")
87-
88-
pg = ft.ProcessGroupNCCL()
70+
if self.enabled:
71+
return dp_degree * self.group_size, dp_degree * self.replica_id + dp_rank
72+
else:
73+
return dp_degree, dp_rank
8974

90-
# If the training method is specific, then the quorum should be synchronous
91-
use_async_quorum = job.fault_tolerance.semi_sync_method is None
75+
def maybe_set_all_reduce_hook(self, model_parts: list[torch.nn.Module]) -> None:
76+
if self.enabled and self.use_async_quorum:
9277

93-
return FTManager(
94-
ft.Manager(
95-
pg=pg,
96-
min_replica_size=job.fault_tolerance.min_replica_size,
97-
load_state_dict=None,
98-
state_dict=None,
99-
use_async_quorum=use_async_quorum,
100-
replica_id=f"torchtitan_ft_{job.fault_tolerance.replica_id}",
101-
),
102-
group_size=job.fault_tolerance.group_size,
103-
replica_id=job.fault_tolerance.replica_id,
104-
)
105-
106-
107-
def ft_dist_reduce(
108-
x: torch.Tensor, reduceOp: str, mesh: DeviceMesh
109-
) -> tuple[torch.Tensor, str, DeviceMesh]:
110-
if has_torchft and isinstance(mesh, ft.device_mesh._FlattenDeviceMesh):
111-
x = funcol.all_reduce(
112-
x, reduceOp=reduceOp, group=mesh.managed_mesh.replicate_pg
113-
)
114-
return x, reduceOp, mesh.managed_mesh.mesh
115-
return x, reduceOp, mesh
78+
def all_reduce_hook(output):
79+
dist.all_reduce(output, group=self.replicate_pg, op=ReduceOp.AVG)
11680

81+
def apply_set_all_reduce_hook(m):
82+
if isinstance(m, FSDPModule):
83+
m.set_all_reduce_hook(all_reduce_hook)
11784

118-
def ft_clip_grad_norm_util(total_norm: DTensor) -> torch.Tensor:
119-
if has_torchft:
120-
mesh = total_norm._spec.mesh
121-
if isinstance(mesh, ft.device_mesh.ManagedDeviceMesh):
122-
# The gradients along the replicated dim has already been reduced.
123-
# So we don't need another reducution beforing removing the
124-
# replicate dimension
125-
local_tensor = total_norm.to_local()
126-
placements = list(copy.copy(total_norm._spec.placements))
127-
placements.pop(mesh.replicate_dim)
128-
return DTensor.from_local(local_tensor, mesh.mesh, placements)
85+
for model_part in model_parts:
86+
model_part.apply(apply_set_all_reduce_hook)
12987

130-
return total_norm
88+
@property
89+
def loss_sync_pg(
90+
self,
91+
) -> Optional["ft.process_group.ManagedProcessGroup"]:
92+
if self.enabled and self.use_async_quorum:
93+
return self.replicate_pg
94+
else:
95+
# skip loss sync when using semi-sync training
96+
return None
13197

13298

13399
def maybe_semi_sync_training(
134-
config: JobConfig,
100+
ft_config: FTConfig,
135101
ft_manager: FTManager,
136102
model_parts: list[torch.nn.Module],
137103
optimizer: torch.optim.Optimizer,
138104
) -> ContextManager[Union["local_sgd.DiLoCo", "local_sgd.LocalSGD", None]]:
139105
"""
140106
If TorchFT is enabled and the config is set, use semi_sync_method
141107
"""
142-
ft_config = config.fault_tolerance
143108
semi_sync_method = ft_config.semi_sync_method
144-
torchft_enabled = ft_config.enable
145-
if torchft_enabled and semi_sync_method is not None:
109+
if ft_config.enable and semi_sync_method is not None:
146110
from torchft import local_sgd
147111

148112
assert (

torchtitan/components/optimizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def build_optimizers(
243243
model_parts: list[nn.Module],
244244
job_config: JobConfig,
245245
parallel_dims: ParallelDims,
246-
ft_manager: FTManager,
246+
ft_manager: FTManager | None = None,
247247
) -> OptimizersContainer:
248248
"""Create a OptimizersContainer for the given model parts and job config.
249249
@@ -273,7 +273,7 @@ def build_optimizers(
273273
raise NotImplementedError(
274274
"Optimizers in backward is not supported with Pipeline Parallel."
275275
)
276-
if ft_manager.enabled:
276+
if ft_manager and ft_manager.enabled:
277277
raise NotImplementedError(
278278
"TorchFT is not supported with optimizers in backward."
279279
)
@@ -313,7 +313,7 @@ def build_optimizers(
313313
model_parts, optimizer_cls, optimizer_kwargs
314314
)
315315

316-
if ft_manager.enabled:
316+
if ft_manager and ft_manager.enabled:
317317
return FTOptimizersContainer(
318318
model_parts,
319319
optimizer_cls,

torchtitan/distributed/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _dist_reduce(
2929
x: torch.Tensor,
3030
reduceOp: str,
3131
mesh: DeviceMesh,
32-
extra_pg: dist.ProcessGroup | None = None,
32+
extra_pg: dist.ProcessGroup | None,
3333
) -> float:
3434
"""Perform distributed reduction on a tensor.
3535

torchtitan/experiments/llama4/optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def build_llama4_optimizers(
4848
model_parts: list[nn.Module],
4949
job_config: JobConfig,
5050
parallel_dims: ParallelDims,
51-
ft_manager: FTManager,
51+
ft_manager: FTManager | None = None,
5252
) -> OptimizersContainer:
5353
optimizers = build_optimizers(
5454
model_parts=model_parts,

torchtitan/protocols/train_spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None:
7878
TokenizerBuilder: TypeAlias = Callable[..., BaseTokenizer]
7979
MetricsProcessorBuilder: TypeAlias = Callable[..., MetricsProcessor]
8080
OptimizersBuilder: TypeAlias = Callable[
81-
[list[nn.Module], JobConfig, ParallelDims, FTManager],
81+
[list[nn.Module], JobConfig, ParallelDims, FTManager | None],
8282
OptimizersContainer,
8383
]
8484
LRSchedulersBuilder: TypeAlias = Callable[

torchtitan/train.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
import torch
1414
from torch.distributed.elastic.multiprocessing.errors import record
1515

16-
import torchtitan.components.ft as ft
1716
import torchtitan.protocols.train_spec as train_spec_module
1817
from torchtitan.components.checkpoint import CheckpointManager
1918
from torchtitan.components.dataloader import DataloaderStopIteration
19+
from torchtitan.components.ft import FTManager, maybe_semi_sync_training
2020
from torchtitan.components.loss import rescale_accumulated_loss
2121
from torchtitan.components.metrics import (
2222
build_metrics_processor,
@@ -50,7 +50,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful):
5050

5151
# non-swappable training components
5252
checkpointer: CheckpointManager
53-
ft_manager: ft.FTManager
53+
ft_manager: FTManager
5454

5555
# runtime utilities
5656
device: torch.device
@@ -104,11 +104,8 @@ def __init__(self, job_config: JobConfig):
104104
else:
105105
dp_degree, dp_rank = 1, 0
106106

107-
self.ft_manager = ft.init_ft_manager(job_config)
108-
# If TorchFT is enabled, the dp_rank and dp_degree, which are used for
109-
# dataloader must be changed.
110-
if self.ft_manager.enabled:
111-
dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank)
107+
self.ft_manager = FTManager(job_config.fault_tolerance)
108+
dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank)
112109

113110
# take control of garbage collection to avoid stragglers
114111
self.gc_handler = utils.GarbageCollection(
@@ -259,11 +256,7 @@ def __init__(self, job_config: JobConfig):
259256

260257
self.model_parts = [model]
261258

262-
if (
263-
self.ft_manager.enabled
264-
and job_config.fault_tolerance.semi_sync_method is None
265-
):
266-
self.ft_manager.set_all_reduce_hook(self.model_parts)
259+
self.ft_manager.maybe_set_all_reduce_hook(self.model_parts)
267260

268261
# initialize device memory monitor and get peak flops for MFU calculation
269262
device_memory_monitor = self.metrics_processor.device_memory_monitor
@@ -474,14 +467,9 @@ def train_step(
474467
if not self.metrics_processor.should_log(self.step):
475468
return
476469

477-
if parallel_dims.dp_cp_enabled or self.ft_manager.enabled:
470+
if parallel_dims.dp_cp_enabled:
478471
loss = loss.detach()
479-
# Skip ft manager communication when using semi sync training
480-
use_ft_pg = (
481-
self.ft_manager.enabled
482-
and self.job_config.fault_tolerance.semi_sync_method is None
483-
)
484-
ft_pg = self.ft_manager.replicate_pg if use_ft_pg else None
472+
ft_pg = self.ft_manager.loss_sync_pg
485473
global_avg_loss, global_max_loss = (
486474
dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg),
487475
dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg),
@@ -508,8 +496,8 @@ def train(self):
508496
maybe_enable_memory_snapshot(
509497
job_config, global_step=self.step
510498
) as memory_profiler,
511-
ft.maybe_semi_sync_training(
512-
job_config,
499+
maybe_semi_sync_training(
500+
job_config.fault_tolerance,
513501
ft_manager=self.ft_manager,
514502
model_parts=self.model_parts,
515503
optimizer=self.optimizers,

0 commit comments

Comments
 (0)