1
- import pytest
1
+ import re
2
2
3
3
import llama_cpp
4
- import llguidance .hf
5
- import numpy as np
6
- import torch
4
+ import llguidance
5
+ import pytest
7
6
import transformers
8
7
from llguidance import LLTokenizer
9
8
12
11
LLGuidanceBackend ,
13
12
LLGuidanceLogitsProcessor
14
13
)
14
+ from tests .backends .test_backends_utils import simulate_model_calling_processor
15
15
16
16
try :
17
- import mlx .core as mx
18
17
import mlx_lm
19
18
HAS_MLX = True
20
19
except ImportError :
@@ -40,20 +39,6 @@ def model_mlxlm():
40
39
* mlx_lm .load ("mlx-community/SmolLM-135M-Instruct-4bit" )
41
40
)
42
41
43
- @pytest .fixture
44
- def llg_tokenizer ():
45
- return llguidance .hf .from_tokenizer (
46
- transformers .AutoTokenizer .from_pretrained ("erwanf/gpt2-mini" ),
47
- )
48
-
49
- @pytest .fixture
50
- def llg_grammar_spec ():
51
- return (
52
- '{"grammars": [{ "json_schema": {"type": "object", "properties":'
53
- + ' {"name": {"type": "string"}, "age": {"type": "integer"}}, "requ'
54
- + 'ired": ["name", "age"], "additionalProperties": false} }] }'
55
- )
56
-
57
42
@pytest .fixture
58
43
def json_schema ():
59
44
return (
@@ -97,42 +82,61 @@ def cfg_ebnf():
97
82
"""
98
83
99
84
100
- def test_llguidance_processor_torch (llg_grammar_spec , llg_tokenizer ):
101
- processor = LLGuidanceLogitsProcessor (llg_grammar_spec , llg_tokenizer , "torch" )
102
- logits = torch .randn (2 , llg_tokenizer .vocab_size )
103
- input_ids = torch .randint (0 , llg_tokenizer .vocab_size , (2 , 10 ))
104
- output = processor (input_ids , logits )
105
- assert output .shape == (2 , llg_tokenizer .vocab_size )
106
- processor (input_ids , logits )
107
-
85
+ def test_llguidance_processor_torch (regex ):
86
+ model = model_transformers ()
87
+ tokenizer = model .tokenizer
88
+ hf_tokenizer = model .hf_tokenizer
89
+ llg_tokenizer = LLGuidanceBackend (model ).llg_tokenizer
90
+ grammar_spec = llguidance .grammar_from ("regex" , regex )
91
+ processor = LLGuidanceLogitsProcessor (grammar_spec , llg_tokenizer , "torch" )
92
+ for _ in range (2 ):
93
+ input_ids = simulate_model_calling_processor (
94
+ processor ,
95
+ "torch" ,
96
+ len (tokenizer .get_vocab ()),
97
+ tokenizer .eos_token_id ,
98
+ 2
99
+ )
100
+ assert re .match (regex , hf_tokenizer .decode (input_ids [0 ]))
101
+ assert re .match (regex , hf_tokenizer .decode (input_ids [1 ]))
102
+
103
+
104
+ def test_llguidance_processor_numpy (regex ):
105
+ model = model_llamacpp ()
106
+ tokenizer = model .tokenizer
107
+ llg_tokenizer = LLGuidanceBackend (model ).llg_tokenizer
108
+ grammar_spec = llguidance .grammar_from ("regex" , regex )
109
+ processor = LLGuidanceLogitsProcessor (grammar_spec , llg_tokenizer , "numpy" )
110
+ for _ in range (2 ):
111
+ input_ids = simulate_model_calling_processor (
112
+ processor ,
113
+ "numpy" ,
114
+ len (tokenizer .vocabulary ),
115
+ tokenizer .eos_token_id ,
116
+ 2
117
+ )
118
+ assert re .match (regex , tokenizer .decode (input_ids [0 ])[0 ])
119
+ assert re .match (regex , tokenizer .decode (input_ids [1 ])[0 ])
108
120
109
- def test_llguidance_processor_numpy (llg_grammar_spec , llg_tokenizer ):
110
- processor = LLGuidanceLogitsProcessor (llg_grammar_spec , llg_tokenizer , "numpy" )
111
- logits = np .random .randn (2 , llg_tokenizer .vocab_size )
112
- input_ids = np .random .randint (0 , llg_tokenizer .vocab_size , (2 , 10 ))
113
- output = processor (input_ids , logits )
114
- assert output .shape == (2 , llg_tokenizer .vocab_size )
115
- processor (input_ids , logits )
116
121
117
122
118
123
@pytest .mark .skipif (not HAS_MLX , reason = "MLX tests require Apple Silicon" )
119
- def test_llguidance_processor_mlx (llg_grammar_spec , llg_tokenizer ):
120
- processor = LLGuidanceLogitsProcessor (llg_grammar_spec , llg_tokenizer , "mlx" )
121
- logits = mx .random .normal ((2 , llg_tokenizer .vocab_size ))
122
- input_ids = mx .random .randint (0 , llg_tokenizer .vocab_size , (2 , 10 ))
123
- output = processor (input_ids , logits )
124
- assert output .shape == (2 , llg_tokenizer .vocab_size )
125
- processor (input_ids , logits )
126
-
127
-
128
- def test_llguidance_processor_tensorflow (llg_grammar_spec , llg_tokenizer ):
129
- with pytest .raises (TypeError ):
130
- LLGuidanceLogitsProcessor (llg_grammar_spec , llg_tokenizer , "tensorflow" )
131
-
132
-
133
- def test_llguidance_processor_jax (llg_grammar_spec , llg_tokenizer ):
134
- with pytest .raises (TypeError ):
135
- LLGuidanceLogitsProcessor (llg_grammar_spec , llg_tokenizer , "jax" )
124
+ def test_llguidance_processor_mlx (regex ):
125
+ model = model_mlxlm ()
126
+ tokenizer = model .mlx_tokenizer
127
+ llg_tokenizer = LLGuidanceBackend (model ).llg_tokenizer
128
+ grammar_spec = llguidance .grammar_from ("regex" , regex )
129
+ processor = LLGuidanceLogitsProcessor (grammar_spec , llg_tokenizer , "mlx" )
130
+ for _ in range (2 ):
131
+ input_ids = simulate_model_calling_processor (
132
+ processor ,
133
+ "mlx" ,
134
+ len (tokenizer .vocabulary ),
135
+ tokenizer .eos_token_id ,
136
+ 2
137
+ )
138
+ assert re .match (regex , tokenizer .decode (input_ids [0 ]))
139
+ assert re .match (regex , tokenizer .decode (input_ids [1 ]))
136
140
137
141
138
142
models = [
@@ -155,7 +159,6 @@ def test_llguidance_backend(model, tensor_library_name, json_schema, regex, cfg_
155
159
generator = outlines .Generator (model , backend = "llguidance" , processor = processor )
156
160
response = generator ("Hello, how are you?" )
157
161
assert response [0 ] == "{"
158
- assert "name" in response
159
162
160
163
# regex
161
164
processor = backend .get_regex_logits_processor (regex )
@@ -184,3 +187,16 @@ def test_llguidance_backend(model, tensor_library_name, json_schema, regex, cfg_
184
187
generator = outlines .Generator (model , backend = "llguidance" , processor = processor )
185
188
response = generator ("Hello, how are you?" )
186
189
assert response == "yes" or response == "no"
190
+
191
+ # batch + multiple generations
192
+ processor = backend .get_json_schema_logits_processor (json_schema )
193
+ generator = outlines .Generator (model , backend = "llguidance" , processor = processor )
194
+ for _ in range (2 ):
195
+ if tensor_library_name == "torch" :
196
+ response = generator .batch (["Create a character" , "Hello, how are you?" ], max_new_tokens = 200 )
197
+ assert len (response ) == 2
198
+ for r in response :
199
+ assert r [0 ] == "{"
200
+ else :
201
+ response = generator ("Create a character" , max_tokens = 20 )
202
+ assert response [0 ] == "{"
0 commit comments