Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __call__(self, parser, namespace, values, option_string=None):
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")


fp_group = parser.add_mutually_exclusive_group()
Expand Down
15 changes: 10 additions & 5 deletions cuda_malloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import importlib.util
from comfy.cli_args import args
import subprocess
import logging

#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
def get_gpu_names():
Expand Down Expand Up @@ -50,7 +51,7 @@ def enum_display_devices():
"GeForce GTX 1650", "GeForce GTX 1630", "Tesla M4", "Tesla M6", "Tesla M10", "Tesla M40", "Tesla M60"
}

def cuda_malloc_supported():
def device_cuda_malloc_supported():
try:
names = get_gpu_names()
except:
Expand All @@ -62,8 +63,8 @@ def cuda_malloc_supported():
return False
return True


if not args.cuda_malloc:
def cuda_malloc_supported():
software_supported = False
try:
version = ""
torch_spec = importlib.util.find_spec("torch")
Expand All @@ -76,16 +77,20 @@ def cuda_malloc_supported():
version = module.__version__

if int(version[0]) >= 2 and "+cu" in version: #enable by default for torch version 2.0 and up only on cuda torch
args.cuda_malloc = cuda_malloc_supported()
software_supported = True
except:
pass

return (software_supported and device_cuda_malloc_supported())

if args.cuda_malloc and not args.disable_cuda_malloc:
if args.cuda_malloc:
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
if env_var is None:
env_var = "backend:cudaMallocAsync"
else:
env_var += ",backend:cudaMallocAsync"

if not cuda_malloc_supported():
logging.warning("WARNING: this card most likely does not support cuda-malloc\n")

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
Loading