Skip to content

Commit f414753

Browse files
committed
Support for WAN 2.2
1 parent c9229c3 commit f414753

File tree

10 files changed

+1524
-51
lines changed

10 files changed

+1524
-51
lines changed

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,15 @@ To generate images, run the following command:
481481

482482
```bash
483483
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/
484-
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445
484+
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py --model_name=wan2.1 attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445
485+
```
486+
## Wan2.2
487+
488+
Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).
489+
490+
```bash
491+
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/
492+
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py --model_name=wan2.2 attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445
485493
```
486494

487495
## Flux

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def create_orbax_checkpoint_manager(
6161
if checkpoint_type == FLUX_CHECKPOINT:
6262
item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config")
6363
elif checkpoint_type == WAN_CHECKPOINT:
64-
item_names = ("wan_state", "wan_config")
64+
item_names = ("low_noise_transformer_state", "high_noise_transformer_state", "wan_state", "wan_config")
6565
else:
6666
item_names = (
6767
"unet_config",
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from abc import ABC
18+
import json
19+
20+
import jax
21+
import numpy as np
22+
from typing import Optional, Tuple
23+
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
24+
from ..pipelines.wan.wan_pipeline2_2 import WanPipeline
25+
from .. import max_logging, max_utils
26+
import orbax.checkpoint as ocp
27+
from etils import epath
28+
29+
WAN_CHECKPOINT = "WAN_CHECKPOINT"
30+
31+
32+
class WanCheckpointer(ABC):
33+
34+
def __init__(self, config, checkpoint_type):
35+
self.config = config
36+
self.checkpoint_type = checkpoint_type
37+
self.opt_state = None
38+
39+
self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager(
40+
self.config.checkpoint_dir,
41+
enable_checkpointing=True,
42+
save_interval_steps=1,
43+
checkpoint_type=checkpoint_type,
44+
dataset_type=config.dataset_type,
45+
)
46+
47+
def _create_optimizer(self, model, config, learning_rate):
48+
learning_rate_scheduler = max_utils.create_learning_rate_schedule(
49+
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
50+
)
51+
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
52+
return tx, learning_rate_scheduler
53+
54+
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
55+
if step is None:
56+
step = self.checkpoint_manager.latest_step()
57+
max_logging.log(f"Latest WAN checkpoint step: {step}")
58+
if step is None:
59+
max_logging.log("No WAN checkpoint found.")
60+
return None, None
61+
max_logging.log(f"Loading WAN checkpoint from step {step}")
62+
metadatas = self.checkpoint_manager.item_metadata(step)
63+
64+
low_noise_transformer_metadata = metadatas.low_noise_transformer_state
65+
abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata)
66+
low_params_restore = ocp.args.PyTreeRestore(
67+
restore_args=jax.tree.map(
68+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
69+
abstract_tree_structure_low_params,
70+
)
71+
)
72+
73+
high_noise_transformer_metadata = metadatas.high_noise_transformer_state
74+
abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata)
75+
high_params_restore = ocp.args.PyTreeRestore(
76+
restore_args=jax.tree.map(
77+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
78+
abstract_tree_structure_high_params,
79+
)
80+
)
81+
82+
max_logging.log("Restoring WAN checkpoint")
83+
restored_checkpoint = self.checkpoint_manager.restore(
84+
directory=epath.Path(self.config.checkpoint_dir),
85+
step=step,
86+
args=ocp.args.Composite(
87+
low_noise_transformer_state=low_params_restore,
88+
high_noise_transformer_state=high_params_restore,
89+
wan_config=ocp.args.JsonRestore(),
90+
),
91+
)
92+
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
93+
max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}")
94+
max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}")
95+
max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}")
96+
max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}")
97+
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
98+
return restored_checkpoint, step
99+
100+
def load_diffusers_checkpoint(self):
101+
pipeline = WanPipeline.from_pretrained(self.config)
102+
return pipeline
103+
104+
def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]:
105+
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
106+
opt_state = None
107+
if restored_checkpoint:
108+
max_logging.log("Loading WAN pipeline from checkpoint")
109+
pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint)
110+
# Check for optimizer state in either transformer
111+
if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys():
112+
opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"]
113+
elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys():
114+
opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"]
115+
else:
116+
max_logging.log("No checkpoint found, loading default pipeline.")
117+
pipeline = self.load_diffusers_checkpoint()
118+
119+
return pipeline, opt_state, step
120+
121+
def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict):
122+
"""Saves the training state and model configurations."""
123+
124+
def config_to_json(model_or_config):
125+
return json.loads(model_or_config.to_json_string())
126+
127+
max_logging.log(f"Saving checkpoint for step {train_step}")
128+
items = {
129+
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
130+
}
131+
132+
items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"])
133+
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"])
134+
135+
# Save the checkpoint
136+
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
137+
max_logging.log(f"Checkpoint for step {train_step} saved.")
138+
139+
140+
def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict):
141+
"""Saves the training state and model configurations."""
142+
143+
def config_to_json(model_or_config):
144+
"""
145+
only save the config that is needed and can be serialized to JSON.
146+
"""
147+
if not hasattr(model_or_config, "config"):
148+
return None
149+
source_config = dict(model_or_config.config)
150+
151+
# 1. configs that can be serialized to JSON
152+
SAFE_KEYS = [
153+
"_class_name",
154+
"_diffusers_version",
155+
"model_type",
156+
"patch_size",
157+
"num_attention_heads",
158+
"attention_head_dim",
159+
"in_channels",
160+
"out_channels",
161+
"text_dim",
162+
"freq_dim",
163+
"ffn_dim",
164+
"num_layers",
165+
"cross_attn_norm",
166+
"qk_norm",
167+
"eps",
168+
"image_dim",
169+
"added_kv_proj_dim",
170+
"rope_max_seq_len",
171+
"pos_embed_seq_len",
172+
"flash_min_seq_length",
173+
"flash_block_sizes",
174+
"attention",
175+
"_use_default_values",
176+
]
177+
178+
# 2. save the config that are in the SAFE_KEYS list
179+
clean_config = {}
180+
for key in SAFE_KEYS:
181+
if key in source_config:
182+
clean_config[key] = source_config[key]
183+
184+
# 3. deal with special data type and precision
185+
if "dtype" in source_config and hasattr(source_config["dtype"], "name"):
186+
clean_config["dtype"] = source_config["dtype"].name # e.g 'bfloat16'
187+
188+
if "weights_dtype" in source_config and hasattr(source_config["weights_dtype"], "name"):
189+
clean_config["weights_dtype"] = source_config["weights_dtype"].name
190+
191+
if "precision" in source_config and isinstance(source_config["precision"]):
192+
clean_config["precision"] = source_config["precision"].name # e.g. 'HIGHEST'
193+
194+
return clean_config
195+
196+
items_to_save = {
197+
"transformer_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
198+
}
199+
200+
items_to_save["transformer_states"] = ocp.args.PyTreeSave(train_states)
201+
202+
# Create CompositeArgs for Orbax
203+
save_args = ocp.args.Composite(**items_to_save)
204+
205+
# Save the checkpoint
206+
self.checkpoint_manager.save(train_step, args=save_args)
207+
max_logging.log(f"Checkpoint for step {train_step} saved.")

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ save_config_to_gcs: False
2828
log_period: 100
2929

3030
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
31+
model_name: wan2.1
3132

3233
# Overrides the transformer from pretrained_model_name_or_path
3334
wan_transformer_pretrained_model_name_or_path: ''

0 commit comments

Comments
 (0)