Skip to content

[bug]: I don't thin bfloat16 config option is being used. #5799

@Vargol

Description

@Vargol

Is there an existing issue for this problem?

  • I have searched the existing issues

Operating system

macOS

GPU vendor

Apple Silicon (MPS)

GPU model

M3 base 10 GPU revision

GPU VRAM

24

Version number

v3.7.0

Browser

Safari 17.2.1

Python dependencies

{
"accelerate": "0.27.2",
"compel": "2.0.2",
"cuda": null,
"diffusers": "0.26.3",
"numpy": "1.26.4",
"opencv": "4.9.0.80",
"onnx": "1.15.0",
"pillow": "10.2.0",
"python": "3.10.13",
"torch": "2.3.0.dev20240221",
"torchvision": "0.18.0.dev20240221",
"transformers": "4.37.2",
"xformers": null
}

What happened

Recent PyTorch Nightlies have added some bfloat16 for MPS, and testing Diffusers with them showed there enough support for Stable Diffusion to run and give a small but statically significant decrease in seconds per iteration.

I thats to see if I could use the nightlies with InvokeAI now the basicSR dependancy has been removed and that testing worked fine. So I set the config to bfloat16.

Everything worked but I saw no change to render times.

Digging though the code and I spotted a few bits of code that look like they force the use of float16., including one
that I think prevents the use of bfloat16 on all formats.

I made the following code changes

invokeai/backend/util/devices.py

made choose_precision allow use of bfloat16 for MPS

def choose_precision(device: torch.device) -> str:
    """Returns an appropriate precision for the given torch device"""
    if device.type == "cuda":
        device_name = torch.cuda.get_device_name(device)
        if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
            if config.precision == "bfloat16":
                return "bfloat16"
            else:
                return "float16"
    elif device.type == "mps":
        if config.precision == "bfloat16":
            return "bfloat16"
        else:
            return "float16"
    return "float32"

This broke Compel so had to do this quick fix, this needs a better fix as I think the issue is purely a MPS torch issue.

invokeai/app/invocations/compel.py

comment out both occurrences of dtype_for_device_getter, line 125 and 248

dtype_for_device_getter=torch_dtype,

And finally the one I think prevents bfloat16 usage on all platforms

invokeai/app/services/model_manager/model_manager_default.py line 65

        dtype = torch.float32 if precision == "float32" elif precision == "bfloat16" else  torch.float16

replaced with

        if precision == "float32":
            dtype = torch.float32
        elif precision == "bfloat16":
            dtype = torch.bfloat16
        else:
            dtype = torch.float16

and that seems to do the trick and I see a small speed up I'd expect.

I suspect there are other places, where similar changes need to be made as the same

invokeai/app/invocations/latent.py
loads the vae as either float16 or float32, not bfloat16. Not sure if that is actually necessary.

Model installation
invokeai/app/services/model_install/model_install_default.py

_guess_variant() looks like it will try and get the 32bit version of models for bfloat16 instead of the fp16 variant by default (andI've noticed one or two bf16 variants recently)

Model loading
invokeai/backend/model_management/models/base.py I think line 285 means for bfloat16 it will load the fp32 model in preference to the fp16 variant if both are installed.

What you expected to happen

Modifying the config file to use bfloat16 should be enough to use bfloat16 precision.

How to reproduce the problem

Changer the invoke.yaml file to use bfloat16, you may need some debug code to conform its actually still using float16, and to run a Diffusers script to test exactly how bfloat16 should behave with your hardware.

Additional context

No response

Discord username

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    4.0.0bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions