Skip to content

Commit 88611cf

Browse files
authored
Bug Fix pipeline creation from diffusers (#276)
* Add ability to save optimizer and resume while training * Bug fix, case where not checkpoint load from diffusers
1 parent 662d501 commit 88611cf

File tree

3 files changed

+128
-4
lines changed

3 files changed

+128
-4
lines changed

.github/workflows/UnitTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ jobs:
5858
pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
5959
- name: PyTest
6060
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
61-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
61+
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
6262
# add_pull_ready:
6363
# if: github.ref != 'refs/heads/main'
6464
# permissions:

src/maxdiffusion/checkpointing/wan_checkpointer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import jax
2121
import numpy as np
22+
from typing import Optional, Tuple
2223
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
2324
from ..pipelines.wan.wan_pipeline import WanPipeline
2425
from .. import max_logging, max_utils
@@ -50,12 +51,13 @@ def _create_optimizer(self, model, config, learning_rate):
5051
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
5152
return tx, learning_rate_scheduler
5253

53-
def load_wan_configs_from_orbax(self, step):
54+
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
5455
if step is None:
5556
step = self.checkpoint_manager.latest_step()
5657
max_logging.log(f"Latest WAN checkpoint step: {step}")
5758
if step is None:
58-
return None
59+
max_logging.log("No WAN checkpoint found.")
60+
return None, None
5961
max_logging.log(f"Loading WAN checkpoint from step {step}")
6062
metadatas = self.checkpoint_manager.item_metadata(step)
6163
transformer_metadata = metadatas.wan_state
@@ -86,7 +88,7 @@ def load_diffusers_checkpoint(self):
8688
pipeline = WanPipeline.from_pretrained(self.config)
8789
return pipeline
8890

89-
def load_checkpoint(self, step=None):
91+
def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]:
9092
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
9193
opt_state = None
9294
if restored_checkpoint:
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
https://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
"""
13+
14+
import unittest
15+
from unittest.mock import patch, MagicMock
16+
17+
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer, WAN_CHECKPOINT
18+
19+
class WanCheckpointerTest(unittest.TestCase):
20+
def setUp(self):
21+
self.config = MagicMock()
22+
self.config.checkpoint_dir = "/tmp/wan_checkpoint_test"
23+
self.config.dataset_type = "test_dataset"
24+
25+
@patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager')
26+
@patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline')
27+
def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager):
28+
mock_manager = MagicMock()
29+
mock_manager.latest_step.return_value = None
30+
mock_create_manager.return_value = mock_manager
31+
32+
mock_pipeline_instance = MagicMock()
33+
mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance
34+
35+
checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
36+
pipeline, opt_state, step = checkpointer.load_checkpoint(step=None)
37+
38+
mock_manager.latest_step.assert_called_once()
39+
mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config)
40+
self.assertEqual(pipeline, mock_pipeline_instance)
41+
self.assertIsNone(opt_state)
42+
self.assertIsNone(step)
43+
44+
@patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager')
45+
@patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline')
46+
def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager):
47+
mock_manager = MagicMock()
48+
mock_manager.latest_step.return_value = 1
49+
metadata_mock = MagicMock()
50+
metadata_mock.wan_state = {}
51+
mock_manager.item_metadata.return_value = metadata_mock
52+
53+
restored_mock = MagicMock()
54+
restored_mock.wan_state = {'params': {}}
55+
restored_mock.wan_config = {}
56+
restored_mock.keys.return_value = ['wan_state', 'wan_config']
57+
def getitem_side_effect(key):
58+
if key == 'wan_state':
59+
return restored_mock.wan_state
60+
raise KeyError(key)
61+
restored_mock.__getitem__.side_effect = getitem_side_effect
62+
mock_manager.restore.return_value = restored_mock
63+
64+
mock_create_manager.return_value = mock_manager
65+
66+
mock_pipeline_instance = MagicMock()
67+
mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance
68+
69+
checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
70+
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
71+
72+
mock_manager.restore.assert_called_once_with(
73+
directory=unittest.mock.ANY,
74+
step=1,
75+
args=unittest.mock.ANY
76+
)
77+
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
78+
self.assertEqual(pipeline, mock_pipeline_instance)
79+
self.assertIsNone(opt_state)
80+
self.assertEqual(step, 1)
81+
82+
@patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager')
83+
@patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline')
84+
def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager):
85+
mock_manager = MagicMock()
86+
mock_manager.latest_step.return_value = 1
87+
metadata_mock = MagicMock()
88+
metadata_mock.wan_state = {}
89+
mock_manager.item_metadata.return_value = metadata_mock
90+
91+
restored_mock = MagicMock()
92+
restored_mock.wan_state = {'params': {}, 'opt_state': {'learning_rate': 0.001}}
93+
restored_mock.wan_config = {}
94+
restored_mock.keys.return_value = ['wan_state', 'wan_config']
95+
def getitem_side_effect(key):
96+
if key == 'wan_state':
97+
return restored_mock.wan_state
98+
raise KeyError(key)
99+
restored_mock.__getitem__.side_effect = getitem_side_effect
100+
mock_manager.restore.return_value = restored_mock
101+
102+
mock_create_manager.return_value = mock_manager
103+
104+
mock_pipeline_instance = MagicMock()
105+
mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance
106+
107+
checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
108+
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
109+
110+
mock_manager.restore.assert_called_once_with(
111+
directory=unittest.mock.ANY,
112+
step=1,
113+
args=unittest.mock.ANY
114+
)
115+
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
116+
self.assertEqual(pipeline, mock_pipeline_instance)
117+
self.assertIsNotNone(opt_state)
118+
self.assertEqual(opt_state['learning_rate'], 0.001)
119+
self.assertEqual(step, 1)
120+
121+
if __name__ == "__main__":
122+
unittest.main()

0 commit comments

Comments
 (0)