Skip to content

Commit 778cd10

Browse files
committed
Fix coverage issues, improve testing logits processors
1 parent 6ae9c28 commit 778cd10

File tree

8 files changed

+308
-91
lines changed

8 files changed

+308
-91
lines changed

outlines/backends/llguidance.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,6 @@ def __init__(
3737
The name of the tensor library used by the model
3838
3939
"""
40-
if tensor_library_name not in SUPPORTED_TENSOR_LIBRARIES:
41-
raise TypeError(f"Unsupported tensor library: {tensor_library_name}")
42-
4340
self.is_first_token = True
4441
self.grammar = grammar
4542
self.llg_tokenizer = llg_tokenizer

outlines/backends/outlines_core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,15 @@ def _setup(self, batch_size: int, vocab_size: int) -> None:
7070
self.allocate_token_bitmask = allocate_token_bitmask
7171
self.bias_logits = self._bias_logits_numpy
7272

73-
elif self.tensor_library_name == "mlx":
73+
elif self.tensor_library_name == "mlx": # pragma: no cover
7474
from outlines_core.kernels.mlx import (
7575
allocate_token_bitmask
7676
)
7777

7878
self.allocate_token_bitmask = allocate_token_bitmask
7979
self.bias_logits = self._bias_logits_mlx
8080

81-
else:
81+
else: # pragma: no cover
8282
raise ValueError(
8383
f"Unsupported tensor library: {self.tensor_library_name}"
8484
)
@@ -217,13 +217,13 @@ def __init__(self, model: SteerableModel):
217217
eos_token_id = tokenizer.eos_token_id
218218
eos_token = tokenizer.eos_token
219219
token_to_str = tokenizer.convert_token_to_string
220-
elif isinstance(model, MLXLM):
220+
elif isinstance(model, MLXLM): # pragma: no cover
221221
tokenizer = model.mlx_tokenizer # type: ignore
222222
vocabulary = tokenizer.get_vocab()
223223
eos_token_id = tokenizer.eos_token_id
224224
eos_token = tokenizer.eos_token
225225
token_to_str = lambda token: tokenizer.convert_tokens_to_string([token]) # type: ignore
226-
else:
226+
else: # pragma: no cover
227227
raise ValueError(f"Unsupported model type: {type(model)}")
228228

229229
self.eos_token_id = eos_token_id

outlines/backends/xgrammar.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _setup(self, batch_size: int, vocab_size: int) -> None:
3939
"""Setup the logits processor for a new generation."""
4040
if self.tensor_library_name == "torch":
4141
self._bias_logits = self._bias_logits_torch
42-
elif self.tensor_library_name == "mlx":
42+
elif self.tensor_library_name == "mlx": # pragma: no cover
4343
self._bias_logits = self._bias_logits_mlx
4444
else: # pragma: no cover
4545
raise ValueError(
@@ -101,7 +101,7 @@ def process_logits(
101101
self.is_first_token = False
102102
else:
103103
for i in range(batch_size):
104-
if not self._matchers[i].is_terminated():
104+
if not self._matchers[i].is_terminated(): # pragma: no cover
105105
last_token_id = self.tensor_adapter.to_scalar(
106106
input_ids[i][-1] # type: ignore
107107
)
@@ -125,7 +125,7 @@ def __init__(self, model: SteerableModel):
125125

126126
if isinstance(model, Transformers):
127127
tokenizer = model.hf_tokenizer
128-
elif isinstance(model, MLXLM):
128+
elif isinstance(model, MLXLM): # pragma: no cover
129129
tokenizer = model.mlx_tokenizer._tokenizer
130130
else: # pragma: no cover
131131
raise ValueError(

outlines/caching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def get_cache():
6161
cache_dir = outlines_cache_dir
6262
elif xdg_cache_home: # pragma: no cover
6363
cache_dir = os.path.join(xdg_cache_home, ".cache", "outlines")
64-
elif home_dir != "/":
64+
elif home_dir != "/": # pragma: no cover
6565
cache_dir = os.path.join(home_dir, ".cache", "outlines")
6666
else: # pragma: no cover
6767
# home_dir may be / inside a docker container without existing user
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
import numpy as np
3+
4+
5+
def simulate_model_calling_processor(processor, tensor_library_name, vocabulary_size, eos_token_id, batch_size):
6+
if tensor_library_name == "torch":
7+
tensor_adapter = TorchTensorAdapter()
8+
elif tensor_library_name == "numpy":
9+
tensor_adapter = NumpyTensorAdapter()
10+
elif tensor_library_name == "mlx":
11+
tensor_adapter = MLXTensorAdapter()
12+
13+
processor.reset()
14+
i = 0
15+
input_ids = tensor_adapter.randint(0, vocabulary_size, (batch_size, 10))
16+
while True:
17+
i += 1
18+
logits = tensor_adapter.randn((batch_size, vocabulary_size))
19+
output = processor(input_ids, logits)
20+
assert output.shape == (batch_size, vocabulary_size)
21+
if all(input_ids[:, -1] == eos_token_id):
22+
break
23+
input_ids = tensor_adapter.add_token_inputs_ids(input_ids, output)
24+
print(input_ids)
25+
if i > 20:
26+
break
27+
return input_ids[:, 10:]
28+
29+
class TorchTensorAdapter():
30+
def randn(self, shape):
31+
return torch.randn(*shape)
32+
33+
def randint(self, low, high, size):
34+
return torch.randint(low, high, size)
35+
36+
def add_token_inputs_ids(self, input_ids, logits):
37+
next_token_ids = torch.argmax(logits, dim=-1)
38+
input_ids = torch.cat([input_ids, next_token_ids.unsqueeze(-1)], dim=-1)
39+
return input_ids
40+
41+
42+
class NumpyTensorAdapter():
43+
def randn(self, shape):
44+
return np.random.randn(*shape)
45+
46+
def randint(self, low, high, size):
47+
return np.random.randint(low, high, size)
48+
49+
def add_token_inputs_ids(self, input_ids, logits):
50+
next_token_ids = np.argmax(logits, axis=-1)
51+
print("next_token_ids",next_token_ids)
52+
input_ids = np.concatenate([input_ids, next_token_ids[..., None]], axis=-1)
53+
return input_ids
54+
55+
56+
class MLXTensorAdapter():
57+
def __init__(self):
58+
import mlx
59+
self.mlx = mlx
60+
61+
def randn(self, shape):
62+
return self.mlx.random.randn(*shape)
63+
64+
def randint(self, low, high, size):
65+
return self.mlx.random.randint(low, high, size)
66+
67+
def add_token_inputs_ids(self, input_ids, logits):
68+
next_token_ids = self.mlx.argmax(logits, axis=-1)
69+
input_ids = self.mlx.concatenate([input_ids, next_token_ids[..., None]], axis=-1)
70+
return input_ids

tests/backends/test_llguidance.py

Lines changed: 68 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import pytest
1+
import re
22

33
import llama_cpp
4-
import llguidance.hf
5-
import numpy as np
6-
import torch
4+
import llguidance
5+
import pytest
76
import transformers
87
from llguidance import LLTokenizer
98

@@ -12,9 +11,9 @@
1211
LLGuidanceBackend,
1312
LLGuidanceLogitsProcessor
1413
)
14+
from tests.backends.test_backends_utils import simulate_model_calling_processor
1515

1616
try:
17-
import mlx.core as mx
1817
import mlx_lm
1918
HAS_MLX = True
2019
except ImportError:
@@ -40,20 +39,6 @@ def model_mlxlm():
4039
*mlx_lm.load("mlx-community/SmolLM-135M-Instruct-4bit")
4140
)
4241

43-
@pytest.fixture
44-
def llg_tokenizer():
45-
return llguidance.hf.from_tokenizer(
46-
transformers.AutoTokenizer.from_pretrained("erwanf/gpt2-mini"),
47-
)
48-
49-
@pytest.fixture
50-
def llg_grammar_spec():
51-
return (
52-
'{"grammars": [{ "json_schema": {"type": "object", "properties":'
53-
+ ' {"name": {"type": "string"}, "age": {"type": "integer"}}, "requ'
54-
+ 'ired": ["name", "age"], "additionalProperties": false} }] }'
55-
)
56-
5742
@pytest.fixture
5843
def json_schema():
5944
return (
@@ -97,42 +82,61 @@ def cfg_ebnf():
9782
"""
9883

9984

100-
def test_llguidance_processor_torch(llg_grammar_spec, llg_tokenizer):
101-
processor = LLGuidanceLogitsProcessor(llg_grammar_spec, llg_tokenizer, "torch")
102-
logits = torch.randn(2, llg_tokenizer.vocab_size)
103-
input_ids = torch.randint(0, llg_tokenizer.vocab_size, (2, 10))
104-
output = processor(input_ids, logits)
105-
assert output.shape == (2, llg_tokenizer.vocab_size)
106-
processor(input_ids, logits)
107-
85+
def test_llguidance_processor_torch(regex):
86+
model = model_transformers()
87+
tokenizer = model.tokenizer
88+
hf_tokenizer = model.hf_tokenizer
89+
llg_tokenizer = LLGuidanceBackend(model).llg_tokenizer
90+
grammar_spec = llguidance.grammar_from("regex", regex)
91+
processor = LLGuidanceLogitsProcessor(grammar_spec, llg_tokenizer, "torch")
92+
for _ in range(2):
93+
input_ids = simulate_model_calling_processor(
94+
processor,
95+
"torch",
96+
len(tokenizer.get_vocab()),
97+
tokenizer.eos_token_id,
98+
2
99+
)
100+
assert re.match(regex, hf_tokenizer.decode(input_ids[0]))
101+
assert re.match(regex, hf_tokenizer.decode(input_ids[1]))
102+
103+
104+
def test_llguidance_processor_numpy(regex):
105+
model = model_llamacpp()
106+
tokenizer = model.tokenizer
107+
llg_tokenizer = LLGuidanceBackend(model).llg_tokenizer
108+
grammar_spec = llguidance.grammar_from("regex", regex)
109+
processor = LLGuidanceLogitsProcessor(grammar_spec, llg_tokenizer, "numpy")
110+
for _ in range(2):
111+
input_ids = simulate_model_calling_processor(
112+
processor,
113+
"numpy",
114+
len(tokenizer.vocabulary),
115+
tokenizer.eos_token_id,
116+
2
117+
)
118+
assert re.match(regex, tokenizer.decode(input_ids[0])[0])
119+
assert re.match(regex, tokenizer.decode(input_ids[1])[0])
108120

109-
def test_llguidance_processor_numpy(llg_grammar_spec, llg_tokenizer):
110-
processor = LLGuidanceLogitsProcessor(llg_grammar_spec, llg_tokenizer, "numpy")
111-
logits = np.random.randn(2, llg_tokenizer.vocab_size)
112-
input_ids = np.random.randint(0, llg_tokenizer.vocab_size, (2, 10))
113-
output = processor(input_ids, logits)
114-
assert output.shape == (2, llg_tokenizer.vocab_size)
115-
processor(input_ids, logits)
116121

117122

118123
@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
119-
def test_llguidance_processor_mlx(llg_grammar_spec, llg_tokenizer):
120-
processor = LLGuidanceLogitsProcessor(llg_grammar_spec, llg_tokenizer, "mlx")
121-
logits = mx.random.normal((2, llg_tokenizer.vocab_size))
122-
input_ids = mx.random.randint(0, llg_tokenizer.vocab_size, (2, 10))
123-
output = processor(input_ids, logits)
124-
assert output.shape == (2, llg_tokenizer.vocab_size)
125-
processor(input_ids, logits)
126-
127-
128-
def test_llguidance_processor_tensorflow(llg_grammar_spec, llg_tokenizer):
129-
with pytest.raises(TypeError):
130-
LLGuidanceLogitsProcessor(llg_grammar_spec, llg_tokenizer, "tensorflow")
131-
132-
133-
def test_llguidance_processor_jax(llg_grammar_spec, llg_tokenizer):
134-
with pytest.raises(TypeError):
135-
LLGuidanceLogitsProcessor(llg_grammar_spec, llg_tokenizer, "jax")
124+
def test_llguidance_processor_mlx(regex):
125+
model = model_mlxlm()
126+
tokenizer = model.mlx_tokenizer
127+
llg_tokenizer = LLGuidanceBackend(model).llg_tokenizer
128+
grammar_spec = llguidance.grammar_from("regex", regex)
129+
processor = LLGuidanceLogitsProcessor(grammar_spec, llg_tokenizer, "mlx")
130+
for _ in range(2):
131+
input_ids = simulate_model_calling_processor(
132+
processor,
133+
"mlx",
134+
len(tokenizer.vocabulary),
135+
tokenizer.eos_token_id,
136+
2
137+
)
138+
assert re.match(regex, tokenizer.decode(input_ids[0]))
139+
assert re.match(regex, tokenizer.decode(input_ids[1]))
136140

137141

138142
models = [
@@ -155,7 +159,6 @@ def test_llguidance_backend(model, tensor_library_name, json_schema, regex, cfg_
155159
generator = outlines.Generator(model, backend="llguidance", processor=processor)
156160
response = generator("Hello, how are you?")
157161
assert response[0] == "{"
158-
assert "name" in response
159162

160163
# regex
161164
processor = backend.get_regex_logits_processor(regex)
@@ -184,3 +187,16 @@ def test_llguidance_backend(model, tensor_library_name, json_schema, regex, cfg_
184187
generator = outlines.Generator(model, backend="llguidance", processor=processor)
185188
response = generator("Hello, how are you?")
186189
assert response == "yes" or response == "no"
190+
191+
# batch + multiple generations
192+
processor = backend.get_json_schema_logits_processor(json_schema)
193+
generator = outlines.Generator(model, backend="llguidance", processor=processor)
194+
for _ in range(2):
195+
if tensor_library_name == "torch":
196+
response = generator.batch(["Create a character", "Hello, how are you?"], max_new_tokens=200)
197+
assert len(response) == 2
198+
for r in response:
199+
assert r[0] == "{"
200+
else:
201+
response = generator("Create a character", max_tokens=20)
202+
assert response[0] == "{"

0 commit comments

Comments
 (0)