From ea155ed2502ade5db53d84267029c34eef3f3b4b Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 16 Aug 2022 10:52:43 +0200 Subject: [PATCH 1/9] Small replacement - replace `modules_to_not_convert` by `module_to_not_convert` --- src/transformers/modeling_utils.py | 6 +++--- src/transformers/utils/bitsandbytes.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d77258c94ea0..5580548808c9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2142,8 +2142,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P logger.info("Detected 8-bit loading: activating 8-bit loading for this model") # We never convert lm_head or any last modules for numerical stability reasons - modules_to_not_convert = get_key_to_not_convert(model) - model = replace_8bit_linear(model, threshold=int8_threshold, modules_to_not_convert=modules_to_not_convert) + module_to_not_convert = get_key_to_not_convert(model) + model = replace_8bit_linear(model, threshold=int8_threshold, module_to_not_convert=module_to_not_convert) if isinstance(device_map, str): if model._no_split_modules is None: @@ -2176,7 +2176,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if load_in_8bit: # The LM head can stay on disk / CPU device_map_without_lm_head = { - key: device_map[key] for key in device_map.keys() if key != modules_to_not_convert + key: device_map[key] for key in device_map.keys() if key != module_to_not_convert } if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): raise ValueError("8-bit operations on `bitsandbytes` are not supported under CPU!") diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index eca605b2edef..43f9439387a6 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -85,7 +85,7 @@ class `Int8Params` from `bitsandbytes`. module._parameters[tensor_name] = new_value -def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"): +def replace_8bit_linear(model, threshold=6.0, module_to_not_convert="lm_head"): """ A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes` library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8(): @@ -112,9 +112,9 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"): """ for name, module in model.named_children(): if len(list(module.children())) > 0: - replace_8bit_linear(module, threshold, modules_to_not_convert) + replace_8bit_linear(module, threshold, module_to_not_convert) - if isinstance(module, nn.Linear) and name != modules_to_not_convert: + if isinstance(module, nn.Linear) and name != module_to_not_convert: with init_empty_weights(): model._modules[name] = bnb.nn.Linear8bitLt( module.in_features, From a7731f74f74ca126cec25ac9815c6131d44b0395 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 16 Aug 2022 22:43:19 +0000 Subject: [PATCH 2/9] refactor a bit - changed variables name - now output a list - change error message --- src/transformers/modeling_utils.py | 18 ++++++++++++------ src/transformers/utils/bitsandbytes.py | 26 ++++++++++++++++++-------- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5580548808c9..f4106aa6262f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -85,7 +85,7 @@ get_balanced_memory = None if is_bitsandbytes_available(): - from .utils.bitsandbytes import get_key_to_not_convert, replace_8bit_linear, set_module_8bit_tensor_to_device + from .utils.bitsandbytes import get_keys_to_not_convert, replace_8bit_linear, set_module_8bit_tensor_to_device logger = logging.get_logger(__name__) @@ -2142,8 +2142,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P logger.info("Detected 8-bit loading: activating 8-bit loading for this model") # We never convert lm_head or any last modules for numerical stability reasons - module_to_not_convert = get_key_to_not_convert(model) - model = replace_8bit_linear(model, threshold=int8_threshold, module_to_not_convert=module_to_not_convert) + modules_to_not_convert = get_keys_to_not_convert(model) + model = replace_8bit_linear(model, threshold=int8_threshold, modules_to_not_convert=modules_to_not_convert) if isinstance(device_map, str): if model._no_split_modules is None: @@ -2174,12 +2174,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) if load_in_8bit: - # The LM head can stay on disk / CPU + # The LM head / tied weights or any last module can stay on disk / CPU device_map_without_lm_head = { - key: device_map[key] for key in device_map.keys() if key != module_to_not_convert + key: device_map[key] for key in device_map.keys() if key not in modules_to_not_convert } if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): - raise ValueError("8-bit operations on `bitsandbytes` are not supported under CPU!") + raise ValueError( + """ + Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the quantized model. + If you have set a value for `max_memory` you should increase that. To have an idea of the modules that are set on the CPU or RAM + you can print model.hf_device_map. + """ + ) del device_map_without_lm_head if from_tf: diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index 43f9439387a6..d72649a2b5aa 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -1,5 +1,7 @@ from copy import deepcopy +import torch + from transformers.utils import is_accelerate_available, is_bitsandbytes_available @@ -85,7 +87,7 @@ class `Int8Params` from `bitsandbytes`. module._parameters[tensor_name] = new_value -def replace_8bit_linear(model, threshold=6.0, module_to_not_convert="lm_head"): +def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"): """ A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes` library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8(): @@ -112,9 +114,9 @@ def replace_8bit_linear(model, threshold=6.0, module_to_not_convert="lm_head"): """ for name, module in model.named_children(): if len(list(module.children())) > 0: - replace_8bit_linear(module, threshold, module_to_not_convert) + replace_8bit_linear(module, threshold, modules_to_not_convert) - if isinstance(module, nn.Linear) and name != module_to_not_convert: + if isinstance(module, nn.Linear) and name not in modules_to_not_convert: with init_empty_weights(): model._modules[name] = bnb.nn.Linear8bitLt( module.in_features, @@ -126,10 +128,11 @@ def replace_8bit_linear(model, threshold=6.0, module_to_not_convert="lm_head"): return model -def get_key_to_not_convert(model): +def get_keys_to_not_convert(model): r""" An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules - we may want to keep the lm_head in full precision for numerical stability reasons. + we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want to keep + the tied weights of the model. The function will return a list of the keys of the modules to not convert in int8. Parameters: model (`torch.nn.Module`): @@ -139,7 +142,9 @@ def get_key_to_not_convert(model): # check if it contains tied weights tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` tied_model.tie_weights() - has_tied_params = len(find_tied_parameters(tied_model)) > 0 + + tied_keys = list(find_tied_parameters(tied_model).values()) + has_tied_params = len(tied_keys) > 0 # Check if it is a base model is_base_model = not hasattr(model, model.base_model_prefix) @@ -150,5 +155,10 @@ def get_key_to_not_convert(model): # otherwise they have an attached head list_modules = list(model.named_parameters()) - last_name = list_modules[-1][0] - return last_name.split(".")[0] + list_last_module = [list_modules[-1][0]] + + # add last module together with tied weights + intersection = set(list_last_module) - set(tied_keys) + list_untouched = tied_keys + list(intersection) + + return [module_name.split(".")[0] for module_name in list_untouched] From bf59f9fabc84134c50cb83633df474dbe4fd60b0 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 17 Aug 2022 00:45:39 +0200 Subject: [PATCH 3/9] make style --- src/transformers/modeling_utils.py | 8 ++++---- src/transformers/utils/bitsandbytes.py | 7 +++---- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f4106aa6262f..ed83ffa079e8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2181,10 +2181,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): raise ValueError( """ - Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the quantized model. - If you have set a value for `max_memory` you should increase that. To have an idea of the modules that are set on the CPU or RAM - you can print model.hf_device_map. - """ + Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit + the quantized model. If you have set a value for `max_memory` you should increase that. To have + an idea of the modules that are set on the CPU or RAM you can print model.hf_device_map. + """ ) del device_map_without_lm_head diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index d72649a2b5aa..b2339efd6269 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -1,7 +1,5 @@ from copy import deepcopy -import torch - from transformers.utils import is_accelerate_available, is_bitsandbytes_available @@ -131,8 +129,9 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"): def get_keys_to_not_convert(model): r""" An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules - we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want to keep - the tied weights of the model. The function will return a list of the keys of the modules to not convert in int8. + we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want + to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in + int8. Parameters: model (`torch.nn.Module`): From f5dc6ad23e9b9aae6d581af051a034d341aced1f Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 17 Aug 2022 07:20:56 +0000 Subject: [PATCH 4/9] add list --- src/transformers/modeling_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ed83ffa079e8..33e5a03152f6 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1747,6 +1747,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P quantization works well for values of magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models (small models, fine-tuning). + no_load_in_8bit_modules (`List[str]`, *optional*, defaults to `None`): + An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as + Jukebox that has several heads in different places and not necessarly at the last position. subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here. @@ -1839,6 +1842,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P offload_state_dict = kwargs.pop("offload_state_dict", False) load_in_8bit = kwargs.pop("load_in_8bit", False) int8_threshold = kwargs.pop("int8_threshold", 6.0) + no_load_in_8bit_modules = kwargs.pop("no_load_in_8bit_modules", None) subfolder = kwargs.pop("subfolder", "") commit_hash = kwargs.pop("_commit_hash", None) @@ -2142,7 +2146,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P logger.info("Detected 8-bit loading: activating 8-bit loading for this model") # We never convert lm_head or any last modules for numerical stability reasons - modules_to_not_convert = get_keys_to_not_convert(model) + if no_load_in_8bit_modules is None: + modules_to_not_convert = get_keys_to_not_convert(model) + else: + modules_to_not_convert = no_load_in_8bit_modules model = replace_8bit_linear(model, threshold=int8_threshold, modules_to_not_convert=modules_to_not_convert) if isinstance(device_map, str): From 42c9df2d09e7bb10e6a65ffc86b1830a93c5a656 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 17 Aug 2022 12:42:14 +0200 Subject: [PATCH 5/9] make style --- src/transformers/modeling_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 33e5a03152f6..6b25583e9941 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1748,8 +1748,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models (small models, fine-tuning). no_load_in_8bit_modules (`List[str]`, *optional*, defaults to `None`): - An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as - Jukebox that has several heads in different places and not necessarly at the last position. + An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such + as Jukebox that has several heads in different places and not necessarly at the last position. subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here. From 27b0ef07a0e43676c525f131bea421e9f3c9f1a0 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 12 Sep 2022 09:27:11 +0000 Subject: [PATCH 6/9] change args name Co-authored-by: stas00 --- src/transformers/modeling_utils.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 10f5c0a47bcb..5381dff6ddd4 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1749,7 +1749,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P https://test.pypi.org/simple/ bitsandbytes-cudaXXX` where XXX is your CUDA version (e.g. 11.6 = 116). Make also sure that you have enough GPU RAM to store half of the model size since the 8bit modules are not compiled and adapted for CPUs. - int8_threshold (`float`, *optional*, defaults to 6): + load_in_8bit_threshold (`float`, *optional*, defaults to 6): Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper. Any hidden states value that is above this threshold will be considered an outlier and the operation on those @@ -1759,7 +1759,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P quantization works well for values of magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models (small models, fine-tuning). - no_load_in_8bit_modules (`List[str]`, *optional*, defaults to `None`): + load_in_8bit_skip_modules (`List[str]`, *optional*, defaults to `None`): An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as Jukebox that has several heads in different places and not necessarly at the last position. subfolder (`str`, *optional*, defaults to `""`): @@ -1853,8 +1853,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P offload_folder = kwargs.pop("offload_folder", None) offload_state_dict = kwargs.pop("offload_state_dict", False) load_in_8bit = kwargs.pop("load_in_8bit", False) - int8_threshold = kwargs.pop("int8_threshold", 6.0) - no_load_in_8bit_modules = kwargs.pop("no_load_in_8bit_modules", None) + load_in_8bit_threshold = kwargs.pop("load_in_8bit_threshold", 6.0) + load_in_8bit_skip_modules = kwargs.pop("load_in_8bit_skip_modules", None) subfolder = kwargs.pop("subfolder", "") commit_hash = kwargs.pop("_commit_hash", None) @@ -2160,11 +2160,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P logger.info("Detected 8-bit loading: activating 8-bit loading for this model") # We never convert lm_head or any last modules for numerical stability reasons - if no_load_in_8bit_modules is None: + if load_in_8bit_skip_modules is None: modules_to_not_convert = get_keys_to_not_convert(model) else: - modules_to_not_convert = no_load_in_8bit_modules - model = replace_8bit_linear(model, threshold=int8_threshold, modules_to_not_convert=modules_to_not_convert) + modules_to_not_convert = load_in_8bit_skip_modules + model = replace_8bit_linear( + model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert + ) if isinstance(device_map, str): if model._no_split_modules is None: From 224b504fc08506e5b8cafa164c4b8bcad454b81e Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 12 Sep 2022 09:33:45 +0000 Subject: [PATCH 7/9] fix comment --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5381dff6ddd4..9015b2919033 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2159,7 +2159,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P logger.info("Detected 8-bit loading: activating 8-bit loading for this model") - # We never convert lm_head or any last modules for numerical stability reasons + # We keep some modules such as the lm_head in their original dtype for numerical stability reasons if load_in_8bit_skip_modules is None: modules_to_not_convert = get_keys_to_not_convert(model) else: From 01a4c0cea9d2c2ad8fd3c85afbdb274510d91344 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 12 Sep 2022 09:35:23 +0000 Subject: [PATCH 8/9] fix typo Co-authored-by: stas00 --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9015b2919033..359b00d9286f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1761,7 +1761,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P (small models, fine-tuning). load_in_8bit_skip_modules (`List[str]`, *optional*, defaults to `None`): An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such - as Jukebox that has several heads in different places and not necessarly at the last position. + as Jukebox that has several heads in different places and not necessarily at the last position. subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here. From 23fe74a6b343691512eab9d6370d73f44ab5d53a Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 12 Sep 2022 15:12:16 +0200 Subject: [PATCH 9/9] Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 359b00d9286f..d00e24bad36a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1759,7 +1759,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P quantization works well for values of magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models (small models, fine-tuning). - load_in_8bit_skip_modules (`List[str]`, *optional*, defaults to `None`): + load_in_8bit_skip_modules (`List[str]`, *optional*): An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as Jukebox that has several heads in different places and not necessarily at the last position. subfolder (`str`, *optional*, defaults to `""`):