44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- import copy
87import importlib
98from contextlib import nullcontext
109from typing import ContextManager , Optional , TYPE_CHECKING , Union
1110
1211import torch
1312import torch .distributed as dist
14- import torch .distributed ._functional_collectives as funcol
1513from torch .distributed ._composable .fsdp .fully_shard import FSDPModule
16- from torch .distributed .device_mesh import DeviceMesh
1714from torch .distributed .distributed_c10d import ReduceOp
18- from torch .distributed .tensor import DTensor
19- from torchtitan .config_manager import JobConfig
15+ from torchtitan .config_manager import FaultTolerance as FTConfig
2016
2117if importlib .util .find_spec ("torchft" ) is not None :
2218 import torchft as ft
3228class FTManager :
3329 def __init__ (
3430 self ,
35- manager : Optional ["ft.Manager" ],
36- group_size : int = 1 ,
37- replica_id : int = 0 ,
31+ ft_config : FTConfig ,
3832 ) -> None :
39- self ._manager = manager
40- self .group_size = group_size
41- self .replica_id = replica_id
42- if has_torchft and manager is not None :
33+ if not ft_config .enable :
34+ self ._manager = None
35+ return
36+
37+ if not has_torchft :
38+ raise ImportError ("torchft is not installed. Please install it." )
39+
40+ pg = ft .ProcessGroupNCCL ()
41+
42+ # If the training method is specific, then the quorum should be synchronous
43+ self .use_async_quorum = ft_config .semi_sync_method is None
44+
45+ self ._manager = ft .Manager (
46+ pg = pg ,
47+ min_replica_size = ft_config .min_replica_size ,
48+ load_state_dict = None ,
49+ state_dict = None ,
50+ use_async_quorum = self .use_async_quorum ,
51+ replica_id = f"torchtitan_ft_{ ft_config .replica_id } " ,
52+ )
53+ self .group_size = ft_config .group_size
54+ self .replica_id = ft_config .replica_id
55+
56+ if self .use_async_quorum :
4357 self .replicate_pg = ft .process_group .ManagedProcessGroup (self ._manager )
4458 self .replicate_pg .register ("dp_replicate" )
4559
@@ -53,96 +67,46 @@ def manager(self) -> "ft.Manager":
5367 return self ._manager
5468
5569 def get_dp_info (self , dp_degree : int , dp_rank : int ) -> tuple [int , int ]:
56- return dp_degree * self .group_size , dp_degree * self .replica_id + dp_rank
57-
58- def set_all_reduce_hook (self , model_parts : list [torch .nn .Module ]) -> None :
59- def all_reduce_hook (output ):
60- dist .all_reduce (output , group = self .replicate_pg , op = ReduceOp .AVG )
61-
62- def apply_set_all_reduce_hook (m ):
63- if isinstance (m , FSDPModule ):
64- m .set_all_reduce_hook (all_reduce_hook )
65-
66- for part in model_parts :
67- part .apply (apply_set_all_reduce_hook )
68-
69-
70- def init_ft_manager (job : JobConfig ) -> FTManager :
71- """Initialize the FT manager if TorchFT is enabled.
72-
73- Args:
74- job (JobConfig): The job configuration.
75-
76- Returns:
77- FTManager: A wrapper around TorchFT.Manager
78- """
79- if not job .fault_tolerance .enable :
80- return FTManager (None )
81-
82- if not has_torchft :
83- raise ImportError ("torchft is not installed. Please install it." )
84-
85- if job .fault_tolerance .min_replica_size < 1 :
86- raise ValueError ("At least one FT replica is required." )
87-
88- pg = ft .ProcessGroupNCCL ()
70+ if self .enabled :
71+ return dp_degree * self .group_size , dp_degree * self .replica_id + dp_rank
72+ else :
73+ return dp_degree , dp_rank
8974
90- # If the training method is specific, then the quorum should be synchronous
91- use_async_quorum = job . fault_tolerance . semi_sync_method is None
75+ def maybe_set_all_reduce_hook ( self , model_parts : list [ torch . nn . Module ]) -> None :
76+ if self . enabled and self . use_async_quorum :
9277
93- return FTManager (
94- ft .Manager (
95- pg = pg ,
96- min_replica_size = job .fault_tolerance .min_replica_size ,
97- load_state_dict = None ,
98- state_dict = None ,
99- use_async_quorum = use_async_quorum ,
100- replica_id = f"torchtitan_ft_{ job .fault_tolerance .replica_id } " ,
101- ),
102- group_size = job .fault_tolerance .group_size ,
103- replica_id = job .fault_tolerance .replica_id ,
104- )
105-
106-
107- def ft_dist_reduce (
108- x : torch .Tensor , reduceOp : str , mesh : DeviceMesh
109- ) -> tuple [torch .Tensor , str , DeviceMesh ]:
110- if has_torchft and isinstance (mesh , ft .device_mesh ._FlattenDeviceMesh ):
111- x = funcol .all_reduce (
112- x , reduceOp = reduceOp , group = mesh .managed_mesh .replicate_pg
113- )
114- return x , reduceOp , mesh .managed_mesh .mesh
115- return x , reduceOp , mesh
78+ def all_reduce_hook (output ):
79+ dist .all_reduce (output , group = self .replicate_pg , op = ReduceOp .AVG )
11680
81+ def apply_set_all_reduce_hook (m ):
82+ if isinstance (m , FSDPModule ):
83+ m .set_all_reduce_hook (all_reduce_hook )
11784
118- def ft_clip_grad_norm_util (total_norm : DTensor ) -> torch .Tensor :
119- if has_torchft :
120- mesh = total_norm ._spec .mesh
121- if isinstance (mesh , ft .device_mesh .ManagedDeviceMesh ):
122- # The gradients along the replicated dim has already been reduced.
123- # So we don't need another reducution beforing removing the
124- # replicate dimension
125- local_tensor = total_norm .to_local ()
126- placements = list (copy .copy (total_norm ._spec .placements ))
127- placements .pop (mesh .replicate_dim )
128- return DTensor .from_local (local_tensor , mesh .mesh , placements )
85+ for model_part in model_parts :
86+ model_part .apply (apply_set_all_reduce_hook )
12987
130- return total_norm
88+ @property
89+ def loss_sync_pg (
90+ self ,
91+ ) -> Optional ["ft.process_group.ManagedProcessGroup" ]:
92+ if self .enabled and self .use_async_quorum :
93+ return self .replicate_pg
94+ else :
95+ # skip loss sync when using semi-sync training
96+ return None
13197
13298
13399def maybe_semi_sync_training (
134- config : JobConfig ,
100+ ft_config : FTConfig ,
135101 ft_manager : FTManager ,
136102 model_parts : list [torch .nn .Module ],
137103 optimizer : torch .optim .Optimizer ,
138104) -> ContextManager [Union ["local_sgd.DiLoCo" , "local_sgd.LocalSGD" , None ]]:
139105 """
140106 If TorchFT is enabled and the config is set, use semi_sync_method
141107 """
142- ft_config = config .fault_tolerance
143108 semi_sync_method = ft_config .semi_sync_method
144- torchft_enabled = ft_config .enable
145- if torchft_enabled and semi_sync_method is not None :
109+ if ft_config .enable and semi_sync_method is not None :
146110 from torchft import local_sgd
147111
148112 assert (
0 commit comments