Skip to content

Conversation

Kaihui-intel
Copy link
Contributor

Usage

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "opensourcerelease/DeepSeek-R1-bf16"
model_name = "/data1/DeepSeek-R1-bf16"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=False, torch_dtype="auto")

from auto_round import AutoRound

autoround = AutoRound(model=model, tokenizer=tokenizer, nsamples=512,
                      batch_size=4, low_gpu_mem_usage=False,device_map="auto", seqlen=2048,
                      )

# Calculate all block linear memory except for the second modulelist
total_linear_memory = 0
for n, m in model.named_modules():
if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call get_block_names

for n, m in model.named_modules():
if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__:
for name, module in m[-1].named_modules():
if isinstance(module, torch.nn.Linear):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conv1d is also supportd

"""
total_memory = bytes_to_gigabytes(torch.cuda.get_device_properties(i).total_memory)
reserved_memory = bytes_to_gigabytes(torch.cuda.memory_reserved(i))
allocated_memory = bytes_to_gigabytes(torch.cuda.memory_allocated(i))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be better to support xpu too. For now, you could raise an exception that xpu does not devcie_map="auto"

all_blocks = get_block_names(model)
m = get_module(model, all_blocks[0][-1])
for name, module in m.named_modules():
if isinstance(module, (torch.nn.Linear, transformers.pytorch_utils.Conv1D)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use SUPPORTED DTYPES

sum(p.numel() for p in module.parameters()) * module.weight.element_size()
) # Assuming parameters are float32 (4 bytes each)
block_memory += param_size
block_memory = block_memory / 1024**3
Copy link
Contributor

@wenhuach21 wenhuach21 Sep 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for vlms, there may be different memory for different blocks. Why not porting the code to quant_blocks function

Signed-off-by: Kaihui-intel <[email protected]>
@@ -217,6 +216,7 @@ def __init__(
disable_deterministic_algorithms = kwargs.pop("disable_deterministic_algorithms", False)
static_kv_dtype = kwargs.pop("static_kv_dtype", None)
self.vlm = kwargs.pop("vlm") if "vlm" in kwargs else False
self.mem_expansion_factor = kwargs.pop("mem_expansion_factor", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ram_per_param_scale? and better have a comment to show the meaning of this variable

"""Automatically sets the device map for the model based on available GPUs and memory constraints."""
num_gpus = torch.cuda.device_count() - 1
if num_gpus == 0:
def get_block_info(self, block, input_ids, supported_types=SUPPORTED_LAYER_TYPES) -> tuple[float, float]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you suggest a more precise name, preferably one that includes ‘mem’?

tensors of the first block, assuming bfloat16 or float32 precision.
"""
# Calculate all block linear memory
total_linear_memory = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total_param_mem?

if self.low_gpu_mem_usage:
return block_memory, 0

# assuming bfloat16 or float32, input and output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

upper case

device_memory[cuda_devices[device_idx]] -= layer_memory * mem_expansion_factor
if device_idx >= len(cuda_devices):
raise ValueError(
f"model is too large to fit in {num_gpus} GPUs, "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed, for device 0, we use the mem_expansion_factor, for other devices, we just split the remaining parameters. If it's more than the layer_memory * mem_expansion_factor, logger a warning but not an exception

if self.device_map == "auto":
self.set_auto_device_map_in_block(block, input_ids)


if self.device_map is not None:
from accelerate import dispatch_model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remember to support this in this scenario auto-round --model xxx --devices 0,1,2

Signed-off-by: Kaihui-intel <[email protected]>
@@ -506,39 +507,34 @@ def _set_device_for_matching_module(self, name: str, device: str) -> None:
else:
module.tuning_device = device

def get_block_info(self, block, input_ids, supported_types=SUPPORTED_LAYER_TYPES) -> tuple[float, float]:
def get_block_mem(self, block, input_ids, supported_types=SUPPORTED_LAYER_TYPES) -> tuple[float, float]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

estimate_tuning_block_mem , predict_tuning_block_mem or something like that

logger.warning(
f"Layer {layer_name} may not fit in available GPU memory. "
"Consider lowering ram_per_param_scale, using more GPUs, "
"or reducing model size."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove reducing model size

Copy link
Contributor

@wenhuach21 wenhuach21 Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using more GPUs or reducing mem_per_param_scale if OOM occurs.

Besides, you need to add one arg mem_per_param_scale in llm.py

device_map[layer_name] = device_idx
device_memory[device_idx] -= layer_memory * ram_per_param_scale
else:
logger.warning(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to use warning_once?

if self.low_gpu_mem_usage:
return block_memory, 0

# assuming bfloat16 or float32, input and output
# Assuming bfloat16 or float32, input and output
input_bytes = 2 if self.amp_dtype != torch.float32 else 4
Copy link
Contributor

@wenhuach21 wenhuach21 Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_id[0] should have dtype

Signed-off-by: Kaihui-intel <[email protected]>
@@ -2460,6 +2549,10 @@ def _quantize_block(
new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype).to(device)
set_module(block, n, new_layer)

if self.device_map == "auto":
self.set_auto_device_map_in_block(block, input_ids)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_set?


def set_auto_device_map_in_block(self, block, input_ids, supported_types=SUPPORTED_LAYER_TYPES) -> None:
"""Automatically sets the device map for the block based on available GPUs and memory constraints."""
num_gpus = torch.cuda.device_count()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to check whether it is cuda, if it is device like xpu, we should logger a warning and try to use device 0

Signed-off-by: Kaihui-intel <[email protected]>
Signed-off-by: Kaihui-intel <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants