@@ -1751,7 +1751,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
1751
1751
https://test.pypi.org/simple/ bitsandbytes-cudaXXX` where XXX is your CUDA version (e.g. 11.6 = 116).
1752
1752
Make also sure that you have enough GPU RAM to store half of the model size since the 8bit modules are
1753
1753
not compiled and adapted for CPUs.
1754
- int8_threshold (`float`, *optional*, defaults to 6):
1754
+ load_in_8bit_threshold (`float`, *optional*, defaults to 6):
1755
1755
Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as
1756
1756
described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper. Any hidden
1757
1757
states value that is above this threshold will be considered an outlier and the operation on those
@@ -1761,6 +1761,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
1761
1761
quantization works well for values of magnitude ~5, but beyond that, there is a significant performance
1762
1762
penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models
1763
1763
(small models, fine-tuning).
1764
+ load_in_8bit_skip_modules (`List[str]`, *optional*):
1765
+ An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such
1766
+ as Jukebox that has several heads in different places and not necessarily at the last position.
1764
1767
subfolder (`str`, *optional*, defaults to `""`):
1765
1768
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
1766
1769
specify the folder name here.
@@ -1852,7 +1855,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
1852
1855
offload_folder = kwargs .pop ("offload_folder" , None )
1853
1856
offload_state_dict = kwargs .pop ("offload_state_dict" , False )
1854
1857
load_in_8bit = kwargs .pop ("load_in_8bit" , False )
1855
- int8_threshold = kwargs .pop ("int8_threshold" , 6.0 )
1858
+ load_in_8bit_threshold = kwargs .pop ("load_in_8bit_threshold" , 6.0 )
1859
+ load_in_8bit_skip_modules = kwargs .pop ("load_in_8bit_skip_modules" , None )
1856
1860
subfolder = kwargs .pop ("subfolder" , "" )
1857
1861
commit_hash = kwargs .pop ("_commit_hash" , None )
1858
1862
@@ -2156,13 +2160,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
2156
2160
model = cls (config , * model_args , ** model_kwargs )
2157
2161
2158
2162
if load_in_8bit :
2159
- from .utils .bitsandbytes import get_key_to_not_convert , replace_8bit_linear
2163
+ from .utils .bitsandbytes import get_keys_to_not_convert , replace_8bit_linear
2160
2164
2161
2165
logger .info ("Detected 8-bit loading: activating 8-bit loading for this model" )
2162
2166
2163
- # We never convert lm_head or any last modules for numerical stability reasons
2164
- modules_to_not_convert = get_key_to_not_convert (model )
2165
- model = replace_8bit_linear (model , threshold = int8_threshold , modules_to_not_convert = modules_to_not_convert )
2167
+ # We keep some modules such as the lm_head in their original dtype for numerical stability reasons
2168
+ if load_in_8bit_skip_modules is None :
2169
+ modules_to_not_convert = get_keys_to_not_convert (model )
2170
+ else :
2171
+ modules_to_not_convert = load_in_8bit_skip_modules
2172
+ model = replace_8bit_linear (
2173
+ model , threshold = load_in_8bit_threshold , modules_to_not_convert = modules_to_not_convert
2174
+ )
2166
2175
2167
2176
if isinstance (device_map , str ):
2168
2177
if model ._no_split_modules is None :
@@ -2193,12 +2202,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
2193
2202
)
2194
2203
2195
2204
if load_in_8bit :
2196
- # The LM head can stay on disk / CPU
2205
+ # The LM head / tied weights or any last module can stay on disk / CPU
2197
2206
device_map_without_lm_head = {
2198
- key : device_map [key ] for key in device_map .keys () if key != modules_to_not_convert
2207
+ key : device_map [key ] for key in device_map .keys () if key not in modules_to_not_convert
2199
2208
}
2200
2209
if "cpu" in device_map_without_lm_head .values () or "disk" in device_map_without_lm_head .values ():
2201
- raise ValueError ("8-bit operations on `bitsandbytes` are not supported under CPU!" )
2210
+ raise ValueError (
2211
+ """
2212
+ Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit
2213
+ the quantized model. If you have set a value for `max_memory` you should increase that. To have
2214
+ an idea of the modules that are set on the CPU or RAM you can print model.hf_device_map.
2215
+ """
2216
+ )
2202
2217
del device_map_without_lm_head
2203
2218
2204
2219
if from_tf :
0 commit comments