|
| 1 | +import inspect |
1 | 2 | import os
|
2 | 3 | import re
|
3 |
| -import weakref |
4 | 4 | from functools import wraps
|
5 | 5 | from typing import Dict, Optional
|
6 | 6 |
|
|
33 | 33 | __all__ = ["modify_save_pretrained"]
|
34 | 34 |
|
35 | 35 |
|
36 |
| -def modify_save_pretrained(model: PreTrainedModel): |
| 36 | +def modify_save_pretrained(model: PreTrainedModel) -> None: |
37 | 37 | """
|
38 | 38 | Overrides a PreTrainedModel's save_pretrained() method with a wrapped version that
|
39 |
| - supports compression. The new save_pretrained function performs the following saving |
40 |
| - operations: |
| 39 | + also supports compression params. The modified save_pretrained function performs the |
| 40 | + following operations: |
41 | 41 |
|
42 | 42 | 1. Saves the model state, potentially in a compressed format
|
43 | 43 | 2. Saves the recipe, appending any current recipes to existing recipe files
|
44 | 44 | 3. Copies any necessary python files from the model cache
|
45 |
| - """ |
46 | 45 |
|
47 |
| - def save_pretrained_compressed(save_pretrained_method): |
48 |
| - if getattr(save_pretrained_method, "_overridden", False): |
49 |
| - # `model.save_pretrained` has already been replaced, return. |
50 |
| - return save_pretrained_method |
51 |
| - |
52 |
| - # Keep a weak reference to the model class and unbound save_pretrained |
53 |
| - # method so we can call the original |
54 |
| - model_ref = weakref.ref(save_pretrained_method.__self__) |
55 |
| - original_save_pretrained = save_pretrained_method.__func__ |
56 |
| - model_class = model_ref().__class__ |
57 |
| - del save_pretrained_method |
58 |
| - |
59 |
| - @wraps(original_save_pretrained) |
60 |
| - def save_pretrained_wrapper( |
61 |
| - save_directory: str, |
62 |
| - sparsity_config: Optional[SparsityCompressionConfig] = None, |
63 |
| - quantization_format: Optional[str] = None, |
64 |
| - save_compressed: bool = True, |
65 |
| - safe_serialization: bool = True, |
66 |
| - skip_sparsity_compression_stats: bool = True, |
67 |
| - disable_sparse_compression: bool = False, |
68 |
| - **kwargs, |
69 |
| - ): |
70 |
| - """ |
71 |
| - Wrapper around PreTrainedModel.save_pretrained(), adds functionality for |
72 |
| - saving models in a compressed format on disk. The compression format is |
73 |
| - saved to the model's config file |
74 |
| -
|
75 |
| - :param save_directory: output directory to save model to |
76 |
| - :param sparsity_config: optional sparsity config to compress model with, |
77 |
| - if no config is provided it will be inferred from the model |
78 |
| - :param quantization_format: optional compression format for quantized |
79 |
| - models. If none is provided it will be inferred from the model |
80 |
| - :param save_compressed: whether or not to compress the model on disk |
81 |
| - :param skip_sparsity_compression_stats: whether to skip the calculation of |
82 |
| - sparsity statistics (such as global sparsity and sparsity structure) |
83 |
| - when saving a model in dense format |
84 |
| - :param disable_sparse_compression: whether to skip sparse compression |
85 |
| - during save, default is False |
86 |
| - :param kwargs: additional kwargs to pass on to model.save_pretrained |
87 |
| - """ |
88 |
| - |
89 |
| - # HACK: Override the dtype_byte_size function in transformers to |
90 |
| - # support float8 types. Fix is posted upstream |
91 |
| - # https://github.com/huggingface/transformers/pull/30488 |
92 |
| - transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size |
93 |
| - |
94 |
| - # state_dict gets passed in as a kwarg for FSDP models |
95 |
| - state_dict = kwargs.pop("state_dict", None) |
96 |
| - if state_dict is None: |
97 |
| - logger.info("Fetching state_dict - this may take some time") |
98 |
| - state_dict = get_state_dict_offloaded_model(model) |
99 |
| - |
100 |
| - logger.info("Fetching compressor") |
101 |
| - compressor = get_model_compressor( |
102 |
| - model=model, |
103 |
| - sparsity_config=sparsity_config, |
104 |
| - quantization_format=quantization_format, |
105 |
| - save_compressed=save_compressed, |
106 |
| - skip_sparsity_compression_stats=skip_sparsity_compression_stats, |
107 |
| - state_dict=state_dict, |
108 |
| - disable_sparse_compression=disable_sparse_compression, |
| 46 | + :param model: The model whose save_pretrained method will be modified |
| 47 | + """ |
| 48 | + original = model.save_pretrained |
| 49 | + # Avoid double-wrapping if already modified |
| 50 | + if getattr(original, "_overridden", False): |
| 51 | + return |
| 52 | + |
| 53 | + # Create enhanced signature with compression parameters |
| 54 | + orig_sig = inspect.signature(original) |
| 55 | + sig_with_compression_params = _create_compression_signature(orig_sig) |
| 56 | + |
| 57 | + @wraps(original) |
| 58 | + def save_pretrained_wrapper( |
| 59 | + *args, |
| 60 | + sparsity_config: Optional[SparsityCompressionConfig] = None, |
| 61 | + quantization_format: Optional[str] = None, |
| 62 | + save_compressed: bool = True, |
| 63 | + skip_sparsity_compression_stats: bool = True, |
| 64 | + disable_sparse_compression: bool = False, |
| 65 | + **kwargs, |
| 66 | + ): |
| 67 | + """ |
| 68 | + Wrapper around PreTrainedModel.save_pretrained() that adds compression |
| 69 | + functionality. The compression format is saved to the model's config file |
| 70 | +
|
| 71 | + NOTE: If adding parameters here, also update _create_compression_signature() |
| 72 | + to maintain signature consistency. |
| 73 | +
|
| 74 | + :param sparsity_config: Optional sparsity compression configuration. |
| 75 | + If None and `skip_sparsity_compression_stats` is False, a sparsity |
| 76 | + config will be inferred from the model. |
| 77 | + :param quantization_format: Optional format string for quantization |
| 78 | + :param save_compressed: Whether to save the model in compressed format |
| 79 | + :param skip_sparsity_compression_stats: Whether to skip calculating |
| 80 | + sparsity stats. |
| 81 | + :param disable_sparse_compression: Whether to disable sparse compression |
| 82 | + entirely |
| 83 | + """ |
| 84 | + # HACK: Override the dtype_byte_size function in transformers to |
| 85 | + # support float8 types. Fix is posted upstream |
| 86 | + # https://github.com/huggingface/transformers/pull/30488 |
| 87 | + transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size |
| 88 | + |
| 89 | + # Extract save_directory from args or kwargs |
| 90 | + save_directory = args[0] if args else kwargs.get("save_directory") |
| 91 | + if save_directory is None: |
| 92 | + raise ValueError( |
| 93 | + "`save_directory` must be provided as first positional arg or kwarg" |
109 | 94 | )
|
110 | 95 |
|
111 |
| - if compressor is None: |
112 |
| - # model is not compressed or quantized, save as normal |
113 |
| - original_save_pretrained_func = original_save_pretrained.__get__( |
114 |
| - model, model_class |
115 |
| - ) |
116 |
| - original_save_pretrained_func( |
117 |
| - save_directory, state_dict=state_dict, **kwargs |
118 |
| - ) |
119 |
| - return |
120 |
| - |
121 |
| - # make sure we're on the main process when saving |
122 |
| - if state_dict is not None and len(state_dict) > 0: |
123 |
| - compressed_state_dict = compressor.compress(model, state_dict) |
124 |
| - logger.info("Saving compressed model to disk") |
125 |
| - original_save_pretrained.__get__(model, model_class)( |
126 |
| - save_directory, |
127 |
| - state_dict=compressed_state_dict, |
128 |
| - safe_serialization=safe_serialization, |
129 |
| - **kwargs, |
130 |
| - ) |
131 |
| - compressor.update_config(save_directory) |
132 |
| - |
133 |
| - # update existing recipe |
134 |
| - update_and_save_recipe(model.name_or_path, save_directory) |
135 |
| - |
136 |
| - # copy python files from cache dir to save_path if any |
137 |
| - copy_python_files_from_model_cache(model, save_directory) |
138 |
| - |
139 |
| - save_pretrained_wrapper._overriden = True |
140 |
| - return save_pretrained_wrapper |
141 |
| - |
142 |
| - # wrap save_pretrained if not already |
143 |
| - if not getattr(model.save_pretrained, "_overriden", False): |
144 |
| - model.save_pretrained = save_pretrained_compressed(model.save_pretrained) |
| 96 | + # Get state_dict or fetch it if not provided |
| 97 | + state_dict = kwargs.pop("state_dict", None) |
| 98 | + if state_dict is None: |
| 99 | + logger.info("Fetching state_dict – this may take some time") |
| 100 | + state_dict = get_state_dict_offloaded_model(model) |
| 101 | + |
| 102 | + logger.info("Fetching compressor") |
| 103 | + compressor = get_model_compressor( |
| 104 | + model=model, |
| 105 | + sparsity_config=sparsity_config, |
| 106 | + quantization_format=quantization_format, |
| 107 | + save_compressed=save_compressed, |
| 108 | + skip_sparsity_compression_stats=skip_sparsity_compression_stats, |
| 109 | + state_dict=state_dict, |
| 110 | + disable_sparse_compression=disable_sparse_compression, |
| 111 | + ) |
| 112 | + |
| 113 | + if compressor is None: |
| 114 | + # No compression needed |
| 115 | + original(*args, state_dict=state_dict, **kwargs) |
| 116 | + else: |
| 117 | + # Compress and save |
| 118 | + compressed_state_dict = compressor.compress(model, state_dict) |
| 119 | + logger.info("Saving compressed model to disk") |
| 120 | + original(*args, state_dict=compressed_state_dict, **kwargs) |
| 121 | + compressor.update_config(save_directory) |
| 122 | + |
| 123 | + # These operations happen regardless of compression |
| 124 | + update_and_save_recipe(model.name_or_path, save_directory) |
| 125 | + copy_python_files_from_model_cache(model, save_directory) |
| 126 | + |
| 127 | + # Apply compression signature |
| 128 | + save_pretrained_wrapper.__signature__ = sig_with_compression_params |
| 129 | + save_pretrained_wrapper._overridden = True |
| 130 | + model.save_pretrained = save_pretrained_wrapper |
145 | 131 |
|
146 | 132 |
|
147 | 133 | # HACK: Override the dtype_byte_size function in transformers to support float8 types
|
@@ -306,3 +292,59 @@ def update_and_save_recipe(model_stub: str, save_directory: str):
|
306 | 292 | # save recipe
|
307 | 293 | recipe_path = os.path.join(save_directory, RECIPE_FILE_NAME)
|
308 | 294 | recipe.yaml(recipe_path)
|
| 295 | + |
| 296 | + |
| 297 | +def _create_compression_signature(orig_sig: inspect.Signature) -> inspect.Signature: |
| 298 | + """ |
| 299 | + Creates an enhanced signature with compression parameters. |
| 300 | +
|
| 301 | + :param orig_sig: Original function signature |
| 302 | + :return: Enhanced signature with compression parameters |
| 303 | + """ |
| 304 | + # Define compression parameters |
| 305 | + compression_params = [ |
| 306 | + inspect.Parameter( |
| 307 | + name="sparsity_config", |
| 308 | + kind=inspect.Parameter.KEYWORD_ONLY, |
| 309 | + default=None, |
| 310 | + annotation=Optional[SparsityCompressionConfig], |
| 311 | + ), |
| 312 | + inspect.Parameter( |
| 313 | + name="quantization_format", |
| 314 | + kind=inspect.Parameter.KEYWORD_ONLY, |
| 315 | + default=None, |
| 316 | + annotation=Optional[str], |
| 317 | + ), |
| 318 | + inspect.Parameter( |
| 319 | + name="save_compressed", |
| 320 | + kind=inspect.Parameter.KEYWORD_ONLY, |
| 321 | + default=True, |
| 322 | + annotation=bool, |
| 323 | + ), |
| 324 | + inspect.Parameter( |
| 325 | + name="skip_sparsity_compression_stats", |
| 326 | + kind=inspect.Parameter.KEYWORD_ONLY, |
| 327 | + default=True, |
| 328 | + annotation=bool, |
| 329 | + ), |
| 330 | + inspect.Parameter( |
| 331 | + name="disable_sparse_compression", |
| 332 | + kind=inspect.Parameter.KEYWORD_ONLY, |
| 333 | + default=False, |
| 334 | + annotation=bool, |
| 335 | + ), |
| 336 | + ] |
| 337 | + |
| 338 | + # Only add parameters that don't exist in the original signature |
| 339 | + existing_params = orig_sig.parameters.keys() |
| 340 | + new_params = [] |
| 341 | + |
| 342 | + for param in orig_sig.parameters.values(): |
| 343 | + if param.kind == inspect.Parameter.VAR_KEYWORD: |
| 344 | + # Add compression params before **kwargs |
| 345 | + new_params.extend( |
| 346 | + [p for p in compression_params if p.name not in existing_params] |
| 347 | + ) |
| 348 | + new_params.append(param) |
| 349 | + |
| 350 | + return orig_sig.replace(parameters=new_params) |
0 commit comments