Skip to content

Commit e66652a

Browse files
WeiweiZhang1pre-commit-ci[bot]wenhuach21
authored
rename llmcompressor to llm_compressor for align with other formats (#780)
* rename llmcompressor to llm_compressor for align with other formats Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add log, refine doc Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add act args for export config Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix line too long Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refine packing device, refine nvfp logging Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixtypo Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix comments Signed-off-by: Zhang, Weiwei1 <[email protected]> * fix typo Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import & log Signed-off-by: Zhang, Weiwei1 <[email protected]> * fix doctypo Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Zhang, Weiwei1 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Wenhua Cheng <[email protected]>
1 parent 7e014ca commit e66652a

File tree

15 files changed

+91
-91
lines changed

15 files changed

+91
-91
lines changed

auto_round/autoround.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ def _check_configs(self) -> None:
691691
if self.gradient_accumulate_steps <= 0:
692692
raise ValueError("`gradient_accumulate_steps` must be positive")
693693

694-
if self.act_bits <= 8:
694+
if self.act_bits <= 8 and (not is_nv_fp(self.act_data_type) or "static_gs" not in self.act_data_type):
695695
logger.warning(
696696
"activation quantization is an experimental feature with limited support and a complex API. "
697697
"And please save the quantized model to fake format as real deployment is not supported currently"
@@ -843,19 +843,21 @@ def _parse_format_to_list(self, format: str) -> list:
843843
"for the current quantization configuration, "
844844
"please change to `fake` format for research purpose"
845845
)
846-
847846
formats[index] = format
848-
elif format == "llmcompressor":
847+
elif format == "llm_compressor":
849848
from auto_round.export.export_to_llmcompressor import check_compressed_tensors_supported
850849

851850
if check_compressed_tensors_supported() and (is_nv_fp(self.data_type) or is_mx_fp(self.data_type)):
852-
format = format.replace("llmcompressor", f"llmcompressor:{self.data_type}")
851+
format = format.replace("llm_compressor", f"llm_compressor:{self.data_type}")
853852
formats[index] = format
854853
elif not is_wfp8afp8(self):
855854
logger.error(
856-
"Currently, the llmcompressor format only supports MXFP/NVFP/FP8. "
855+
"Currently, the llm_compressor format only supports MXFP/NVFP/FP8. "
857856
"Please change format to fake or auto_round etc."
858857
)
858+
else:
859+
if (is_nv_fp(self.data_type) or is_mx_fp(self.data_type)) and format != "fake":
860+
logger.warning(f"nv_fp and mx_fp dtypes are not supported for export format: {format}")
859861

860862
# Remove duplicates from formats list
861863
def remove_duplicates(lst):
@@ -887,8 +889,13 @@ def _check_supported_format(self, format: str) -> bool:
887889
# Only support to export afp8/nv_fp
888890
if self.act_bits <= 8:
889891
if not is_standard_fp(self.act_data_type) or self.act_dynamic:
890-
if format == "llmcompressor":
891-
if is_nv_fp(self.act_data_type):
892+
if "llm_compressor" in format:
893+
if is_nv_fp(self.act_data_type) and "static_gs" in self.act_data_type:
894+
logger.warning(
895+
f"AutoRound supports exporting to format '{format}', "
896+
"but loading quantized models in this format is not yet supported. "
897+
"It is currently recommended to export to the 'llm_compressor' format."
898+
)
892899
return format
893900
bits, group_size, sym, act_bits = 8, -1, True, 8
894901
assert (
@@ -899,10 +906,11 @@ def _check_supported_format(self, format: str) -> bool:
899906
and self.act_dynamic
900907
), (
901908
f"Currently only support to export llmcompressor format for dynamic quantized"
902-
f" W{self.bits}A{self.act_bits} model, but got bits={self.bits},"
903-
f" group_size={self.group_size}, sym={self.sym}, act_bits={self.act_bits}"
909+
f" W{bits}Afp{act_bits} model, but got bits={self.bits}, data_type={self.data_type}"
910+
f" group_size={self.group_size}, sym={self.sym}"
911+
f", act_bits={self.act_bits}, act_data_type={self.act_data_type}"
904912
)
905-
elif format != "fake" and not is_nv_fp(format):
913+
elif format != "fake" and (not is_nv_fp(format) or "static_gs" not in self.act_data_type):
906914
logger.warning(
907915
"Currently only support to export auto_round format quantized model"
908916
" with fp8 or nv_fp4 dtype activation for activation quantization."
@@ -1652,7 +1660,7 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
16521660
or "gptq" in formats[0]
16531661
or "auto_round" in formats[0]
16541662
or "gguf" in formats[0]
1655-
or "llmcompressor" in formats[0]
1663+
or "llm_compressor" in formats[0]
16561664
)
16571665
and self.inplace
16581666
):
@@ -3017,8 +3025,8 @@ def save_quantized(
30173025
"Support for exporting activation quantization is limited. "
30183026
"Please ensure that your configuration is supported."
30193027
)
3020-
if format == "llmcompressor" and (is_nv_fp(self.data_type) or is_mx_fp(self.data_type)):
3021-
format = format.replace("llmcompressor", f"llmcompressor:{self.data_type}")
3028+
if format == "llm_compressor" and (is_nv_fp(self.data_type) or is_mx_fp(self.data_type)):
3029+
format = format.replace("llm_compressor", f"llm_compressor:{self.data_type}")
30223030

30233031
from auto_round.export import EXPORT_FORMAT
30243032

auto_round/export/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,14 @@ def _packing_layer_with_autoawq(*args, **kwargs):
7878
return pack_layer(*args, **kwargs)
7979

8080

81-
@register_format("llmcompressor")
81+
@register_format("llm_compressor")
8282
def _save_quantized_as_llmcompressor(*args, **kwargs):
8383
from auto_round.export.export_to_llmcompressor.export import save_quantized_as_llmcompressor
8484

8585
return save_quantized_as_llmcompressor(*args, **kwargs)
8686

8787

88-
@register_layer_packing("llmcompressor")
88+
@register_layer_packing("llm_compressor")
8989
def _packing_layer_with_llmcompressor(*args, **kwargs):
9090
from auto_round.export.export_to_llmcompressor.export import pack_layer
9191

auto_round/export/export_to_autoround/export.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import transformers
2626
from tqdm import tqdm
2727

28+
from auto_round.export.export_to_autoround.utils import REQUIRED_CONFIG_KEYS, check_neq_config
2829
from auto_round.utils import (
2930
SUPPORTED_FORMATS,
3031
SUPPORTED_LAYER_TYPES,
@@ -40,8 +41,6 @@
4041
set_module,
4142
)
4243

43-
from .utils import check_neq_config
44-
4544

4645
def dynamic_import_quant_linear_for_packing(backend, bits, group_size, sym, act_bits=16):
4746
"""
@@ -313,12 +312,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
313312
block_name_to_quantize is not None and check_start_with_block_name(layer_name, block_name_to_quantize)
314313
):
315314
neq_keys = check_neq_config(
316-
layer_config[layer_name],
317-
data_type=quantization_config["data_type"],
318-
bits=quantization_config["bits"],
319-
act_bits=quantization_config["act_bits"],
320-
group_size=quantization_config["group_size"],
321-
sym=quantization_config["sym"],
315+
layer_config[layer_name], **{k: quantization_config[k] for k in REQUIRED_CONFIG_KEYS}
322316
)
323317
if len(neq_keys) > 0:
324318
extra_config[layer_name] = {}

auto_round/export/export_to_autoround/export_to_fp8.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
from tqdm import tqdm
2424

2525
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad
26+
from auto_round.export.export_to_autoround.utils import REQUIRED_CONFIG_KEYS, check_neq_config
2627
from auto_round.utils import (
2728
SUPPORTED_LAYER_TYPES,
29+
_get_device,
2830
check_start_with_block_name,
2931
check_to_quantized,
3032
filter_quantization_config,
@@ -33,8 +35,6 @@
3335
set_module,
3436
)
3537

36-
from .utils import check_neq_config
37-
3838

3939
class FP8WOQLinear(torch.nn.Module):
4040

@@ -86,11 +86,7 @@ def pack_layer(layer_name, model, data_type, packing_device=None):
8686
None: The function modifies the model in place.
8787
"""
8888
if packing_device is None:
89-
packing_device = "cpu"
90-
if torch.cuda.is_available():
91-
packing_device = "cuda"
92-
elif torch.xpu.is_available():
93-
packing_device = "xpu"
89+
packing_device = _get_device()
9490
layer = get_module(model, layer_name)
9591
if hasattr(layer, "orig_layer"):
9692
layer = layer.orig_layer
@@ -187,12 +183,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round",
187183
block_name_to_quantize is not None and check_start_with_block_name(layer_name, block_name_to_quantize)
188184
):
189185
neq_keys = check_neq_config(
190-
layer_config[layer_name],
191-
data_type=quantization_config["data_type"],
192-
bits=quantization_config["bits"],
193-
act_bits=quantization_config["act_bits"],
194-
group_size=quantization_config["group_size"],
195-
sym=quantization_config["sym"],
186+
layer_config[layer_name], **{k: quantization_config[k] for k in REQUIRED_CONFIG_KEYS}
196187
)
197188
if len(neq_keys) > 0:
198189
extra_config[layer_name] = {}
@@ -205,11 +196,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round",
205196
max_workers = 1
206197
if not torch.cuda.is_available() and not torch.xpu.is_available():
207198
max_workers = 2 ## 2 with cuda packing will cause hang occasionally
208-
packing_device = "cpu"
209-
if torch.cuda.is_available():
210-
packing_device = "cuda"
211-
elif torch.xpu.is_available():
212-
packing_device = "xpu"
199+
packing_device = _get_device()
213200
with ThreadPoolExecutor(max_workers=max_workers) as executor:
214201
with tqdm(total=len(names), leave=True) as pbar:
215202

auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import transformers
2525
from tqdm import tqdm
2626

27+
from auto_round.export.export_to_autoround.utils import REQUIRED_CONFIG_KEYS, check_neq_config
2728
from auto_round.utils import (
2829
SUPPORTED_LAYER_TYPES,
2930
check_start_with_block_name,
@@ -39,7 +40,6 @@
3940
from auto_round.wrapper import WrapperWALayer
4041

4142
from .qlinear_fp import QuantLinear
42-
from .utils import check_neq_config
4343

4444
__all__ = [
4545
"pack_layer",
@@ -203,12 +203,7 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs):
203203
block_name_to_quantize is not None and check_start_with_block_name(layer_name, block_name_to_quantize)
204204
):
205205
neq_keys = check_neq_config(
206-
layer_config[layer_name],
207-
data_type=quantization_config["data_type"],
208-
bits=quantization_config["bits"],
209-
act_bits=quantization_config["act_bits"],
210-
group_size=quantization_config["group_size"],
211-
sym=quantization_config["sym"],
206+
layer_config[layer_name], **{k: quantization_config[k] for k in REQUIRED_CONFIG_KEYS}
212207
)
213208
if len(neq_keys) > 0:
214209
extra_config[layer_name] = {}

auto_round/export/export_to_autoround/qlinear_fp.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from auto_round.data_type.mxfp import FP32_EXPONENT_BIAS, FP32_MIN_NORMAL
3939
from auto_round.data_type.nvfp import cast_to_fp4, get_reciprocal
4040
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad
41-
from auto_round.utils import is_mx_fp, is_nv_fp
41+
from auto_round.utils import _get_device, is_mx_fp, is_nv_fp
4242

4343
# from auto_round.utils import get_weight_compress_dtype
4444
logger = getLogger(__name__)
@@ -141,15 +141,11 @@ def post_init(self):
141141
pass
142142

143143
def pack(self, linear, scales, zeros=None, g_idx=None, global_scale=None, input_global_scale=None):
144-
if linear.bias is not None:
145-
self.bias = linear.bias.clone().half()
146-
device = "cpu"
147-
if torch.cuda.is_available():
148-
device = "cuda:0"
149-
elif torch.xpu.is_available():
150-
device = "xpu:0"
151-
152-
W = linear.weight.data.to(device).clone()
144+
device = _get_device()
145+
if getattr(linear, "bias", None) is not None:
146+
self.bias = linear.bias.detach().to(torch.float16)
147+
148+
W = linear.weight.data.detach().to(device)
153149
if isinstance(linear, nn.Conv2d):
154150
W = W.flatten(1)
155151
if isinstance(linear, transformers.pytorch_utils.Conv1D):
@@ -163,7 +159,8 @@ def pack(self, linear, scales, zeros=None, g_idx=None, global_scale=None, input_
163159
scaled_tensor = tensor.to(global_scale.dtype) * get_reciprocal(
164160
scales.reshape(tensor.shape[0], -1) * get_reciprocal(global_scale)
165161
)
166-
scaled_tensor = cast_to_fp4(torch.clamp(scaled_tensor, -6.0, 6.0))
162+
scaled_tensor.clamp_(-6.0, 6.0)
163+
scaled_tensor = cast_to_fp4(scaled_tensor)
167164
else:
168165
scaled_tensor = tensor / (2 ** scales.reshape(tensor.shape[0], -1))
169166
scaled_tensor = revert_tensor_by_pad(scaled_tensor, orig_shape=orig_shape, pad_len=pad_len)

auto_round/export/export_to_autoround/qlinear_triton_act.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
import torch.nn as nn
4242
import transformers
4343

44+
from auto_round.utils import _get_device
45+
4446
logger = getLogger(__name__)
4547

4648

@@ -117,16 +119,14 @@ def post_init(self):
117119
pass
118120

119121
def pack(self, linear, scales, zeros, act_scales, w_bf16_to_fp8_scale, g_idx=None):
122+
device = _get_device()
120123
scales_t = scales.t().contiguous()
121124

122125
self.act_scales.data.copy_(act_scales.squeeze().clone())
123126
self.w_bf16_to_fp8_scale.data.copy_(w_bf16_to_fp8_scale.squeeze().clone())
124127
if linear.bias is not None:
125128
self.bias = linear.bias.clone().half()
126129
self.scales = scales_t.clone().half()
127-
device = "cpu"
128-
if torch.cuda.is_available():
129-
device = "cuda:0"
130130

131131
W = linear.weight.data.to(device).clone()
132132
if isinstance(linear, nn.Conv2d):

auto_round/export/export_to_autoround/utils.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,36 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
REQUIRED_CONFIG_KEYS = (
16+
"data_type",
17+
"bits",
18+
"group_size",
19+
"sym",
20+
"act_bits",
21+
"act_data_type",
22+
"act_group_size",
23+
"act_sym",
24+
"act_dynamic",
25+
)
1526

16-
def check_neq_config(config, data_type, bits, act_bits, group_size, sym):
17-
"""
18-
Checks if the provided configuration parameters are not equal to the values in the config dictionary.
1927

20-
Args:
21-
config (dict): A dictionary containing the configuration parameters.
22-
data_type (str): The expected data type.
23-
bits (int): The expected number of bits.
24-
group_size (int): The expected group size.
25-
sym (bool): The expected symmetry flag.
28+
def check_neq_config(config: dict, **expected) -> dict[str, tuple]:
29+
"""
30+
Compare a config dict against expected values.
31+
Ensures all required keys are present in both config and expected.
2632
2733
Returns:
28-
list: A list of strings indicating which configuration parameters do not match.
34+
dict[str, tuple]: {key: (actual, expected)} for mismatched values.
2935
"""
30-
expected_config = {
31-
"data_type": data_type,
32-
"bits": bits,
33-
"group_size": group_size,
34-
"sym": sym,
35-
"act_bits": act_bits,
36-
}
37-
return [key for key, expected_value in expected_config.items() if config.get(key) != expected_value]
36+
# 1. Check missing from expected
37+
missing_expected = [k for k in REQUIRED_CONFIG_KEYS if k not in expected]
38+
if missing_expected:
39+
raise ValueError(f"Missing expected values for keys: {missing_expected}")
40+
41+
# 2. Check missing from layer config
42+
missing_config = [k for k in REQUIRED_CONFIG_KEYS if k not in config]
43+
if missing_config:
44+
raise ValueError(f"Missing config values for keys: {missing_config}")
45+
46+
# 3. Collect mismatches
47+
return {key: (config[key], expected[key]) for key in REQUIRED_CONFIG_KEYS if config[key] != expected[key]}

auto_round/export/export_to_llmcompressor/export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def pack_layer(layer_name, model, backend):
6060

6161
return pack_layer(layer_name, model, backend)
6262

63-
## passed as no other llmcompressor format is supported yet
64-
logger.warning("No other llmcompressor packing format(except NVFP&MXFP) is supported yet, skip packing")
63+
## passed as no other llm_compressor format is supported yet
64+
logger.warning("No other llm_compressor packing format(except NVFP&MXFP) is supported yet, skip packing")
6565
return
6666

6767

auto_round/script/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ def tune(args):
511511
"auto_round" not in format
512512
and "fake" not in format
513513
and "awq" not in format
514-
and "llmcompressor" not in format
514+
and "llm_compressor" not in format
515515
):
516516
# TODO gptq could support some mixed precision config
517517
logger.warning(f"mixed precision exporting does not support {format} currently")

0 commit comments

Comments
 (0)