diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2f305ff8dd09..930e97fc324b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1751,7 +1751,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P https://test.pypi.org/simple/ bitsandbytes-cudaXXX` where XXX is your CUDA version (e.g. 11.6 = 116). Make also sure that you have enough GPU RAM to store half of the model size since the 8bit modules are not compiled and adapted for CPUs. - int8_threshold (`float`, *optional*, defaults to 6): + load_in_8bit_threshold (`float`, *optional*, defaults to 6): Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper. Any hidden states value that is above this threshold will be considered an outlier and the operation on those @@ -1761,6 +1761,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P quantization works well for values of magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models (small models, fine-tuning). + load_in_8bit_skip_modules (`List[str]`, *optional*): + An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such + as Jukebox that has several heads in different places and not necessarily at the last position. subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here. @@ -1852,7 +1855,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P offload_folder = kwargs.pop("offload_folder", None) offload_state_dict = kwargs.pop("offload_state_dict", False) load_in_8bit = kwargs.pop("load_in_8bit", False) - int8_threshold = kwargs.pop("int8_threshold", 6.0) + load_in_8bit_threshold = kwargs.pop("load_in_8bit_threshold", 6.0) + load_in_8bit_skip_modules = kwargs.pop("load_in_8bit_skip_modules", None) subfolder = kwargs.pop("subfolder", "") commit_hash = kwargs.pop("_commit_hash", None) @@ -2156,13 +2160,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = cls(config, *model_args, **model_kwargs) if load_in_8bit: - from .utils.bitsandbytes import get_key_to_not_convert, replace_8bit_linear + from .utils.bitsandbytes import get_keys_to_not_convert, replace_8bit_linear logger.info("Detected 8-bit loading: activating 8-bit loading for this model") - # We never convert lm_head or any last modules for numerical stability reasons - modules_to_not_convert = get_key_to_not_convert(model) - model = replace_8bit_linear(model, threshold=int8_threshold, modules_to_not_convert=modules_to_not_convert) + # We keep some modules such as the lm_head in their original dtype for numerical stability reasons + if load_in_8bit_skip_modules is None: + modules_to_not_convert = get_keys_to_not_convert(model) + else: + modules_to_not_convert = load_in_8bit_skip_modules + model = replace_8bit_linear( + model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert + ) if isinstance(device_map, str): if model._no_split_modules is None: @@ -2193,12 +2202,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) if load_in_8bit: - # The LM head can stay on disk / CPU + # The LM head / tied weights or any last module can stay on disk / CPU device_map_without_lm_head = { - key: device_map[key] for key in device_map.keys() if key != modules_to_not_convert + key: device_map[key] for key in device_map.keys() if key not in modules_to_not_convert } if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): - raise ValueError("8-bit operations on `bitsandbytes` are not supported under CPU!") + raise ValueError( + """ + Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit + the quantized model. If you have set a value for `max_memory` you should increase that. To have + an idea of the modules that are set on the CPU or RAM you can print model.hf_device_map. + """ + ) del device_map_without_lm_head if from_tf: diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index eca605b2edef..b2339efd6269 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -114,7 +114,7 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"): if len(list(module.children())) > 0: replace_8bit_linear(module, threshold, modules_to_not_convert) - if isinstance(module, nn.Linear) and name != modules_to_not_convert: + if isinstance(module, nn.Linear) and name not in modules_to_not_convert: with init_empty_weights(): model._modules[name] = bnb.nn.Linear8bitLt( module.in_features, @@ -126,10 +126,12 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"): return model -def get_key_to_not_convert(model): +def get_keys_to_not_convert(model): r""" An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules - we may want to keep the lm_head in full precision for numerical stability reasons. + we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want + to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in + int8. Parameters: model (`torch.nn.Module`): @@ -139,7 +141,9 @@ def get_key_to_not_convert(model): # check if it contains tied weights tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` tied_model.tie_weights() - has_tied_params = len(find_tied_parameters(tied_model)) > 0 + + tied_keys = list(find_tied_parameters(tied_model).values()) + has_tied_params = len(tied_keys) > 0 # Check if it is a base model is_base_model = not hasattr(model, model.base_model_prefix) @@ -150,5 +154,10 @@ def get_key_to_not_convert(model): # otherwise they have an attached head list_modules = list(model.named_parameters()) - last_name = list_modules[-1][0] - return last_name.split(".")[0] + list_last_module = [list_modules[-1][0]] + + # add last module together with tied weights + intersection = set(list_last_module) - set(tied_keys) + list_untouched = tied_keys + list(intersection) + + return [module_name.split(".")[0] for module_name in list_untouched]