Skip to content

Commit c0a62dc

Browse files
committed
remove pretrained checkpoint loading.
1 parent c9229c3 commit c0a62dc

File tree

2 files changed

+11
-42
lines changed

2 files changed

+11
-42
lines changed

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 11 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -116,45 +116,13 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
116116

117117
# 2. eval_shape - will not use flops or create weights on device
118118
# thus not using HBM memory.
119-
p_model_factory = partial(create_model, wan_config=wan_config)
120-
wan_transformer = nnx.eval_shape(p_model_factory, rngs=rngs)
121-
graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...)
122-
123-
# 3. retrieve the state shardings, mapping logical names to mesh axis names.
124-
logical_state_spec = nnx.get_partition_spec(state)
125-
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
126-
logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding))
127-
params = state.to_pure_dict()
128-
state = dict(nnx.to_flat_state(state))
119+
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
120+
p_model_factory = partial(create_model, wan_config=wan_config)
121+
wan_transformer = p_model_factory(rngs=rngs)
129122

130123
# 4. Load pretrained weights and move them to device using the state shardings from (3) above.
131124
# This helps with loading sharded weights directly into the accelerators without fist copying them
132125
# all to one device and then distributing them, thus using low HBM memory.
133-
if restored_checkpoint:
134-
if "params" in restored_checkpoint["wan_state"]: # if checkpointed with optimizer
135-
params = restored_checkpoint["wan_state"]["params"]
136-
else: # if not checkpointed with optimizer
137-
params = restored_checkpoint["wan_state"]
138-
else:
139-
params = load_wan_transformer(
140-
config.wan_transformer_pretrained_model_name_or_path,
141-
params,
142-
"cpu",
143-
num_layers=wan_config["num_layers"],
144-
scan_layers=config.scan_layers,
145-
)
146-
147-
params = jax.tree_util.tree_map_with_path(
148-
lambda path, x: cast_with_exclusion(path, x, dtype_to_cast=config.weights_dtype), params
149-
)
150-
for path, val in flax.traverse_util.flatten_dict(params).items():
151-
if restored_checkpoint:
152-
path = path[:-1]
153-
sharding = logical_state_sharding[path].value
154-
state[path].value = device_put_replicated(val, sharding)
155-
state = nnx.from_flat_state(state)
156-
157-
wan_transformer = nnx.merge(graphdef, state, rest_of_state)
158126
return wan_transformer
159127

160128

@@ -392,9 +360,10 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_
392360
tokenizer = cls.load_tokenizer(config=config)
393361

394362
scheduler, scheduler_state = cls.load_scheduler(config=config)
395-
396-
with mesh:
397-
wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
363+
wan_vae = None
364+
vae_cache = None
365+
# with mesh:
366+
# wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
398367

399368
return WanPipeline(
400369
tokenizer=tokenizer,
@@ -429,9 +398,10 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
429398
tokenizer = cls.load_tokenizer(config=config)
430399

431400
scheduler, scheduler_state = cls.load_scheduler(config=config)
432-
433-
with mesh:
434-
wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
401+
wan_vae = None
402+
vae_cache = None
403+
# with mesh:
404+
# wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
435405

436406
pipeline = WanPipeline(
437407
tokenizer=tokenizer,

src/maxdiffusion/train_wan.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def main(argv: Sequence[str]) -> None:
3535
config = pyconfig.config
3636
validate_train_config(config)
3737
max_logging.log(f"Found {jax.device_count()} devices.")
38-
flax.config.update("flax_always_shard_variable", False)
3938
train(config)
4039

4140

0 commit comments

Comments
 (0)