Skip to content

WanVACETransformer3DModel with GGUF not working for 1.3B model #11878

Open
@nitinmukesh

Description

@nitinmukesh

Describe the bug

The support for GGUF in WanVace was added in this PR
#11807

This maybe working for 14B model (not tested) but not working for 1.3B. Didn't posted the issue earlier but now confirmed it's not only me who is facing issue.
#11807 (comment)

Reproduction

from typing import List
import torch
import PIL.Image
from diffusers import AutoencoderKLWan, WanVACEPipeline, WanVACETransformer3DModel
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from diffusers.utils import export_to_video, load_image, load_video
from diffusers import GGUFQuantizationConfig

model_id = "Wan-AI/Wan2.1-VACE-1.3B-diffusers"
# transformer_path = f"https://huggingface.co/newgenai79/Wan-VACE-1.3B-diffusers-gguf/blob/main/Wan-VACE-1.3B-diffusers-Q8_0.gguf"
transformer_path = f"https://huggingface.co/calcuis/wan-gguf/blob/main/wan2.1-v4-vace-1.3b-q4_0.gguf"
transformer_gguf = WanVACETransformer3DModel.from_single_file(
    transformer_path,
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanVACEPipeline.from_pretrained(
    model_id,
    transformer=transformer_gguf,
    vae=vae, 
    torch_dtype=torch.bfloat16
)
flow_shift = 3.0  # 5.0 for 720P, 3.0 for 480P
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
pipe.enable_model_cpu_offload()
pipe.vae.enable_tiling()


prompt = "A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of field that keeps the focus on the robot while subtly blurring the background for a cinematic effect."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

output = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=832,
    height=480,
    num_frames=81,
    num_inference_steps=30,
    guidance_scale=5.0,
    conditioning_scale=0.0,
    generator=torch.Generator().manual_seed(0),
).frames[0]
export_to_video(output, "output_GGUF1.mp4", fps=16)

Logs

config.json: 100%|████████████████████████████████████████████████████████████████████████████| 662/662 [00:00<?, ?B/s]
config.json: 100%|████████████████████████████████████████████████████████████████████████████| 724/724 [00:00<?, ?B/s]
diffusion_pytorch_model.safetensors: 100%|██████████████████████████████████████████| 508M/508M [00:15<00:00, 33.7MB/s]
model_index.json: 100%|███████████████████████████████████████████████████████████████████████| 408/408 [00:00<?, ?B/s]
scheduler_config.json: 100%|██████████████████████████████████████████████████████████████████| 751/751 [00:00<?, ?B/s]
special_tokens_map.json: 7.08kB [00:00, 1.23MB/s]                                       | 2/13 [00:00<00:01,  5.87it/s]
config.json: 100%|████████████████████████████████████████████████████████████████████████████| 850/850 [00:00<?, ?B/s]
model.safetensors.index.json: 22.5kB [00:00, ?B/s]                                      | 3/13 [00:00<00:01,  5.92it/s]
tokenizer_config.json: 61.8kB [00:00, ?B/s]                                                  | 0.00/850 [00:00<?, ?B/s]
spiece.model: 100%|███████████████████████████████████████████████████████████████| 4.55M/4.55M [00:00<00:00, 10.3MB/s]
tokenizer.json: 100%|█████████████████████████████████████████████████████████████| 16.8M/16.8M [00:03<00:00, 4.45MB/s]
model-00003-of-00003.safetensors: 100%|███████████████████████████████████████████| 1.44G/1.44G [02:59<00:00, 8.03MB/s]
model-00002-of-00003.safetensors: 100%|███████████████████████████████████████████| 4.98G/4.98G [03:52<00:00, 21.4MB/s]
model-00001-of-00003.safetensors: 100%|███████████████████████████████████████████| 4.94G/4.94G [04:25<00:00, 18.6MB/s]
Fetching 13 files: 100%|███████████████████████████████████████████████████████████████| 13/13 [04:26<00:00, 20.49s/it]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 146.18it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████| 5/5 [00:00<00:00, 11.13it/s]
  0%|                                                                                           | 0/30 [00:02<?, ?it/s]
Error: Python: Traceback (most recent call last):
  File ".\python\Lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\diffusers\pipelines\wan\pipeline_wan_vace.py", line 909, in __call__
    noise_pred = self.transformer(
                 ^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\accelerate\hooks.py", line 175, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\diffusers\models\transformers\transformer_wan_vace.py", line 324, in forward
    temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
                                                                              ^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\diffusers\models\transformers\transformer_wan.py", line 178, in forward
    temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\diffusers\models\embeddings.py", line 1308, in forward
    sample = self.linear_1(sample)
             ^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\diffusers\quantizers\gguf\utils.py", line 460, in forward
    output = torch.nn.functional.linear(inputs, weight, bias)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 must have the same dtype, but got Byte and BFloat16

System Info

Latest build from source

Who can help?

@DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions