Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 35 additions & 13 deletions src/weathergen/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,14 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None):

if cf.with_ddp and cf.with_fsdp:
fsdp_kwargs = {
"mp_policy": MixedPrecisionPolicy(
param_dtype=self.mixed_precision_dtype,
reduce_dtype=torch.float32,
)
if cf.with_mixed_precision
else None,
"mp_policy": (
MixedPrecisionPolicy(
param_dtype=self.mixed_precision_dtype,
reduce_dtype=torch.float32,
)
if cf.with_mixed_precision
else None
),
}
modules_to_shard = (
MLP,
Expand Down Expand Up @@ -252,12 +254,14 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None):
fully_shard(module, **fsdp_kwargs)

full_precision_fsdp_kwargs = {
"mp_policy": MixedPrecisionPolicy(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
)
if cf.with_mixed_precision
else None,
"mp_policy": (
MixedPrecisionPolicy(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
)
if cf.with_mixed_precision
else None
),
}
for module in self.model.pred_adapter_kv.modules():
if isinstance(module, modules_to_shard):
Expand All @@ -274,7 +278,6 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None):
for tensor in itertools.chain(self.model.parameters(), self.model.buffers()):
assert tensor.device == torch.device("meta")

# load model if specified
if run_id_contd is None:
self.model.to_empty(device="cuda")
self.model.reset_parameters()
Expand Down Expand Up @@ -714,6 +717,25 @@ def load_model(self, run_id: str, epoch=-1):
# choose `assign=True` for sharded model since we cannot call `copy_` on meta tensor
mkeys, ukeys = self.model.load_state_dict(maybe_sharded_sd, strict=False, assign=True)

if mkeys:
# Get the unique parent modules for the missing parameters
new_modules_to_init = {key.rsplit(".", 1)[0] for key in mkeys}

# Find the highest-level "root" new modules to avoid redundant initializations
root_new_modules = set()
for path in sorted(list(new_modules_to_init)):
if not any(path.startswith(root + ".") for root in root_new_modules):
root_new_modules.add(path)

# Get all modules for quick lookup and initialize the new ones
all_modules = dict(self.model.named_modules())
for path in root_new_modules:
if is_root():
logger.info(f"Initializing new module not found in checkpoint: {path}")
module_to_init = all_modules[path]
module_to_init.to_empty(device="cuda")
module_to_init.reset_parameters()

if not is_model_sharded:
self.model = self.model.to(self.device)

Expand Down
Loading