@@ -116,45 +116,13 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
116116
117117 # 2. eval_shape - will not use flops or create weights on device
118118 # thus not using HBM memory.
119- p_model_factory = partial (create_model , wan_config = wan_config )
120- wan_transformer = nnx .eval_shape (p_model_factory , rngs = rngs )
121- graphdef , state , rest_of_state = nnx .split (wan_transformer , nnx .Param , ...)
122-
123- # 3. retrieve the state shardings, mapping logical names to mesh axis names.
124- logical_state_spec = nnx .get_partition_spec (state )
125- logical_state_sharding = nn .logical_to_mesh_sharding (logical_state_spec , mesh , config .logical_axis_rules )
126- logical_state_sharding = dict (nnx .to_flat_state (logical_state_sharding ))
127- params = state .to_pure_dict ()
128- state = dict (nnx .to_flat_state (state ))
119+ with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
120+ p_model_factory = partial (create_model , wan_config = wan_config )
121+ wan_transformer = p_model_factory (rngs = rngs )
129122
130123 # 4. Load pretrained weights and move them to device using the state shardings from (3) above.
131124 # This helps with loading sharded weights directly into the accelerators without fist copying them
132125 # all to one device and then distributing them, thus using low HBM memory.
133- if restored_checkpoint :
134- if "params" in restored_checkpoint ["wan_state" ]: # if checkpointed with optimizer
135- params = restored_checkpoint ["wan_state" ]["params" ]
136- else : # if not checkpointed with optimizer
137- params = restored_checkpoint ["wan_state" ]
138- else :
139- params = load_wan_transformer (
140- config .wan_transformer_pretrained_model_name_or_path ,
141- params ,
142- "cpu" ,
143- num_layers = wan_config ["num_layers" ],
144- scan_layers = config .scan_layers ,
145- )
146-
147- params = jax .tree_util .tree_map_with_path (
148- lambda path , x : cast_with_exclusion (path , x , dtype_to_cast = config .weights_dtype ), params
149- )
150- for path , val in flax .traverse_util .flatten_dict (params ).items ():
151- if restored_checkpoint :
152- path = path [:- 1 ]
153- sharding = logical_state_sharding [path ].value
154- state [path ].value = device_put_replicated (val , sharding )
155- state = nnx .from_flat_state (state )
156-
157- wan_transformer = nnx .merge (graphdef , state , rest_of_state )
158126 return wan_transformer
159127
160128
@@ -392,9 +360,10 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_
392360 tokenizer = cls .load_tokenizer (config = config )
393361
394362 scheduler , scheduler_state = cls .load_scheduler (config = config )
395-
396- with mesh :
397- wan_vae , vae_cache = cls .load_vae (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config )
363+ wan_vae = None
364+ vae_cache = None
365+ # with mesh:
366+ # wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
398367
399368 return WanPipeline (
400369 tokenizer = tokenizer ,
@@ -429,9 +398,10 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
429398 tokenizer = cls .load_tokenizer (config = config )
430399
431400 scheduler , scheduler_state = cls .load_scheduler (config = config )
432-
433- with mesh :
434- wan_vae , vae_cache = cls .load_vae (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config )
401+ wan_vae = None
402+ vae_cache = None
403+ # with mesh:
404+ # wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
435405
436406 pipeline = WanPipeline (
437407 tokenizer = tokenizer ,
0 commit comments