Skip to content

Commit faa7c5c

Browse files
lapp0rlouf
authored andcommitted
automatically download exl2 model in tests
fix exl bug: sometimes piece_to_id not populated, but get_piece_to_id() still works fix exl bug: sometimes piece_to_id not populated, but get_piece_to_id() still works enable exl2 in generate.cfg cleate OutlinesExLlamaV2Tokenizer rather than monkey patching
1 parent 80b82f1 commit faa7c5c

File tree

3 files changed

+37
-21
lines changed

3 files changed

+37
-21
lines changed

outlines/generate/cfg.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
SequenceGeneratorAdapter,
55
VisionSequenceGeneratorAdapter,
66
)
7-
from outlines.models import ExLlamaV2Model, LlamaCpp, OpenAI, TransformersVision
7+
from outlines.models import LlamaCpp, OpenAI, TransformersVision
88
from outlines.samplers import Sampler, multinomial
99

1010

@@ -41,13 +41,6 @@ def cfg_vision(model, cfg_str: str, sampler: Sampler = multinomial()):
4141
return VisionSequenceGeneratorAdapter(model, logits_processor, sampler)
4242

4343

44-
@cfg.register(ExLlamaV2Model)
45-
def cfg_exllamav2(model, cfg_str: str, sampler: Sampler = multinomial()):
46-
raise NotImplementedError(
47-
"Not yet available, track progress in https://github.com/dottxt-ai/outlines/pull/1010"
48-
)
49-
50-
5144
@cfg.register(LlamaCpp)
5245
def cfg_llamacpp(model, cfg_str: str, sampler: Sampler = multinomial()):
5346
raise NotImplementedError("Not yet available due to bug in llama_cpp tokenizer")

outlines/models/exllamav2.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import dataclasses
22
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, TypedDict, Union
33

4+
import torch
45
from typing_extensions import Unpack
56

67
from outlines.generate.api import GenerationParameters, SamplingParameters
78

89
if TYPE_CHECKING:
9-
from exllamav2 import ExLlamaV2Tokenizer
10+
import torch.LongTensor
1011
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler
1112

1213

@@ -18,13 +19,33 @@ class ExllamaV2Params(TypedDict, total=False):
1819
max_new_tokens: List[int]
1920

2021

22+
class OutlinesExLlamaV2Tokenizer:
23+
def __init__(self, tokenizer):
24+
self.exl2_tokenizer = tokenizer
25+
self.vocabulary = self.exl2_tokenizer.get_piece_to_id_dict()
26+
self.special_tokens = set(self.exl2_tokenizer.extended_piece_to_id)
27+
self.eos_token_id = self.exl2_tokenizer.eos_token_id
28+
29+
def convert_token_to_string(self, token):
30+
return token
31+
32+
def decode(self, token_ids: "torch.LongTensor") -> List[str]:
33+
decoded = self.exl2_tokenizer.decode(
34+
torch.tensor(token_ids),
35+
decode_special_tokens=False,
36+
)
37+
if isinstance(decoded, str):
38+
return [decoded]
39+
return decoded
40+
41+
2142
class ExLlamaV2Model:
2243
"""Represents a `exl2` model."""
2344

2445
def __init__(
2546
self,
2647
generator: "ExLlamaV2DynamicGenerator",
27-
tokenizer: "ExLlamaV2Tokenizer",
48+
tokenizer: "OutlinesExLlamaV2Tokenizer",
2849
max_seq_len: int,
2950
):
3051
self.generator = generator
@@ -220,14 +241,6 @@ def token_generator() -> Iterator[str]:
220241
return token_generator()
221242

222243

223-
# Taken from https://github.com/lapp0/exllamav2/pull/1/files#diff-26f303de07c10aad998e33d3df52581643673a598162cc4b35ef051f52d7c60b
224-
def patch_tokenizer(tokenizer):
225-
tokenizer.vocabulary = tokenizer.piece_to_id
226-
tokenizer.special_tokens = set(tokenizer.extended_piece_to_id)
227-
tokenizer.convert_token_to_string = lambda t: t
228-
return tokenizer
229-
230-
231244
def exl2(
232245
model_path: str,
233246
draft_model_path: Optional[str] = None,
@@ -306,7 +319,6 @@ def exl2(
306319

307320
print("Loading tokenizer...")
308321
tokenizer = ExLlamaV2Tokenizer(config)
309-
tokenizer = patch_tokenizer(tokenizer)
310322
max_batch_size = 4 if paged else 1
311323

312324
draft_model = None
@@ -337,4 +349,7 @@ def exl2(
337349
paged=paged,
338350
)
339351
max_seq_len = cache.max_seq_len
340-
return ExLlamaV2Model(generator, tokenizer, max_seq_len)
352+
353+
outlines_tokenizer = OutlinesExLlamaV2Tokenizer(tokenizer)
354+
outlines_exl2_model = ExLlamaV2Model(generator, outlines_tokenizer, max_seq_len)
355+
return outlines_exl2_model

tests/generate/test_generate.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,16 @@ def model_llamacpp(tmp_path_factory):
2222

2323
@pytest.fixture(scope="session")
2424
def model_exllamav2(tmp_path_factory):
25+
from huggingface_hub import snapshot_download
26+
27+
tmp_dir = tmp_path_factory.mktemp("model_download")
28+
model_path = snapshot_download(
29+
repo_id="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4.6-exl2",
30+
cache_dir=tmp_dir,
31+
)
32+
2533
return models.exl2(
26-
model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2",
34+
model_path=model_path,
2735
cache_q4=True,
2836
paged=False,
2937
)

0 commit comments

Comments
 (0)