@@ -70,15 +70,15 @@ def _setup(self, batch_size: int, vocab_size: int) -> None:
70
70
self .allocate_token_bitmask = allocate_token_bitmask
71
71
self .bias_logits = self ._bias_logits_numpy
72
72
73
- elif self .tensor_library_name == "mlx" :
73
+ elif self .tensor_library_name == "mlx" : # pragma: no cover
74
74
from outlines_core .kernels .mlx import (
75
75
allocate_token_bitmask
76
76
)
77
77
78
78
self .allocate_token_bitmask = allocate_token_bitmask
79
79
self .bias_logits = self ._bias_logits_mlx
80
80
81
- else :
81
+ else : # pragma: no cover
82
82
raise ValueError (
83
83
f"Unsupported tensor library: { self .tensor_library_name } "
84
84
)
@@ -179,7 +179,13 @@ def process_logits(
179
179
else :
180
180
for i in range (batch_size ):
181
181
last_token_id = self .tensor_adapter .to_scalar (input_ids [i ][- 1 ]) # type: ignore
182
- if not self ._guides [i ].is_finished ():
182
+ # This circumvents issue #227 in outlines_core
183
+ # Ideally, we would be able to advance all the times as the final
184
+ # state would accept the eos token leading to itself
185
+ if (
186
+ not self ._guides [i ].is_finished ()
187
+ or self ._guides [i ].accepts_tokens ([last_token_id ])
188
+ ):
183
189
self ._guides [i ].advance (
184
190
token_id = last_token_id ,
185
191
return_tokens = False
@@ -211,13 +217,13 @@ def __init__(self, model: SteerableModel):
211
217
eos_token_id = tokenizer .eos_token_id
212
218
eos_token = tokenizer .eos_token
213
219
token_to_str = tokenizer .convert_token_to_string
214
- elif isinstance (model , MLXLM ):
220
+ elif isinstance (model , MLXLM ): # pragma: no cover
215
221
tokenizer = model .mlx_tokenizer # type: ignore
216
222
vocabulary = tokenizer .get_vocab ()
217
223
eos_token_id = tokenizer .eos_token_id
218
224
eos_token = tokenizer .eos_token
219
225
token_to_str = lambda token : tokenizer .convert_tokens_to_string ([token ]) # type: ignore
220
- else :
226
+ else : # pragma: no cover
221
227
raise ValueError (f"Unsupported model type: { type (model )} " )
222
228
223
229
self .eos_token_id = eos_token_id
0 commit comments