4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- import copy
8
7
import importlib
9
8
from contextlib import nullcontext
10
9
from typing import ContextManager , Optional , TYPE_CHECKING , Union
11
10
12
11
import torch
13
12
import torch .distributed as dist
14
- import torch .distributed ._functional_collectives as funcol
15
13
from torch .distributed ._composable .fsdp .fully_shard import FSDPModule
16
- from torch .distributed .device_mesh import DeviceMesh
17
14
from torch .distributed .distributed_c10d import ReduceOp
18
- from torch .distributed .tensor import DTensor
19
- from torchtitan .config_manager import JobConfig
15
+ from torchtitan .config_manager import FaultTolerance as FTConfig
20
16
21
17
if importlib .util .find_spec ("torchft" ) is not None :
22
18
import torchft as ft
32
28
class FTManager :
33
29
def __init__ (
34
30
self ,
35
- manager : Optional ["ft.Manager" ],
36
- group_size : int = 1 ,
37
- replica_id : int = 0 ,
31
+ ft_config : FTConfig ,
38
32
) -> None :
39
- self ._manager = manager
40
- self .group_size = group_size
41
- self .replica_id = replica_id
42
- if has_torchft and manager is not None :
33
+ if not ft_config .enable :
34
+ self ._manager = None
35
+ return
36
+
37
+ if not has_torchft :
38
+ raise ImportError ("torchft is not installed. Please install it." )
39
+
40
+ pg = ft .ProcessGroupNCCL ()
41
+
42
+ # If the training method is specific, then the quorum should be synchronous
43
+ self .use_async_quorum = ft_config .semi_sync_method is None
44
+
45
+ self ._manager = ft .Manager (
46
+ pg = pg ,
47
+ min_replica_size = ft_config .min_replica_size ,
48
+ load_state_dict = None ,
49
+ state_dict = None ,
50
+ use_async_quorum = self .use_async_quorum ,
51
+ replica_id = f"torchtitan_ft_{ ft_config .replica_id } " ,
52
+ )
53
+ self .group_size = ft_config .group_size
54
+ self .replica_id = ft_config .replica_id
55
+
56
+ if self .use_async_quorum :
43
57
self .replicate_pg = ft .process_group .ManagedProcessGroup (self ._manager )
44
58
self .replicate_pg .register ("dp_replicate" )
45
59
@@ -53,96 +67,46 @@ def manager(self) -> "ft.Manager":
53
67
return self ._manager
54
68
55
69
def get_dp_info (self , dp_degree : int , dp_rank : int ) -> tuple [int , int ]:
56
- return dp_degree * self .group_size , dp_degree * self .replica_id + dp_rank
57
-
58
- def set_all_reduce_hook (self , model_parts : list [torch .nn .Module ]) -> None :
59
- def all_reduce_hook (output ):
60
- dist .all_reduce (output , group = self .replicate_pg , op = ReduceOp .AVG )
61
-
62
- def apply_set_all_reduce_hook (m ):
63
- if isinstance (m , FSDPModule ):
64
- m .set_all_reduce_hook (all_reduce_hook )
65
-
66
- for part in model_parts :
67
- part .apply (apply_set_all_reduce_hook )
68
-
69
-
70
- def init_ft_manager (job : JobConfig ) -> FTManager :
71
- """Initialize the FT manager if TorchFT is enabled.
72
-
73
- Args:
74
- job (JobConfig): The job configuration.
75
-
76
- Returns:
77
- FTManager: A wrapper around TorchFT.Manager
78
- """
79
- if not job .fault_tolerance .enable :
80
- return FTManager (None )
81
-
82
- if not has_torchft :
83
- raise ImportError ("torchft is not installed. Please install it." )
84
-
85
- if job .fault_tolerance .min_replica_size < 1 :
86
- raise ValueError ("At least one FT replica is required." )
87
-
88
- pg = ft .ProcessGroupNCCL ()
70
+ if self .enabled :
71
+ return dp_degree * self .group_size , dp_degree * self .replica_id + dp_rank
72
+ else :
73
+ return dp_degree , dp_rank
89
74
90
- # If the training method is specific, then the quorum should be synchronous
91
- use_async_quorum = job . fault_tolerance . semi_sync_method is None
75
+ def maybe_set_all_reduce_hook ( self , model_parts : list [ torch . nn . Module ]) -> None :
76
+ if self . enabled and self . use_async_quorum :
92
77
93
- return FTManager (
94
- ft .Manager (
95
- pg = pg ,
96
- min_replica_size = job .fault_tolerance .min_replica_size ,
97
- load_state_dict = None ,
98
- state_dict = None ,
99
- use_async_quorum = use_async_quorum ,
100
- replica_id = f"torchtitan_ft_{ job .fault_tolerance .replica_id } " ,
101
- ),
102
- group_size = job .fault_tolerance .group_size ,
103
- replica_id = job .fault_tolerance .replica_id ,
104
- )
105
-
106
-
107
- def ft_dist_reduce (
108
- x : torch .Tensor , reduceOp : str , mesh : DeviceMesh
109
- ) -> tuple [torch .Tensor , str , DeviceMesh ]:
110
- if has_torchft and isinstance (mesh , ft .device_mesh ._FlattenDeviceMesh ):
111
- x = funcol .all_reduce (
112
- x , reduceOp = reduceOp , group = mesh .managed_mesh .replicate_pg
113
- )
114
- return x , reduceOp , mesh .managed_mesh .mesh
115
- return x , reduceOp , mesh
78
+ def all_reduce_hook (output ):
79
+ dist .all_reduce (output , group = self .replicate_pg , op = ReduceOp .AVG )
116
80
81
+ def apply_set_all_reduce_hook (m ):
82
+ if isinstance (m , FSDPModule ):
83
+ m .set_all_reduce_hook (all_reduce_hook )
117
84
118
- def ft_clip_grad_norm_util (total_norm : DTensor ) -> torch .Tensor :
119
- if has_torchft :
120
- mesh = total_norm ._spec .mesh
121
- if isinstance (mesh , ft .device_mesh .ManagedDeviceMesh ):
122
- # The gradients along the replicated dim has already been reduced.
123
- # So we don't need another reducution beforing removing the
124
- # replicate dimension
125
- local_tensor = total_norm .to_local ()
126
- placements = list (copy .copy (total_norm ._spec .placements ))
127
- placements .pop (mesh .replicate_dim )
128
- return DTensor .from_local (local_tensor , mesh .mesh , placements )
85
+ for model_part in model_parts :
86
+ model_part .apply (apply_set_all_reduce_hook )
129
87
130
- return total_norm
88
+ @property
89
+ def loss_sync_pg (
90
+ self ,
91
+ ) -> Optional ["ft.process_group.ManagedProcessGroup" ]:
92
+ if self .enabled and self .use_async_quorum :
93
+ return self .replicate_pg
94
+ else :
95
+ # skip loss sync when using semi-sync training
96
+ return None
131
97
132
98
133
99
def maybe_semi_sync_training (
134
- config : JobConfig ,
100
+ ft_config : FTConfig ,
135
101
ft_manager : FTManager ,
136
102
model_parts : list [torch .nn .Module ],
137
103
optimizer : torch .optim .Optimizer ,
138
104
) -> ContextManager [Union ["local_sgd.DiLoCo" , "local_sgd.LocalSGD" , None ]]:
139
105
"""
140
106
If TorchFT is enabled and the config is set, use semi_sync_method
141
107
"""
142
- ft_config = config .fault_tolerance
143
108
semi_sync_method = ft_config .semi_sync_method
144
- torchft_enabled = ft_config .enable
145
- if torchft_enabled and semi_sync_method is not None :
109
+ if ft_config .enable and semi_sync_method is not None :
146
110
from torchft import local_sgd
147
111
148
112
assert (
0 commit comments