File tree Expand file tree Collapse file tree 1 file changed +8
-4
lines changed Expand file tree Collapse file tree 1 file changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -116,15 +116,19 @@ def maybe_semi_sync_training(
116
116
# Create the outer optimizer based on the inner optimizer parameters.
117
117
params = [group ["params" ] for group in optimizer .param_groups ]
118
118
params = [param for sublist in params for param in sublist ]
119
- outer_optimizer = torch .optim .SGD (
120
- params , lr = 0.7 , momentum = 0.9 , nesterov = True
121
- )
119
+ outer_optimizers = []
120
+ for model in model_parts :
121
+ params = [p for p in model .parameters () if p .requires_grad ]
122
+ outer_optimizer = torch .optim .SGD (
123
+ params , lr = 0.7 , momentum = 0.9 , nesterov = True
124
+ )
125
+ outer_optimizers .append (outer_optimizer )
122
126
123
127
return local_sgd .DiLoCo (
124
128
manager = ft_manager ._manager ,
125
129
model_fragments = model_parts ,
126
130
inner_optimizer = optimizer ,
127
- outer_optimizer = outer_optimizer ,
131
+ outer_optimizer = outer_optimizers ,
128
132
sync_every = ft_config .sync_steps ,
129
133
should_quantize = ft_config .should_quantize ,
130
134
fragment_sync_delay = ft_config .fragment_sync_delay ,
You can’t perform that action at this time.
0 commit comments