59
59
convert_fp8_layer_to_linear ,
60
60
convert_fp8_model_to_16b_model ,
61
61
detect_device ,
62
+ estimate_tuning_block_mem ,
62
63
find_matching_blocks ,
63
64
flatten_list ,
64
65
get_block_names ,
66
+ get_device_memory ,
65
67
get_layer_config_by_gguf_format ,
66
68
get_layer_features ,
67
69
get_layer_names_in_block ,
@@ -228,20 +230,19 @@ def __init__(
228
230
logger .warning ("`device` is deprecated, please use `device_map` instead" )
229
231
230
232
self .vlm = kwargs .pop ("vlm" ) if "vlm" in kwargs else False
233
+ # Scale factor for RAM usage per parameter.
234
+ self .mem_per_param_scale = kwargs .pop ("mem_per_param_scale" , None )
231
235
232
236
if kwargs :
233
237
logger .warning (f"unrecognized keys { list (kwargs .keys ())} were passed. Please check them." )
234
238
235
- if device_map is not None and "," in str (device_map ):
236
- raise ValueError (
237
- "API does not support explicit set multiple devices," " please set CUDA_VISIBLE_DEVICES=0,1 yourself"
238
- )
239
239
if device_map is None :
240
240
device_map = 0
241
241
242
242
# Set device, must place after model loading
243
243
if isinstance (device_map , (str , torch .device , int )):
244
244
self .device = detect_device (device_map )
245
+
245
246
elif isinstance (device_map , dict ) and device_map :
246
247
tmp_devices = []
247
248
for val in device_map .values ():
@@ -258,8 +259,12 @@ def __init__(
258
259
259
260
self .device = tmp_devices [0 ]
260
261
261
- if isinstance (device_map , dict ) and device_map :
262
+ if ( isinstance (device_map , dict ) and device_map ) or device_map == "auto" :
262
263
self .device_map = device_map
264
+ elif isinstance (device_map , str ) and "," in device_map :
265
+ device_map = device_map .replace (" " , "" ) # Remove any spaces
266
+ self .device_list = [int (dev ) for dev in device_map .split ("," ) if dev .isdigit ()]
267
+ self .device_map = "auto"
263
268
else :
264
269
self .device_map = None
265
270
self ._set_device_map_in_blocks (self .device_map )
@@ -543,6 +548,8 @@ def _set_device_map_in_blocks(self, device_map: Union[str, dict, None]) -> None:
543
548
self .device_map = None
544
549
if not device_map :
545
550
return
551
+ if self .device_map == "auto" and device_map == "auto" :
552
+ return
546
553
if isinstance (device_map , str ):
547
554
device_map = device_map .replace (" " , "" )
548
555
infos = device_map .split ("," )
@@ -583,6 +590,71 @@ def _set_device_for_matching_module(self, name: str, device: str) -> None:
583
590
else :
584
591
module .tuning_device = device
585
592
593
+ def _set_auto_device_map_in_block (self , block : torch .nn .Module , input_ids : list [torch .Tensor ]) -> None :
594
+ """Automatically sets the device map for the block based on available GPUs and memory constraints."""
595
+ if torch .cuda .is_available ():
596
+ num_gpus = torch .cuda .device_count ()
597
+ elif torch .xpu .is_available ():
598
+ logger .warning_once ("XPU does not support auto device map yet, using device 0 for tuning." )
599
+ return
600
+ else :
601
+ raise RuntimeError ("No CUDA or XPU devices found." )
602
+ if num_gpus <= 1 :
603
+ self .device_map = None
604
+ return
605
+
606
+ if hasattr (self , "device_list" ) and self .device_list :
607
+ cuda_devices = [f"cuda:{ i } " for i in self .device_list ]
608
+ device_0 = cuda_devices [0 ]
609
+ else :
610
+ cuda_devices = [f"cuda:{ i } " for i in range (num_gpus )]
611
+ device_0 = "cuda:0"
612
+
613
+ device_0_memory = get_device_memory (
614
+ self .device_list [0 ] if hasattr (self , "device_list" ) and self .device_list else 0
615
+ )
616
+ block_memory , input_ouput_memory = estimate_tuning_block_mem (block , input_ids )
617
+ if self .low_gpu_mem_usage :
618
+ input_ouput_memory = 0
619
+
620
+ mem_per_param_scale = 13 if self .mem_per_param_scale is None else self .mem_per_param_scale
621
+ if self .iters == 0 :
622
+ mem_per_param_scale = 1 # for rtn
623
+
624
+ if (block_memory * mem_per_param_scale + input_ouput_memory ) < device_0_memory :
625
+ return # fit in one GPU
626
+
627
+ device_map = {}
628
+ device_memory = {device : get_device_memory (int (device .split (":" )[1 ])) for device in cuda_devices }
629
+ device_memory [device_0 ] = device_0_memory - input_ouput_memory
630
+
631
+ device_idx = 0
632
+ # First, fill device 0 to its maximum capacity, then distribute the remaining layers evenly across other devices
633
+ for n , m in block .named_modules ():
634
+ if check_to_quantized (m ):
635
+ layer_name = block .tmp_name + "." + n
636
+ layer_memory = m .weight .nbytes / 1024 ** 3
637
+ if device_idx == 0 and layer_memory * mem_per_param_scale < device_memory [cuda_devices [device_idx ]]:
638
+ device_map [layer_name ] = cuda_devices [device_idx ]
639
+ device_memory [cuda_devices [device_idx ]] -= layer_memory * mem_per_param_scale
640
+ elif device_idx == 0 :
641
+ device_idx += 1 # Move to the next device once device 0 is full
642
+ device_map [layer_name ] = cuda_devices [device_idx ]
643
+ device_memory [cuda_devices [device_idx ]] -= layer_memory * mem_per_param_scale
644
+ else :
645
+ # Calculate the target device index based on even distribution
646
+ sorted_devices = sorted (cuda_devices , key = lambda d : device_memory [d ], reverse = True )
647
+ device_idx = sorted_devices [0 ]
648
+ if layer_memory * mem_per_param_scale < device_memory [device_idx ]:
649
+ device_map [layer_name ] = device_idx
650
+ device_memory [device_idx ] -= layer_memory * mem_per_param_scale
651
+ else :
652
+ logger .warning_once (
653
+ f"Block { block .tmp_name } not fit in available GPU memory. "
654
+ "Consider using more GPUs or reducing mem_per_param_scale if OOM occurs."
655
+ )
656
+ self ._set_device_map_in_blocks (device_map )
657
+
586
658
def _dq_check (self ) -> None :
587
659
"""Reset the default value of super_bits and super_group_size"""
588
660
if self .data_type .endswith ("_dq" ):
@@ -1488,6 +1560,10 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
1488
1560
block = block .to (self .device )
1489
1561
if _is_fp8_model (self .model ):
1490
1562
convert_fp8_model_to_16b_model (block , dtype = self .amp_dtype )
1563
+
1564
+ if self .device_map == "auto" :
1565
+ self ._set_auto_device_map_in_block (block , input_ids )
1566
+
1491
1567
# Dispatch model if needed
1492
1568
if self .device_map is not None :
1493
1569
from accelerate .hooks import AlignDevicesHook , add_hook_to_module
@@ -2551,6 +2627,9 @@ def _quantize_block(
2551
2627
new_layer = convert_fp8_layer_to_linear (m , self .amp_dtype ).to (device )
2552
2628
set_module (block , n , new_layer )
2553
2629
2630
+ if self .device_map == "auto" :
2631
+ self ._set_auto_device_map_in_block (block , input_ids )
2632
+
2554
2633
if self .device_map is not None :
2555
2634
for n , m in block .named_modules ():
2556
2635
if len (list (m .children ())) != 0 or not hasattr (m , "tuning_device" ):
0 commit comments