Skip to content
7 changes: 6 additions & 1 deletion comfy/clip_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
else:
scale_size = (size, size)

image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
if image.device.type == 'musa':
image = image.cpu()
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
image = image.to('musa')
else:
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
Expand Down
11 changes: 8 additions & 3 deletions comfy/ldm/flux/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,14 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
device = torch.device("cpu")
else:
device = pos.device

scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
if device.type == "musa":
scale = torch.linspace(0, (dim - 2) / dim, steps=dim // 2, dtype=torch.float32, device=device)
if not isinstance(theta, torch.Tensor):
theta = torch.tensor(theta, dtype=torch.float32, device=device)
omega = torch.exp(-scale * torch.log(theta + 1e-6))
else:
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
Expand Down
37 changes: 37 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,15 @@ def get_supported_float8_types():
except:
mlu_available = False

try:
import torch_musa
_ = torch_musa.device_count()
musa_available = torch_musa.is_available()
if musa_available:
logging.info("MUSA device detected: {}".format(torch_musa.get_device_name(0)))
except:
musa_available = False

try:
ixuca_available = hasattr(torch, "corex")
except:
Expand Down Expand Up @@ -159,6 +168,12 @@ def is_mlu():
return True
return False

def is_musa():
global musa_available
if musa_available:
return True
return False

def is_ixuca():
global ixuca_available
if ixuca_available:
Expand All @@ -182,6 +197,8 @@ def get_torch_device():
return torch.device("npu", torch.npu.current_device())
elif is_mlu():
return torch.device("mlu", torch.mlu.current_device())
elif is_musa():
return torch.device('musa', torch.musa.current_device())
else:
return torch.device(torch.cuda.current_device())

Expand Down Expand Up @@ -215,6 +232,12 @@ def get_total_memory(dev=None, torch_total_too=False):
_, mem_total_mlu = torch.mlu.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_mlu
elif is_musa():
stats = torch.musa.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total = torch.musa.mem_get_info(dev)
mem_total_torch = mem_reserved

else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
Expand Down Expand Up @@ -1157,6 +1180,14 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_mlu + mem_free_torch
elif is_musa():
stats = torch.musa.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_musa, _ = torch.musa.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_musa + mem_free_torch

else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
Expand Down Expand Up @@ -1235,6 +1266,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_mlu():
return True

if is_musa():
return True

if is_ixuca():
return True

Expand Down Expand Up @@ -1301,6 +1335,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_ascend_npu():
return True

if is_musa():
return True

if is_ixuca():
return True

Expand Down
Loading