Skip to content

Commit fced2ba

Browse files
authored
refine nvfp code, typofix (#777)
1 parent ca3f733 commit fced2ba

File tree

8 files changed

+26
-18
lines changed

8 files changed

+26
-18
lines changed

auto_round/autoround.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -397,9 +397,7 @@ def __init__(
397397
self.infer_bs_coeff = 1
398398
self.enable_torch_compile = enable_torch_compile
399399
self._adjust_torch_compile(enable_torch_compile)
400-
401400
self._check_configs()
402-
403401
torch.set_printoptions(precision=3, sci_mode=True)
404402

405403
if is_optimum_habana_available():
@@ -2032,7 +2030,6 @@ def calib(self, nsamples, bs):
20322030
if isinstance(data_new, torch.Tensor):
20332031
self.model(data_new)
20342032
elif isinstance(data_new, tuple) or isinstance(data_new, list):
2035-
20362033
self.model(*data_new)
20372034
else:
20382035
self.model(**data_new)
@@ -2499,7 +2496,7 @@ def get_act_max_hook(module, input, output):
24992496
module.act_max = act_max
25002497
else:
25012498
act_max = act_max.to(module.act_max.device)
2502-
if is_nv_fp(self.data_type): ## for nvfp per-tensor input_global_scale calculation usage
2499+
if is_nv_fp(self.act_data_type): ## for nvfp per-tensor input_global_scale calculation usage
25032500
module.act_max = torch.max(
25042501
torch.tensor([act_max.max(), module.act_max.max()], device=act_max.device)
25052502
)
@@ -2736,10 +2733,13 @@ def _quantize_block(
27362733
with torch.no_grad():
27372734
unwrapper_block(block, best_params)
27382735

2736+
if is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats):
2737+
from auto_round.utils import set_amax_for_all_moe_layers
2738+
2739+
# enable moe experts act_max automatic generation for WrapperWALayer
2740+
set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max")
2741+
27392742
if self.enable_quanted_input:
2740-
if is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats):
2741-
# Enable moe experts act_max automatic generation for WrapperWALayer
2742-
set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max")
27432743
if self.low_cpu_mem_usage:
27442744
block = block.to(device)
27452745
clear_memory()

auto_round/data_type/nvfp.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def nv_fp4(tensor, bits=4, group_size=16, v=0, global_scale=None, **kwargs):
8888
if global_scale is None:
8989
tensor_max = tensor.abs().max().to(torch.float32)
9090
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX * get_reciprocal(tensor_max)
91-
global_scale = global_scale.to(tensor.device)
91+
global_scale = global_scale.to(device=tensor.device, dtype=torch.float32)
9292
qdq_res, scale = ref_nvfp4_quant(tensor, global_scale, group_size, v)
9393
qdq_res = revert_tensor_by_pad(qdq_res, orig_shape=orig_shape, pad_len=pad_len)
9494
return qdq_res.to(orig_dtype), scale, None
@@ -102,11 +102,14 @@ def nv_fp4_with_static_gs(tensor, bits=4, group_size=16, v=0, tensor_max=None, *
102102
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
103103
if tensor_max is None:
104104
tensor_max = tensor.abs().max().to(torch.float32)
105-
elif tensor_max is not None:
105+
else:
106106
if not isinstance(tensor_max, torch.Tensor):
107107
tensor_max = torch.tensor(tensor_max, device=tensor.device, dtype=torch.float32)
108+
else:
109+
tensor_max = tensor_max.to(device=tensor.device, dtype=torch.float32)
108110
if tensor_max.numel() != 1:
109-
tensor_max = tensor.abs().max().to(torch.float32)
111+
tensor_max = tensor_max.abs().max()
112+
110113
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX * get_reciprocal(tensor_max)
111114
global_scale = global_scale.to(tensor.device)
112115
qdq_res, scale = ref_nvfp4_quant(tensor, global_scale, group_size, v)

auto_round/export/export_to_autoround/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
267267
if is_nv_fp(data_type) or is_mx_fp(data_type): ## detect nvfp & mxfp first
268268
from auto_round.export.export_to_autoround.export_to_nvfp_mxfp import save_quantized_as_fp
269269

270-
return save_quantized_as_fp(output_dir, inplace=inplace, backend="auto_round", **kwargs)
270+
return save_quantized_as_fp(output_dir, inplace=inplace, backend="auto_round:llm_compressor", **kwargs)
271271

272272
if kwargs.get("data_type", "int") == "fp" and kwargs.get("bits", 16) == 8 and kwargs.get("act_bits", 16) >= 16:
273273
from auto_round.export.export_to_autoround.export_to_fp8 import save_quantized_as_autoround

auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
]
4848

4949

50+
@torch.compile()
5051
def pack_layer(name, model, backend):
5152
if name == "lm_head": # TODO: Check vLLM inference status to determine whether to enable this feature
5253
return
@@ -157,8 +158,6 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs):
157158
quantization_config["block_name_to_quantize"] = quantization_config.pop("to_quant_block_names", None)
158159
quantization_config["quant_method"] = "auto-round"
159160
quantization_config["packing_format"] = backend
160-
quantization_config["scale_format"] = ("e8m0",)
161-
quantization_config["scale_calculation_mode"] = ("even",)
162161

163162
tokenizer = kwargs.get("tokenizer", None)
164163
processor = kwargs.get("processor", None)

auto_round/export/export_to_autoround/qlinear_fp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,17 @@ def pack(self, linear, scales, zeros=None, g_idx=None, global_scale=None, input_
149149
elif torch.xpu.is_available():
150150
device = "xpu:0"
151151

152-
W = linear.weight.data.to(device).clone() # TODO check is nesscessory
152+
W = linear.weight.data.to(device).clone()
153153
if isinstance(linear, nn.Conv2d):
154154
W = W.flatten(1)
155155
if isinstance(linear, transformers.pytorch_utils.Conv1D):
156156
W = W.t()
157157

158-
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(linear.weight, self.group_size)
158+
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(W, self.group_size)
159+
scales = scales.to(device)
159160
if self.is_nv:
160161
assert global_scale is not None and global_scale.numel() == 1
162+
global_scale = global_scale.to(device)
161163
scaled_tensor = tensor.to(global_scale.dtype) * get_reciprocal(
162164
scales.reshape(tensor.shape[0], -1) * get_reciprocal(global_scale)
163165
)

auto_round/special_model_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def _handle_special_model(model):
112112

113113

114114
def _handle_moe_model(model):
115-
if model.config.model_type in CONVERT_EXPERT_TO_LINEAR_MODELS:
115+
if hasattr(model.config, "model_type") and model.config.model_type in CONVERT_EXPERT_TO_LINEAR_MODELS:
116116
from tqdm import tqdm
117117

118118
from auto_round.utils import clear_memory

auto_round/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,11 @@ def _get_llm_block_names(model):
435435
return block_names
436436

437437
def _get_vlm_block_names(model, quant_vision=False):
438-
if hasattr(model, "config") and model.config.model_type in SPECIAL_MULTIMODAL_BLOCK.keys():
438+
if (
439+
hasattr(model, "config")
440+
and hasattr(model.config, "model_type")
441+
and model.config.model_type in SPECIAL_MULTIMODAL_BLOCK.keys()
442+
):
439443
return SPECIAL_MULTIMODAL_BLOCK.get(model.config.model_type)(model, quant_vision=quant_vision)
440444
block_names = []
441445
target_modules = []

auto_round/wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def _qdq_act(self, x, act_max_scale, act_max=None):
246246
data_type=self.act_data_type,
247247
max_scale=act_max_scale,
248248
tensor_max=act_max,
249-
global_scale=getattr(self, "weight_global_scale", None),
249+
global_scale=getattr(self, "input_global_scale", None),
250250
)
251251
return x, scale, zp
252252

0 commit comments

Comments
 (0)