From 1c7fcdd0898c7723422401a603b4ee43a557c66a Mon Sep 17 00:00:00 2001 From: lspindler Date: Mon, 11 Aug 2025 10:18:50 +0200 Subject: [PATCH] Make cudaMalloc backed opt-in --- comfy/cli_args.py | 1 - cuda_malloc.py | 15 ++++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 0d760d524d07..aaac81b1374c 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -53,7 +53,6 @@ def __call__(self, parser, namespace, values, option_string=None): parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.") cm_group = parser.add_mutually_exclusive_group() cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") -cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.") fp_group = parser.add_mutually_exclusive_group() diff --git a/cuda_malloc.py b/cuda_malloc.py index c1d9ae3cab5c..96a1bea00067 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -2,6 +2,7 @@ import importlib.util from comfy.cli_args import args import subprocess +import logging #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. def get_gpu_names(): @@ -50,7 +51,7 @@ def enum_display_devices(): "GeForce GTX 1650", "GeForce GTX 1630", "Tesla M4", "Tesla M6", "Tesla M10", "Tesla M40", "Tesla M60" } -def cuda_malloc_supported(): +def device_cuda_malloc_supported(): try: names = get_gpu_names() except: @@ -62,8 +63,8 @@ def cuda_malloc_supported(): return False return True - -if not args.cuda_malloc: +def cuda_malloc_supported(): + software_supported = False try: version = "" torch_spec = importlib.util.find_spec("torch") @@ -76,16 +77,20 @@ def cuda_malloc_supported(): version = module.__version__ if int(version[0]) >= 2 and "+cu" in version: #enable by default for torch version 2.0 and up only on cuda torch - args.cuda_malloc = cuda_malloc_supported() + software_supported = True except: pass + return (software_supported and device_cuda_malloc_supported()) -if args.cuda_malloc and not args.disable_cuda_malloc: +if args.cuda_malloc: env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) if env_var is None: env_var = "backend:cudaMallocAsync" else: env_var += ",backend:cudaMallocAsync" + if not cuda_malloc_supported(): + logging.warning("WARNING: this card most likely does not support cuda-malloc\n") + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var