Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,7 +1749,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
https://test.pypi.org/simple/ bitsandbytes-cudaXXX` where XXX is your CUDA version (e.g. 11.6 = 116).
Make also sure that you have enough GPU RAM to store half of the model size since the 8bit modules are
not compiled and adapted for CPUs.
int8_threshold (`float`, *optional*, defaults to 6):
load_in_8bit_threshold (`float`, *optional*, defaults to 6):
Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as
described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper. Any hidden
states value that is above this threshold will be considered an outlier and the operation on those
Expand All @@ -1759,6 +1759,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
quantization works well for values of magnitude ~5, but beyond that, there is a significant performance
penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models
(small models, fine-tuning).
load_in_8bit_skip_modules (`List[str]`, *optional*):
An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such
as Jukebox that has several heads in different places and not necessarily at the last position.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
Expand Down Expand Up @@ -1850,7 +1853,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
load_in_8bit = kwargs.pop("load_in_8bit", False)
int8_threshold = kwargs.pop("int8_threshold", 6.0)
load_in_8bit_threshold = kwargs.pop("load_in_8bit_threshold", 6.0)
load_in_8bit_skip_modules = kwargs.pop("load_in_8bit_skip_modules", None)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)

Expand Down Expand Up @@ -2151,13 +2155,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
model = cls(config, *model_args, **model_kwargs)

if load_in_8bit:
from .utils.bitsandbytes import get_key_to_not_convert, replace_8bit_linear
from .utils.bitsandbytes import get_keys_to_not_convert, replace_8bit_linear

logger.info("Detected 8-bit loading: activating 8-bit loading for this model")

# We never convert lm_head or any last modules for numerical stability reasons
modules_to_not_convert = get_key_to_not_convert(model)
model = replace_8bit_linear(model, threshold=int8_threshold, modules_to_not_convert=modules_to_not_convert)
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
if load_in_8bit_skip_modules is None:
modules_to_not_convert = get_keys_to_not_convert(model)
else:
modules_to_not_convert = load_in_8bit_skip_modules
model = replace_8bit_linear(
model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert
)

if isinstance(device_map, str):
if model._no_split_modules is None:
Expand Down Expand Up @@ -2188,12 +2197,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)

if load_in_8bit:
# The LM head can stay on disk / CPU
# The LM head / tied weights or any last module can stay on disk / CPU
device_map_without_lm_head = {
key: device_map[key] for key in device_map.keys() if key != modules_to_not_convert
key: device_map[key] for key in device_map.keys() if key not in modules_to_not_convert
}
if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
raise ValueError("8-bit operations on `bitsandbytes` are not supported under CPU!")
raise ValueError(
"""
Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit
the quantized model. If you have set a value for `max_memory` you should increase that. To have
an idea of the modules that are set on the CPU or RAM you can print model.hf_device_map.
"""
)
del device_map_without_lm_head

if from_tf:
Expand Down
21 changes: 15 additions & 6 deletions src/transformers/utils/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"):
if len(list(module.children())) > 0:
replace_8bit_linear(module, threshold, modules_to_not_convert)

if isinstance(module, nn.Linear) and name != modules_to_not_convert:
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
with init_empty_weights():
model._modules[name] = bnb.nn.Linear8bitLt(
module.in_features,
Expand All @@ -126,10 +126,12 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"):
return model


def get_key_to_not_convert(model):
def get_keys_to_not_convert(model):
r"""
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
we may want to keep the lm_head in full precision for numerical stability reasons.
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
int8.

Parameters:
model (`torch.nn.Module`):
Expand All @@ -139,7 +141,9 @@ def get_key_to_not_convert(model):
# check if it contains tied weights
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
tied_model.tie_weights()
has_tied_params = len(find_tied_parameters(tied_model)) > 0

tied_keys = list(find_tied_parameters(tied_model).values())
has_tied_params = len(tied_keys) > 0

# Check if it is a base model
is_base_model = not hasattr(model, model.base_model_prefix)
Expand All @@ -150,5 +154,10 @@ def get_key_to_not_convert(model):

# otherwise they have an attached head
list_modules = list(model.named_parameters())
last_name = list_modules[-1][0]
return last_name.split(".")[0]
list_last_module = [list_modules[-1][0]]

# add last module together with tied weights
intersection = set(list_last_module) - set(tied_keys)
list_untouched = tied_keys + list(intersection)

return [module_name.split(".")[0] for module_name in list_untouched]