1
1
import dataclasses
2
2
from typing import TYPE_CHECKING , Iterator , List , Optional , Tuple , TypedDict , Union
3
3
4
+ import torch
4
5
from typing_extensions import Unpack
5
6
6
7
from outlines .generate .api import GenerationParameters , SamplingParameters
7
8
8
9
if TYPE_CHECKING :
9
- from exllamav2 import ExLlamaV2Tokenizer
10
+ import torch . LongTensor
10
11
from exllamav2 .generator import ExLlamaV2DynamicGenerator , ExLlamaV2Sampler
11
12
12
13
@@ -18,13 +19,33 @@ class ExllamaV2Params(TypedDict, total=False):
18
19
max_new_tokens : List [int ]
19
20
20
21
22
+ class OutlinesExLlamaV2Tokenizer :
23
+ def __init__ (self , tokenizer ):
24
+ self .exl2_tokenizer = tokenizer
25
+ self .vocabulary = self .exl2_tokenizer .get_piece_to_id_dict ()
26
+ self .special_tokens = set (self .exl2_tokenizer .extended_piece_to_id )
27
+ self .eos_token_id = self .exl2_tokenizer .eos_token_id
28
+
29
+ def convert_token_to_string (self , token ):
30
+ return token
31
+
32
+ def decode (self , token_ids : "torch.LongTensor" ) -> List [str ]:
33
+ decoded = self .exl2_tokenizer .decode (
34
+ torch .tensor (token_ids ),
35
+ decode_special_tokens = False ,
36
+ )
37
+ if isinstance (decoded , str ):
38
+ return [decoded ]
39
+ return decoded
40
+
41
+
21
42
class ExLlamaV2Model :
22
43
"""Represents a `exl2` model."""
23
44
24
45
def __init__ (
25
46
self ,
26
47
generator : "ExLlamaV2DynamicGenerator" ,
27
- tokenizer : "ExLlamaV2Tokenizer " ,
48
+ tokenizer : "OutlinesExLlamaV2Tokenizer " ,
28
49
max_seq_len : int ,
29
50
):
30
51
self .generator = generator
@@ -220,14 +241,6 @@ def token_generator() -> Iterator[str]:
220
241
return token_generator ()
221
242
222
243
223
- # Taken from https://github.com/lapp0/exllamav2/pull/1/files#diff-26f303de07c10aad998e33d3df52581643673a598162cc4b35ef051f52d7c60b
224
- def patch_tokenizer (tokenizer ):
225
- tokenizer .vocabulary = tokenizer .piece_to_id
226
- tokenizer .special_tokens = set (tokenizer .extended_piece_to_id )
227
- tokenizer .convert_token_to_string = lambda t : t
228
- return tokenizer
229
-
230
-
231
244
def exl2 (
232
245
model_path : str ,
233
246
draft_model_path : Optional [str ] = None ,
@@ -306,7 +319,6 @@ def exl2(
306
319
307
320
print ("Loading tokenizer..." )
308
321
tokenizer = ExLlamaV2Tokenizer (config )
309
- tokenizer = patch_tokenizer (tokenizer )
310
322
max_batch_size = 4 if paged else 1
311
323
312
324
draft_model = None
@@ -337,4 +349,7 @@ def exl2(
337
349
paged = paged ,
338
350
)
339
351
max_seq_len = cache .max_seq_len
340
- return ExLlamaV2Model (generator , tokenizer , max_seq_len )
352
+
353
+ outlines_tokenizer = OutlinesExLlamaV2Tokenizer (tokenizer )
354
+ outlines_exl2_model = ExLlamaV2Model (generator , outlines_tokenizer , max_seq_len )
355
+ return outlines_exl2_model
0 commit comments