Skip to content

[Wan 2.2 VAE] fix VAE tiling encode/decode #12191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

miaojinc
Copy link
Contributor

What does this PR do?

Current AutoencoderKLWan lacks some patchify stuff when tiling.
Also add patch_size config for Wan VAE unit tests.

Without this, we will got error likes:

  File "/home/mjc/diffusers/src/diffusers/pipelines/wan/pipeline_wan.py", line 645, in __call__
    video = self.vae.decode(latents, return_dict=False)[0]
  File "/home/mjc/diffusers/src/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
  File "/home/mjc/diffusers/src/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1248, in decode
    decoded = self._decode(z).sample
  File "/home/mjc/diffusers/src/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1204, in _decode
    return self.tiled_decode(z, return_dict=return_dict)
  File "/home/mjc/diffusers/src/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1374, in tiled_decode
    decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
  File "/root/miniforge3/envs/wan/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniforge3/envs/wan/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mjc/diffusers/src/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 893, in forward
    x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
  File "/root/miniforge3/envs/wan/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniforge3/envs/wan/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mjc/diffusers/src/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 709, in forward
    x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 2

Before submitting

Who can review?

hi @yiyixuxu @a-r-r-o-w
Could you please help to review it, thanks

Current AutoencoderKLWan lacks some patchify stuff when tiling.
Also add patch_size config for Wan VAE unit tests.

Signed-off-by: Jincheng Miao <[email protected]>
Apply patchify/unpatchify if needed.

Signed-off-by: Jincheng Miao <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant