Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,7 @@ def __init__(
self.infer_bs_coeff = 1
self.enable_torch_compile = enable_torch_compile
self._adjust_torch_compile(enable_torch_compile)

self._check_configs()

torch.set_printoptions(precision=3, sci_mode=True)

if is_optimum_habana_available():
Expand Down Expand Up @@ -2032,7 +2030,6 @@ def calib(self, nsamples, bs):
if isinstance(data_new, torch.Tensor):
self.model(data_new)
elif isinstance(data_new, tuple) or isinstance(data_new, list):

self.model(*data_new)
else:
self.model(**data_new)
Expand Down Expand Up @@ -2499,7 +2496,7 @@ def get_act_max_hook(module, input, output):
module.act_max = act_max
else:
act_max = act_max.to(module.act_max.device)
if is_nv_fp(self.data_type): ## for nvfp per-tensor input_global_scale calculation usage
if is_nv_fp(self.act_data_type): ## for nvfp per-tensor input_global_scale calculation usage
module.act_max = torch.max(
torch.tensor([act_max.max(), module.act_max.max()], device=act_max.device)
)
Expand Down Expand Up @@ -2736,10 +2733,13 @@ def _quantize_block(
with torch.no_grad():
unwrapper_block(block, best_params)

if is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats):
from auto_round.utils import set_amax_for_all_moe_layers

# enable moe experts act_max automatic generation for WrapperWALayer
set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max")

if self.enable_quanted_input:
if is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats):
# Enable moe experts act_max automatic generation for WrapperWALayer
set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max")
if self.low_cpu_mem_usage:
block = block.to(device)
clear_memory()
Expand Down
9 changes: 6 additions & 3 deletions auto_round/data_type/nvfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def nv_fp4(tensor, bits=4, group_size=16, v=0, global_scale=None, **kwargs):
if global_scale is None:
tensor_max = tensor.abs().max().to(torch.float32)
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX * get_reciprocal(tensor_max)
global_scale = global_scale.to(tensor.device)
global_scale = global_scale.to(device=tensor.device, dtype=torch.float32)
qdq_res, scale = ref_nvfp4_quant(tensor, global_scale, group_size, v)
qdq_res = revert_tensor_by_pad(qdq_res, orig_shape=orig_shape, pad_len=pad_len)
return qdq_res.to(orig_dtype), scale, None
Expand All @@ -102,11 +102,14 @@ def nv_fp4_with_static_gs(tensor, bits=4, group_size=16, v=0, tensor_max=None, *
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
if tensor_max is None:
tensor_max = tensor.abs().max().to(torch.float32)
elif tensor_max is not None:
else:
if not isinstance(tensor_max, torch.Tensor):
tensor_max = torch.tensor(tensor_max, device=tensor.device, dtype=torch.float32)
else:
tensor_max = tensor_max.to(device=tensor.device, dtype=torch.float32)
if tensor_max.numel() != 1:
tensor_max = tensor.abs().max().to(torch.float32)
tensor_max = tensor_max.abs().max()

global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX * get_reciprocal(tensor_max)
global_scale = global_scale.to(tensor.device)
qdq_res, scale = ref_nvfp4_quant(tensor, global_scale, group_size, v)
Expand Down
2 changes: 1 addition & 1 deletion auto_round/export/export_to_autoround/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
if is_nv_fp(data_type) or is_mx_fp(data_type): ## detect nvfp & mxfp first
from auto_round.export.export_to_autoround.export_to_nvfp_mxfp import save_quantized_as_fp

return save_quantized_as_fp(output_dir, inplace=inplace, backend="auto_round", **kwargs)
return save_quantized_as_fp(output_dir, inplace=inplace, backend="auto_round:llm_compressor", **kwargs)

if kwargs.get("data_type", "int") == "fp" and kwargs.get("bits", 16) == 8 and kwargs.get("act_bits", 16) >= 16:
from auto_round.export.export_to_autoround.export_to_fp8 import save_quantized_as_autoround
Expand Down
3 changes: 1 addition & 2 deletions auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
]


@torch.compile()
def pack_layer(name, model, backend):
if name == "lm_head": # TODO: Check vLLM inference status to determine whether to enable this feature
return
Expand Down Expand Up @@ -157,8 +158,6 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs):
quantization_config["block_name_to_quantize"] = quantization_config.pop("to_quant_block_names", None)
quantization_config["quant_method"] = "auto-round"
quantization_config["packing_format"] = backend
quantization_config["scale_format"] = ("e8m0",)
quantization_config["scale_calculation_mode"] = ("even",)

tokenizer = kwargs.get("tokenizer", None)
processor = kwargs.get("processor", None)
Expand Down
6 changes: 4 additions & 2 deletions auto_round/export/export_to_autoround/qlinear_fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,17 @@ def pack(self, linear, scales, zeros=None, g_idx=None, global_scale=None, input_
elif torch.xpu.is_available():
device = "xpu:0"

W = linear.weight.data.to(device).clone() # TODO check is nesscessory
W = linear.weight.data.to(device).clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this clone is nesscessory

if isinstance(linear, nn.Conv2d):
W = W.flatten(1)
if isinstance(linear, transformers.pytorch_utils.Conv1D):
W = W.t()

tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(linear.weight, self.group_size)
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(W, self.group_size)
scales = scales.to(device)
if self.is_nv:
assert global_scale is not None and global_scale.numel() == 1
global_scale = global_scale.to(device)
scaled_tensor = tensor.to(global_scale.dtype) * get_reciprocal(
scales.reshape(tensor.shape[0], -1) * get_reciprocal(global_scale)
)
Expand Down
2 changes: 1 addition & 1 deletion auto_round/special_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _handle_special_model(model):


def _handle_moe_model(model):
if model.config.model_type in CONVERT_EXPERT_TO_LINEAR_MODELS:
if hasattr(model.config, "model_type") and model.config.model_type in CONVERT_EXPERT_TO_LINEAR_MODELS:
from tqdm import tqdm

from auto_round.utils import clear_memory
Expand Down
6 changes: 5 additions & 1 deletion auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,11 @@ def _get_llm_block_names(model):
return block_names

def _get_vlm_block_names(model, quant_vision=False):
if hasattr(model, "config") and model.config.model_type in SPECIAL_MULTIMODAL_BLOCK.keys():
if (
hasattr(model, "config")
and hasattr(model.config, "model_type")
and model.config.model_type in SPECIAL_MULTIMODAL_BLOCK.keys()
):
return SPECIAL_MULTIMODAL_BLOCK.get(model.config.model_type)(model, quant_vision=quant_vision)
block_names = []
target_modules = []
Expand Down
2 changes: 1 addition & 1 deletion auto_round/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def _qdq_act(self, x, act_max_scale, act_max=None):
data_type=self.act_data_type,
max_scale=act_max_scale,
tensor_max=act_max,
global_scale=getattr(self, "weight_global_scale", None),
global_scale=getattr(self, "input_global_scale", None),
)
return x, scale, zp

Expand Down