Skip to content

Commit d7d2efa

Browse files
authored
Support auto device mapping (#781)
1 parent 4c597de commit d7d2efa

File tree

4 files changed

+159
-7
lines changed

4 files changed

+159
-7
lines changed

auto_round/autoround.py

Lines changed: 84 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,11 @@
5959
convert_fp8_layer_to_linear,
6060
convert_fp8_model_to_16b_model,
6161
detect_device,
62+
estimate_tuning_block_mem,
6263
find_matching_blocks,
6364
flatten_list,
6465
get_block_names,
66+
get_device_memory,
6567
get_layer_config_by_gguf_format,
6668
get_layer_features,
6769
get_layer_names_in_block,
@@ -228,20 +230,19 @@ def __init__(
228230
logger.warning("`device` is deprecated, please use `device_map` instead")
229231

230232
self.vlm = kwargs.pop("vlm") if "vlm" in kwargs else False
233+
# Scale factor for RAM usage per parameter.
234+
self.mem_per_param_scale = kwargs.pop("mem_per_param_scale", None)
231235

232236
if kwargs:
233237
logger.warning(f"unrecognized keys {list(kwargs.keys())} were passed. Please check them.")
234238

235-
if device_map is not None and "," in str(device_map):
236-
raise ValueError(
237-
"API does not support explicit set multiple devices," " please set CUDA_VISIBLE_DEVICES=0,1 yourself"
238-
)
239239
if device_map is None:
240240
device_map = 0
241241

242242
# Set device, must place after model loading
243243
if isinstance(device_map, (str, torch.device, int)):
244244
self.device = detect_device(device_map)
245+
245246
elif isinstance(device_map, dict) and device_map:
246247
tmp_devices = []
247248
for val in device_map.values():
@@ -258,8 +259,12 @@ def __init__(
258259

259260
self.device = tmp_devices[0]
260261

261-
if isinstance(device_map, dict) and device_map:
262+
if (isinstance(device_map, dict) and device_map) or device_map == "auto":
262263
self.device_map = device_map
264+
elif isinstance(device_map, str) and "," in device_map:
265+
device_map = device_map.replace(" ", "") # Remove any spaces
266+
self.device_list = [int(dev) for dev in device_map.split(",") if dev.isdigit()]
267+
self.device_map = "auto"
263268
else:
264269
self.device_map = None
265270
self._set_device_map_in_blocks(self.device_map)
@@ -543,6 +548,8 @@ def _set_device_map_in_blocks(self, device_map: Union[str, dict, None]) -> None:
543548
self.device_map = None
544549
if not device_map:
545550
return
551+
if self.device_map == "auto" and device_map == "auto":
552+
return
546553
if isinstance(device_map, str):
547554
device_map = device_map.replace(" ", "")
548555
infos = device_map.split(",")
@@ -583,6 +590,71 @@ def _set_device_for_matching_module(self, name: str, device: str) -> None:
583590
else:
584591
module.tuning_device = device
585592

593+
def _set_auto_device_map_in_block(self, block: torch.nn.Module, input_ids: list[torch.Tensor]) -> None:
594+
"""Automatically sets the device map for the block based on available GPUs and memory constraints."""
595+
if torch.cuda.is_available():
596+
num_gpus = torch.cuda.device_count()
597+
elif torch.xpu.is_available():
598+
logger.warning_once("XPU does not support auto device map yet, using device 0 for tuning.")
599+
return
600+
else:
601+
raise RuntimeError("No CUDA or XPU devices found.")
602+
if num_gpus <= 1:
603+
self.device_map = None
604+
return
605+
606+
if hasattr(self, "device_list") and self.device_list:
607+
cuda_devices = [f"cuda:{i}" for i in self.device_list]
608+
device_0 = cuda_devices[0]
609+
else:
610+
cuda_devices = [f"cuda:{i}" for i in range(num_gpus)]
611+
device_0 = "cuda:0"
612+
613+
device_0_memory = get_device_memory(
614+
self.device_list[0] if hasattr(self, "device_list") and self.device_list else 0
615+
)
616+
block_memory, input_ouput_memory = estimate_tuning_block_mem(block, input_ids)
617+
if self.low_gpu_mem_usage:
618+
input_ouput_memory = 0
619+
620+
mem_per_param_scale = 13 if self.mem_per_param_scale is None else self.mem_per_param_scale
621+
if self.iters == 0:
622+
mem_per_param_scale = 1 # for rtn
623+
624+
if (block_memory * mem_per_param_scale + input_ouput_memory) < device_0_memory:
625+
return # fit in one GPU
626+
627+
device_map = {}
628+
device_memory = {device: get_device_memory(int(device.split(":")[1])) for device in cuda_devices}
629+
device_memory[device_0] = device_0_memory - input_ouput_memory
630+
631+
device_idx = 0
632+
# First, fill device 0 to its maximum capacity, then distribute the remaining layers evenly across other devices
633+
for n, m in block.named_modules():
634+
if check_to_quantized(m):
635+
layer_name = block.tmp_name + "." + n
636+
layer_memory = m.weight.nbytes / 1024**3
637+
if device_idx == 0 and layer_memory * mem_per_param_scale < device_memory[cuda_devices[device_idx]]:
638+
device_map[layer_name] = cuda_devices[device_idx]
639+
device_memory[cuda_devices[device_idx]] -= layer_memory * mem_per_param_scale
640+
elif device_idx == 0:
641+
device_idx += 1 # Move to the next device once device 0 is full
642+
device_map[layer_name] = cuda_devices[device_idx]
643+
device_memory[cuda_devices[device_idx]] -= layer_memory * mem_per_param_scale
644+
else:
645+
# Calculate the target device index based on even distribution
646+
sorted_devices = sorted(cuda_devices, key=lambda d: device_memory[d], reverse=True)
647+
device_idx = sorted_devices[0]
648+
if layer_memory * mem_per_param_scale < device_memory[device_idx]:
649+
device_map[layer_name] = device_idx
650+
device_memory[device_idx] -= layer_memory * mem_per_param_scale
651+
else:
652+
logger.warning_once(
653+
f"Block {block.tmp_name} not fit in available GPU memory. "
654+
"Consider using more GPUs or reducing mem_per_param_scale if OOM occurs."
655+
)
656+
self._set_device_map_in_blocks(device_map)
657+
586658
def _dq_check(self) -> None:
587659
"""Reset the default value of super_bits and super_group_size"""
588660
if self.data_type.endswith("_dq"):
@@ -1488,6 +1560,10 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
14881560
block = block.to(self.device)
14891561
if _is_fp8_model(self.model):
14901562
convert_fp8_model_to_16b_model(block, dtype=self.amp_dtype)
1563+
1564+
if self.device_map == "auto":
1565+
self._set_auto_device_map_in_block(block, input_ids)
1566+
14911567
# Dispatch model if needed
14921568
if self.device_map is not None:
14931569
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
@@ -2551,6 +2627,9 @@ def _quantize_block(
25512627
new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype).to(device)
25522628
set_module(block, n, new_layer)
25532629

2630+
if self.device_map == "auto":
2631+
self._set_auto_device_map_in_block(block, input_ids)
2632+
25542633
if self.device_map is not None:
25552634
for n, m in block.named_modules():
25562635
if len(list(m.children())) != 0 or not hasattr(m, "tuning_device"):

auto_round/script/llm.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,13 @@ def __init__(self, *args, **kwargs):
104104
help="minmax learning rate, if None, it will beset to be the same with lr",
105105
)
106106

107+
self.add_argument(
108+
"--mem_per_param_scale",
109+
default=13,
110+
type=float,
111+
help="Scale factor for memory per parameter, used to adjust memory usage estimation for tuning",
112+
)
113+
107114
self.add_argument("--seed", default=42, type=int, help="random seed")
108115

109116
self.add_argument("--adam", action="store_true", help="whether to use adam optimizer instead of SignSGD")
@@ -436,7 +443,7 @@ def tune(args):
436443
raise RuntimeError("marlin backend only supports sym quantization, please remove --asym")
437444

438445
# Must set this before import torch
439-
set_cuda_visible_devices(args.device_map)
446+
# set_cuda_visible_devices(args.device_map)
440447
device_str, use_auto_mapping = get_device_and_parallelism(args.device_map)
441448

442449
import torch

auto_round/script/mllm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def tune(args):
328328
raise ValueError(f"{format} is not supported, we only support {SUPPORTED_FORMATS}")
329329

330330
# Must set this before import torch
331-
set_cuda_visible_devices(args.device_map)
331+
# set_cuda_visible_devices(args.device_map)
332332
device_str, use_auto_mapping = get_device_and_parallelism(args.device_map)
333333

334334
import torch

auto_round/utils.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,10 @@ def is_valid_digit(s):
578578
if is_valid_digit(device):
579579
dev_idx = int(device)
580580
device = "auto"
581+
if isinstance(device, str) and "," in device: # device is "0,1,2"
582+
device_list = [int(dev) for dev in device.split(",") if dev.isdigit()]
583+
dev_idx = device_list[0] if device_list else None
584+
device = "auto"
581585
if device is None or device == "auto":
582586
if torch.cuda.is_available():
583587
device = torch.device("cuda")
@@ -1426,6 +1430,8 @@ def llm_load_model(
14261430
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
14271431

14281432
model_cls = AutoModel if is_glm else AutoModelForCausalLM
1433+
if "deepseek" in pretrained_model_name_or_path.lower() and trust_remote_code:
1434+
logger.warning("trust_remote_code is enabled by default, please ensure its correctness.")
14291435

14301436
if low_cpu_mem_tmp_dir is None:
14311437
low_cpu_mem_tmp_dir = "low_cpu_mem_tmp"
@@ -2563,6 +2569,66 @@ def is_static_wfp8afp8(ar):
25632569
return False
25642570

25652571

2572+
def bytes_to_gigabytes(bytes) -> int:
2573+
"""
2574+
Converts bytes to gigabytes.
2575+
2576+
Args:
2577+
bytes (int): The number of bytes.
2578+
2579+
Returns:
2580+
int: The equivalent number of gigabytes.
2581+
"""
2582+
return bytes / 1024 / 1024 / 1024
2583+
2584+
2585+
def get_device_memory(i: int = 0) -> int:
2586+
"""
2587+
Gets the available memory on the specified device.
2588+
2589+
Args:
2590+
i (int, optional): Device index. Defaults to 0.
2591+
2592+
Returns:
2593+
int: Available memory in gigabytes.
2594+
"""
2595+
if torch.cuda.is_available():
2596+
total_memory = bytes_to_gigabytes(torch.cuda.get_device_properties(i).total_memory)
2597+
elif torch.xpu.is_available():
2598+
raise RuntimeError("XPU does not support device_map='auto' currently.")
2599+
else:
2600+
raise RuntimeError("No supported device found (CUDA or XPU).")
2601+
return total_memory
2602+
2603+
2604+
def estimate_tuning_block_mem(block: torch.nn.Module, input_ids: list[torch.Tensor]) -> tuple[float, float]:
2605+
"""
2606+
Calculates the memory consumption of a specific block in the model.
2607+
2608+
Args:
2609+
block (torch.nn.Module): The block of the model to analyze.
2610+
input_ids (list[torch.Tensor]): A list of input tensors for the block.
2611+
2612+
Returns:
2613+
tuple: A tuple containing the following:
2614+
- block_memory (float): The memory consumption (in GB) of the block's linear layers.
2615+
- input_output_memory (float): The memory consumption (in GB) for input and output
2616+
tensors of the block.
2617+
"""
2618+
# Calculate all block parameters memory
2619+
total_param_mem = 0
2620+
for name, module in block.named_modules():
2621+
if check_to_quantized(module):
2622+
param_size = module.weight.nbytes
2623+
total_param_mem += param_size
2624+
block_memory = total_param_mem / 1024**3 # Convert to GB
2625+
2626+
# Assuming bfloat16 or float32, input and output
2627+
input_output_memory = 2 * sum(tensor.nbytes for tensor in input_ids) / 1024**3
2628+
2629+
return block_memory, input_output_memory
2630+
2631+
25662632
def get_max_vram(ratio: float = 0.9) -> dict:
25672633
max_memory = {}
25682634
if torch.cuda.is_available(): # NVIDIA CUDA

0 commit comments

Comments
 (0)