@@ -325,8 +325,9 @@ def __init__(self, sliding_window, *args, **kwargs):
325
325
sliding_window (`int`):
326
326
Effective window size: number of tokens that are kept on each update call.
327
327
"""
328
- kwargs .pop ("max_cache_len" , None )
329
- super ().__init__ (* args , max_cache_len = sliding_window , * args , ** kwargs )
328
+ max_cache_len = kwargs .pop ("max_cache_len" , None )
329
+ max_cache_len = min (sliding_window , max_cache_len ) if max_cache_len is not None else sliding_window
330
+ super ().__init__ (* args , max_cache_len = max_cache_len , * args , ** kwargs )
330
331
331
332
def update (
332
333
self ,
@@ -1277,9 +1278,7 @@ def max_batch_size(self) -> int:
1277
1278
def max_cache_len (self ) -> int :
1278
1279
"""Return the maximum cache length of the cache"""
1279
1280
values = [layer .max_cache_len for layer in self .layers ]
1280
- if len (set (values )) > 1 :
1281
- raise ValueError (f"Max cache length is not consistent across layers: { values } " )
1282
- return values [0 ]
1281
+ return max (values )
1283
1282
1284
1283
@property
1285
1284
def is_compileable (self ) -> bool :
@@ -1655,7 +1654,7 @@ class QuantoQuantizedCache(QuantizedCache):
1655
1654
"""
1656
1655
1657
1656
def __init__ (self , ** kwargs ) -> None :
1658
- Cache .__init__ (self , cache_processor = QuantoQuantizedCacheProcessor , ** kwargs )
1657
+ DynamicCache .__init__ (self , cache_processor = QuantoQuantizedCacheProcessor , ** kwargs )
1659
1658
1660
1659
1661
1660
class HQQQuantizedCache (QuantizedCache ):
@@ -1697,7 +1696,7 @@ class HQQQuantizedCache(QuantizedCache):
1697
1696
1698
1697
def __init__ (self , backend = "HQQ" , ** kwargs ) -> None :
1699
1698
assert backend == "HQQ"
1700
- Cache .__init__ (self , cache_processor = HQQQuantizedCacheProcessor , ** kwargs )
1699
+ DynamicCache .__init__ (self , cache_processor = HQQQuantizedCacheProcessor , ** kwargs )
1701
1700
1702
1701
1703
1702
class EncoderDecoderCache (Cache ):
@@ -1951,10 +1950,6 @@ def parse_layer_args_from_model_config(
1951
1950
)
1952
1951
# Adjust max_cache_len for sliding window layers (they can't be larger than sliding window)
1953
1952
max_cache_len = max_cache_len or config .max_position_embeddings
1954
- if getattr (config , "sliding_window" , None ) is not None :
1955
- sliding_window_len = min (config .sliding_window , max_cache_len )
1956
- else :
1957
- sliding_window_len = None
1958
1953
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads:
1959
1954
head_dim = (
1960
1955
config .head_dim
@@ -1981,7 +1976,7 @@ def parse_layer_args_from_model_config(
1981
1976
"layer_device_map" : layer_device_map ,
1982
1977
"head_dim" : head_dim ,
1983
1978
"num_heads" : num_heads ,
1984
- "sliding_window" : sliding_window_len ,
1979
+ "sliding_window" : getattr ( config , "sliding_window" , None ) ,
1985
1980
}
1986
1981
return {k : v for k , v in layer_args .items () if v is not None }
1987
1982
0 commit comments