Skip to content

Commit 7743cac

Browse files
younesbelkadastas00sgugger
authored
[bnb] Small improvements on utils (#18646)
* Small replacement - replace `modules_to_not_convert` by `module_to_not_convert` * refactor a bit - changed variables name - now output a list - change error message * make style * add list * make style * change args name Co-authored-by: stas00 <[email protected]> * fix comment * fix typo Co-authored-by: stas00 <[email protected]> * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: stas00 <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]>
1 parent 8edf196 commit 7743cac

File tree

2 files changed

+39
-15
lines changed

2 files changed

+39
-15
lines changed

src/transformers/modeling_utils.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,7 +1751,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
17511751
https://test.pypi.org/simple/ bitsandbytes-cudaXXX` where XXX is your CUDA version (e.g. 11.6 = 116).
17521752
Make also sure that you have enough GPU RAM to store half of the model size since the 8bit modules are
17531753
not compiled and adapted for CPUs.
1754-
int8_threshold (`float`, *optional*, defaults to 6):
1754+
load_in_8bit_threshold (`float`, *optional*, defaults to 6):
17551755
Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as
17561756
described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper. Any hidden
17571757
states value that is above this threshold will be considered an outlier and the operation on those
@@ -1761,6 +1761,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
17611761
quantization works well for values of magnitude ~5, but beyond that, there is a significant performance
17621762
penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models
17631763
(small models, fine-tuning).
1764+
load_in_8bit_skip_modules (`List[str]`, *optional*):
1765+
An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such
1766+
as Jukebox that has several heads in different places and not necessarily at the last position.
17641767
subfolder (`str`, *optional*, defaults to `""`):
17651768
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
17661769
specify the folder name here.
@@ -1852,7 +1855,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
18521855
offload_folder = kwargs.pop("offload_folder", None)
18531856
offload_state_dict = kwargs.pop("offload_state_dict", False)
18541857
load_in_8bit = kwargs.pop("load_in_8bit", False)
1855-
int8_threshold = kwargs.pop("int8_threshold", 6.0)
1858+
load_in_8bit_threshold = kwargs.pop("load_in_8bit_threshold", 6.0)
1859+
load_in_8bit_skip_modules = kwargs.pop("load_in_8bit_skip_modules", None)
18561860
subfolder = kwargs.pop("subfolder", "")
18571861
commit_hash = kwargs.pop("_commit_hash", None)
18581862

@@ -2156,13 +2160,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
21562160
model = cls(config, *model_args, **model_kwargs)
21572161

21582162
if load_in_8bit:
2159-
from .utils.bitsandbytes import get_key_to_not_convert, replace_8bit_linear
2163+
from .utils.bitsandbytes import get_keys_to_not_convert, replace_8bit_linear
21602164

21612165
logger.info("Detected 8-bit loading: activating 8-bit loading for this model")
21622166

2163-
# We never convert lm_head or any last modules for numerical stability reasons
2164-
modules_to_not_convert = get_key_to_not_convert(model)
2165-
model = replace_8bit_linear(model, threshold=int8_threshold, modules_to_not_convert=modules_to_not_convert)
2167+
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
2168+
if load_in_8bit_skip_modules is None:
2169+
modules_to_not_convert = get_keys_to_not_convert(model)
2170+
else:
2171+
modules_to_not_convert = load_in_8bit_skip_modules
2172+
model = replace_8bit_linear(
2173+
model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert
2174+
)
21662175

21672176
if isinstance(device_map, str):
21682177
if model._no_split_modules is None:
@@ -2193,12 +2202,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
21932202
)
21942203

21952204
if load_in_8bit:
2196-
# The LM head can stay on disk / CPU
2205+
# The LM head / tied weights or any last module can stay on disk / CPU
21972206
device_map_without_lm_head = {
2198-
key: device_map[key] for key in device_map.keys() if key != modules_to_not_convert
2207+
key: device_map[key] for key in device_map.keys() if key not in modules_to_not_convert
21992208
}
22002209
if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
2201-
raise ValueError("8-bit operations on `bitsandbytes` are not supported under CPU!")
2210+
raise ValueError(
2211+
"""
2212+
Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit
2213+
the quantized model. If you have set a value for `max_memory` you should increase that. To have
2214+
an idea of the modules that are set on the CPU or RAM you can print model.hf_device_map.
2215+
"""
2216+
)
22022217
del device_map_without_lm_head
22032218

22042219
if from_tf:

src/transformers/utils/bitsandbytes.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"):
114114
if len(list(module.children())) > 0:
115115
replace_8bit_linear(module, threshold, modules_to_not_convert)
116116

117-
if isinstance(module, nn.Linear) and name != modules_to_not_convert:
117+
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
118118
with init_empty_weights():
119119
model._modules[name] = bnb.nn.Linear8bitLt(
120120
module.in_features,
@@ -126,10 +126,12 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"):
126126
return model
127127

128128

129-
def get_key_to_not_convert(model):
129+
def get_keys_to_not_convert(model):
130130
r"""
131131
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
132-
we may want to keep the lm_head in full precision for numerical stability reasons.
132+
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
133+
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
134+
int8.
133135
134136
Parameters:
135137
model (`torch.nn.Module`):
@@ -139,7 +141,9 @@ def get_key_to_not_convert(model):
139141
# check if it contains tied weights
140142
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
141143
tied_model.tie_weights()
142-
has_tied_params = len(find_tied_parameters(tied_model)) > 0
144+
145+
tied_keys = list(find_tied_parameters(tied_model).values())
146+
has_tied_params = len(tied_keys) > 0
143147

144148
# Check if it is a base model
145149
is_base_model = not hasattr(model, model.base_model_prefix)
@@ -150,5 +154,10 @@ def get_key_to_not_convert(model):
150154

151155
# otherwise they have an attached head
152156
list_modules = list(model.named_parameters())
153-
last_name = list_modules[-1][0]
154-
return last_name.split(".")[0]
157+
list_last_module = [list_modules[-1][0]]
158+
159+
# add last module together with tied weights
160+
intersection = set(list_last_module) - set(tied_keys)
161+
list_untouched = tied_keys + list(intersection)
162+
163+
return [module_name.split(".")[0] for module_name in list_untouched]

0 commit comments

Comments
 (0)