Skip to content

Support saving inc model for transformers-like api #2231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
parser.add_argument("--output_dir", nargs="?", default="./saved_results")
parser.add_argument("--quant_lm_head", action="store_true", help="whether to quant the lm_head layer in transformers")
parser.add_argument("--for_inference", action="store_true", help="whether to replace ipex linear for inference ")
# ============Benchmark configs==============
parser.add_argument("--benchmark", action="store_true")
parser.add_argument("--benchmark_iters", default=100, type=int, help="num iters for benchmark")
Expand Down Expand Up @@ -299,6 +300,7 @@
quantization_config=quantization_config,
trust_remote_code=args.trust_remote_code,
_commit_hash=args._commit_hash,
for_inference=args.for_inference,
)
elif args.load_in_4bit or args.load_in_8bit:
user_model = AutoModelForCausalLM.from_pretrained(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
parser.add_argument("--output_dir", nargs="?", default="./saved_results")
parser.add_argument("--quant_lm_head", action="store_true", help="whether to quant the lm_head layer in transformers")
parser.add_argument("--for_inference", action="store_true", help="whether to replace ipex linear for inference")
parser.add_argument("--use_layer_wise", nargs='?', const=True, default=None, type=lambda x: bool(strtobool(x)),
help="""whether to use layerwise quant. Case-insensitive and
true values are 'y', 'yes', 't', 'true', 'on', and '1';
Expand Down Expand Up @@ -202,6 +203,7 @@
quantization_config=quantization_config,
trust_remote_code=args.trust_remote_code,
torch_dtype=torch.float16,
for_inference=args.for_inference,
)
elif args.load_in_4bit or args.load_in_8bit:
user_model = AutoModelForCausalLM.from_pretrained(args.model,
Expand Down
18 changes: 14 additions & 4 deletions neural_compressor/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
repack_awq_and_load_state_dict,
replace_linear,
save_low_bit,
save_low_bit_for_inc,
)
from ..utils import AutoRoundConfig, AwqConfig, GPTQConfig, RtnConfig, TeqConfig

Expand Down Expand Up @@ -101,6 +102,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None)

quantization_config = kwargs.pop("quantization_config", None)
for_inference = kwargs.pop("for_inference", True)
if not isinstance(config, PretrainedConfig):
config, _ = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
Expand Down Expand Up @@ -212,7 +214,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
quantization_config.post_init_xpu()
if (device_map == "cpu" or device_map == torch.device("cpu")) and model.config.model_type == "chatglm":
model = model.float()
model = convert_to_quantized_model(model, quantization_config, device=device_map)
model = convert_to_quantized_model(
model, quantization_config, device=device_map, for_inference=for_inference
)
if isinstance(quantization_config, AwqConfig):
quantization_config.backend = "inc"
quantization_config.remove_redundant_parameters()
Expand All @@ -234,8 +238,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
device_map = torch.device(device_map) if isinstance(device_map, str) else device_map
model.hf_device_map = {"": device_map}
model.quantization_config = quantization_config

model.save_pretrained = types.MethodType(save_low_bit, model)
if for_inference:
model.save_pretrained = types.MethodType(save_low_bit, model)
else:
model.save_pretrained = types.MethodType(save_low_bit_for_inc, model)
logger.info("WeightOnlyQuant done.")
return model

Expand Down Expand Up @@ -409,7 +415,11 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
user_agent=user_agent,
revision=revision,
commit_hash=commit_hash,
is_remote_code=True,
**(
{"is_remote_code": True}
if "is_remote_code" in _get_resolved_checkpoint_files.__code__.co_varnames
else {}
),
)
is_sharded = sharded_metadata is not None
resolved_archive_file = checkpoint_files if is_sharded else checkpoint_files[0]
Expand Down
69 changes: 63 additions & 6 deletions neural_compressor/transformers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,6 @@ def _replace_linear(
model._modules[name].set_weights_bias(
module.qweight.data if hasattr(module, "qweight") else weight,
None if module.bias is None else module.bias.data,
**(
{"update_g_idx": not empty_weights}
if "update_g_idx" in model._modules[name].set_weights_bias.__code__.co_varnames
else {}
),
)
else:
raise Exception("{} device Unsupported weight only quantization!".format(device))
Expand Down Expand Up @@ -359,7 +354,7 @@ def run_fn_for_autoround(model, dataloader):
model(data)


def convert_to_quantized_model(model, config, device="cpu"):
def convert_to_quantized_model(model, config, device="cpu", for_inference=True):
if device == "xpu" or device == torch.device("xpu"):
import intel_extension_for_pytorch

Expand Down Expand Up @@ -608,6 +603,14 @@ def set_nontext_module_config(model, to_quant_block_names, config):
logger.warning("The recommended ipex version is higher than 2.3.10 for xpu device.")

model.eval()
if not for_inference:
# return inc model
if config.use_layer_wise and not (model.device == device or model.device.type == device):
logger.warning(
"Do not convert device to avoid out of memory. Recommend using saved quantized model to inference."
)
return model
return model.to(device)

q_model = replace_linear(model, None, None, config, device=device)

Expand Down Expand Up @@ -688,6 +691,60 @@ def make_contiguous(model):
param.data = param.data.contiguous()


def save_low_bit_for_inc(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):

assert hasattr(self, "quantization_config"), "Detected this model is not a low-bit model."

if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return

os.makedirs(save_directory, exist_ok=True)
# use transformers original `save_pretrained` function
del self.save_pretrained
make_contiguous(self)

self.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
self.save_pretrained = types.MethodType(save_low_bit_for_inc, self)
# We conveniently save all the keys of the model to have them on hand,
# so that when using 'low_cpumem load',
# it's not necessary to load the entire model to extract its keys
# and we can avoid gc not triggered potentially.
all_checkpoint_keys = {"all_checkpoint_keys": list(self.state_dict().keys())}
json_file_path = os.path.join(save_directory, "all_checkpoint_keys.json")
with open(json_file_path, "w") as json_file:
json.dump(all_checkpoint_keys, json_file)
if push_to_hub:
use_auth_token = kwargs.pop("use_auth_token", None)

if use_auth_token is not None:
logger.warning.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)

token = use_auth_token
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)

if token is not None:
kwargs["token"] = token
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)
self._upload_modified_files(
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("token"),
)
self.quantization_config.save_pretrained(save_directory, **kwargs)


def save_low_bit(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):

assert hasattr(self, "quantization_config"), "Detected this model is not a low-bit model."
Expand Down
31 changes: 31 additions & 0 deletions test/3x/torch/quantization/weight_only/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,34 @@ def test_vlm(self):
# model_name = "microsoft/Phi-3-vision-128k-instruct"
# woq_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True, attn_implementation='eager')
# assert isinstance(woq_model.model.layers[0].self_attn.o_proj, WeightOnlyQuantizedLinear), "quantizaion failed."

def test_save_load_for_inc_model(self):
model_name_or_path = self.model_name_or_path

fp32_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
dummy_input = fp32_model.dummy_inputs["input_ids"]

# RTN
woq_config = RtnConfig(bits=4, group_size=16)
woq_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
quantization_config=woq_config,
)
woq_output = woq_model(dummy_input)[0]

# RTN
woq_config = RtnConfig(bits=4, group_size=16)
woq_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
quantization_config=woq_config,
for_inference=False,
)

# save
output_dir = "./transformers_tmp"
woq_model.save_pretrained(output_dir)

# load
loaded_model = AutoModelForCausalLM.from_pretrained(output_dir)
loaded_output = loaded_model(dummy_input)[0]
assert torch.equal(woq_output, loaded_output), "loaded output should be same. Please double check."