We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3aa09b9 commit 327a99cCopy full SHA for 327a99c
torchtitan/experiments/forge/engine.py
@@ -215,7 +215,11 @@ def __init__(self, job_config: ForgeJobConfig):
215
lr_schedulers=self.lr_schedulers,
216
states={"train_state": self},
217
checkpoint_config=job_config.checkpoint,
218
- sd_adapter=self.train_spec.state_dict_adapter,
+ sd_adapter=(
219
+ self.train_spec.state_dict_adapter(model_args)
220
+ if self.train_spec.state_dict_adapter
221
+ else None
222
+ ),
223
)
224
225
loss_parallel_enabled = (
0 commit comments