diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/run_generation_cpu_woq.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/run_generation_cpu_woq.py index 49166a59e34..033fe054919 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/run_generation_cpu_woq.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/run_generation_cpu_woq.py @@ -26,6 +26,7 @@ ) parser.add_argument("--output_dir", nargs="?", default="./saved_results") parser.add_argument("--quant_lm_head", action="store_true", help="whether to quant the lm_head layer in transformers") +parser.add_argument("--for_inference", action="store_true", help="whether to replace ipex linear for inference ") # ============Benchmark configs============== parser.add_argument("--benchmark", action="store_true") parser.add_argument("--benchmark_iters", default=100, type=int, help="num iters for benchmark") @@ -299,6 +300,7 @@ quantization_config=quantization_config, trust_remote_code=args.trust_remote_code, _commit_hash=args._commit_hash, + for_inference=args.for_inference, ) elif args.load_in_4bit or args.load_in_8bit: user_model = AutoModelForCausalLM.from_pretrained( diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/run_generation_gpu_woq.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/run_generation_gpu_woq.py index f277b524937..31a00b46c7c 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/run_generation_gpu_woq.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/run_generation_gpu_woq.py @@ -30,6 +30,7 @@ ) parser.add_argument("--output_dir", nargs="?", default="./saved_results") parser.add_argument("--quant_lm_head", action="store_true", help="whether to quant the lm_head layer in transformers") +parser.add_argument("--for_inference", action="store_true", help="whether to replace ipex linear for inference") parser.add_argument("--use_layer_wise", nargs='?', const=True, default=None, type=lambda x: bool(strtobool(x)), help="""whether to use layerwise quant. Case-insensitive and true values are 'y', 'yes', 't', 'true', 'on', and '1'; @@ -202,6 +203,7 @@ quantization_config=quantization_config, trust_remote_code=args.trust_remote_code, torch_dtype=torch.float16, + for_inference=args.for_inference, ) elif args.load_in_4bit or args.load_in_8bit: user_model = AutoModelForCausalLM.from_pretrained(args.model, diff --git a/neural_compressor/transformers/models/modeling_auto.py b/neural_compressor/transformers/models/modeling_auto.py index 9a8e5d62a72..dd17dfe580a 100644 --- a/neural_compressor/transformers/models/modeling_auto.py +++ b/neural_compressor/transformers/models/modeling_auto.py @@ -54,6 +54,7 @@ repack_awq_and_load_state_dict, replace_linear, save_low_bit, + save_low_bit_for_inc, ) from ..utils import AutoRoundConfig, AwqConfig, GPTQConfig, RtnConfig, TeqConfig @@ -101,6 +102,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): config = kwargs.pop("config", None) quantization_config = kwargs.pop("quantization_config", None) + for_inference = kwargs.pop("for_inference", True) if not isinstance(config, PretrainedConfig): config, _ = AutoConfig.from_pretrained( pretrained_model_name_or_path, @@ -212,7 +214,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): quantization_config.post_init_xpu() if (device_map == "cpu" or device_map == torch.device("cpu")) and model.config.model_type == "chatglm": model = model.float() - model = convert_to_quantized_model(model, quantization_config, device=device_map) + model = convert_to_quantized_model( + model, quantization_config, device=device_map, for_inference=for_inference + ) if isinstance(quantization_config, AwqConfig): quantization_config.backend = "inc" quantization_config.remove_redundant_parameters() @@ -234,8 +238,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): device_map = torch.device(device_map) if isinstance(device_map, str) else device_map model.hf_device_map = {"": device_map} model.quantization_config = quantization_config - - model.save_pretrained = types.MethodType(save_low_bit, model) + if for_inference: + model.save_pretrained = types.MethodType(save_low_bit, model) + else: + model.save_pretrained = types.MethodType(save_low_bit_for_inc, model) logger.info("WeightOnlyQuant done.") return model @@ -409,7 +415,11 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): user_agent=user_agent, revision=revision, commit_hash=commit_hash, - is_remote_code=True, + **( + {"is_remote_code": True} + if "is_remote_code" in _get_resolved_checkpoint_files.__code__.co_varnames + else {} + ), ) is_sharded = sharded_metadata is not None resolved_archive_file = checkpoint_files if is_sharded else checkpoint_files[0] diff --git a/neural_compressor/transformers/quantization/utils.py b/neural_compressor/transformers/quantization/utils.py index a2a487469e9..d0675dd4579 100644 --- a/neural_compressor/transformers/quantization/utils.py +++ b/neural_compressor/transformers/quantization/utils.py @@ -237,11 +237,6 @@ def _replace_linear( model._modules[name].set_weights_bias( module.qweight.data if hasattr(module, "qweight") else weight, None if module.bias is None else module.bias.data, - **( - {"update_g_idx": not empty_weights} - if "update_g_idx" in model._modules[name].set_weights_bias.__code__.co_varnames - else {} - ), ) else: raise Exception("{} device Unsupported weight only quantization!".format(device)) @@ -359,7 +354,7 @@ def run_fn_for_autoround(model, dataloader): model(data) -def convert_to_quantized_model(model, config, device="cpu"): +def convert_to_quantized_model(model, config, device="cpu", for_inference=True): if device == "xpu" or device == torch.device("xpu"): import intel_extension_for_pytorch @@ -608,6 +603,14 @@ def set_nontext_module_config(model, to_quant_block_names, config): logger.warning("The recommended ipex version is higher than 2.3.10 for xpu device.") model.eval() + if not for_inference: + # return inc model + if config.use_layer_wise and not (model.device == device or model.device.type == device): + logger.warning( + "Do not convert device to avoid out of memory. Recommend using saved quantized model to inference." + ) + return model + return model.to(device) q_model = replace_linear(model, None, None, config, device=device) @@ -688,6 +691,60 @@ def make_contiguous(model): param.data = param.data.contiguous() +def save_low_bit_for_inc(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + + assert hasattr(self, "quantization_config"), "Detected this model is not a low-bit model." + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + # use transformers original `save_pretrained` function + del self.save_pretrained + make_contiguous(self) + + self.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + self.save_pretrained = types.MethodType(save_low_bit_for_inc, self) + # We conveniently save all the keys of the model to have them on hand, + # so that when using 'low_cpumem load', + # it's not necessary to load the entire model to extract its keys + # and we can avoid gc not triggered potentially. + all_checkpoint_keys = {"all_checkpoint_keys": list(self.state_dict().keys())} + json_file_path = os.path.join(save_directory, "all_checkpoint_keys.json") + with open(json_file_path, "w") as json_file: + json.dump(all_checkpoint_keys, json_file) + if push_to_hub: + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + logger.warning.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + + token = use_auth_token + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + + if token is not None: + kwargs["token"] = token + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("token"), + ) + self.quantization_config.save_pretrained(save_directory, **kwargs) + + def save_low_bit(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): assert hasattr(self, "quantization_config"), "Detected this model is not a low-bit model." diff --git a/test/3x/torch/quantization/weight_only/test_transformers.py b/test/3x/torch/quantization/weight_only/test_transformers.py index 8a9dbcbecf1..1fa8c0f803d 100644 --- a/test/3x/torch/quantization/weight_only/test_transformers.py +++ b/test/3x/torch/quantization/weight_only/test_transformers.py @@ -277,3 +277,34 @@ def test_vlm(self): # model_name = "microsoft/Phi-3-vision-128k-instruct" # woq_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True, attn_implementation='eager') # assert isinstance(woq_model.model.layers[0].self_attn.o_proj, WeightOnlyQuantizedLinear), "quantizaion failed." + + def test_save_load_for_inc_model(self): + model_name_or_path = self.model_name_or_path + + fp32_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) + dummy_input = fp32_model.dummy_inputs["input_ids"] + + # RTN + woq_config = RtnConfig(bits=4, group_size=16) + woq_model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + quantization_config=woq_config, + ) + woq_output = woq_model(dummy_input)[0] + + # RTN + woq_config = RtnConfig(bits=4, group_size=16) + woq_model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + quantization_config=woq_config, + for_inference=False, + ) + + # save + output_dir = "./transformers_tmp" + woq_model.save_pretrained(output_dir) + + # load + loaded_model = AutoModelForCausalLM.from_pretrained(output_dir) + loaded_output = loaded_model(dummy_input)[0] + assert torch.equal(woq_output, loaded_output), "loaded output should be same. Please double check."