Skip to content

Commit 5a69228

Browse files
authored
Merge branch 'dottxt-ai:main' into multimodal-input-handling-via-chat-interface
2 parents fb3de2d + 778cd10 commit 5a69228

File tree

17 files changed

+359
-270
lines changed

17 files changed

+359
-270
lines changed

.github/workflows/tests.yml

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ jobs:
5353
run: |
5454
rm -f .coverage*
5555
uv run coverage erase
56-
uv run python -m coverage run --branch --parallel-mode -m pytest -x -m 'not api_call'
56+
uv run python -m coverage run --branch --source=outlines --parallel-mode -m pytest -x -m 'not api_call'
5757
- name: Upload coverage data
5858
uses: actions/upload-artifact@v4
5959
with:
@@ -86,37 +86,18 @@ jobs:
8686
with:
8787
name: coverage-data
8888

89-
- name: Determine base branch for comparison
90-
id: base-branch
91-
run: |
92-
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
93-
# For PRs, use the remote tracking branch
94-
COMPARE_BRANCH="origin/${{ github.base_ref }}"
95-
echo "COMPARE_BRANCH=$COMPARE_BRANCH" >> $GITHUB_ENV
96-
else
97-
# For push events, compare against the parent commit
98-
COMPARE_BRANCH="${{ github.event.before }}"
99-
echo "COMPARE_BRANCH=$COMPARE_BRANCH" >> $GITHUB_ENV
100-
fi
101-
echo "Using $COMPARE_BRANCH for coverage comparison"
102-
103-
- name: Fetch base branch for coverage diff
104-
run: |
105-
git fetch --no-tags --prune origin ${COMPARE_BRANCH#origin/}
106-
10789
- name: Combine coverage & fail if it's <100%.
10890
run: |
10991
python -m coverage combine
11092
python -m coverage html --skip-covered --skip-empty
11193
python -m coverage xml
112-
diff-cover coverage.xml --compare-branch=$COMPARE_BRANCH --markdown-report=coverage.md --fail-under=100 || (cat coverage.md >> $GITHUB_STEP_SUMMARY && exit 1)
94+
python -m coverage report --fail-under=100 || (python -m coverage report && exit 1)
11395
11496
- name: Upload HTML report if check failed.
11597
uses: actions/upload-artifact@v4
11698
with:
11799
name: html-report
118100
path: htmlcov
119-
# TODO FIXME: This is only using the last run
120101
overwrite: true
121102
if: ${{ failure() }}
122103

outlines/backends/llguidance.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,6 @@ def __init__(
3737
The name of the tensor library used by the model
3838
3939
"""
40-
if tensor_library_name not in SUPPORTED_TENSOR_LIBRARIES:
41-
raise TypeError(f"Unsupported tensor library: {tensor_library_name}")
42-
4340
self.is_first_token = True
4441
self.grammar = grammar
4542
self.llg_tokenizer = llg_tokenizer

outlines/backends/outlines_core.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,15 @@ def _setup(self, batch_size: int, vocab_size: int) -> None:
7070
self.allocate_token_bitmask = allocate_token_bitmask
7171
self.bias_logits = self._bias_logits_numpy
7272

73-
elif self.tensor_library_name == "mlx":
73+
elif self.tensor_library_name == "mlx": # pragma: no cover
7474
from outlines_core.kernels.mlx import (
7575
allocate_token_bitmask
7676
)
7777

7878
self.allocate_token_bitmask = allocate_token_bitmask
7979
self.bias_logits = self._bias_logits_mlx
8080

81-
else:
81+
else: # pragma: no cover
8282
raise ValueError(
8383
f"Unsupported tensor library: {self.tensor_library_name}"
8484
)
@@ -179,7 +179,13 @@ def process_logits(
179179
else:
180180
for i in range(batch_size):
181181
last_token_id = self.tensor_adapter.to_scalar(input_ids[i][-1]) # type: ignore
182-
if not self._guides[i].is_finished():
182+
# This circumvents issue #227 in outlines_core
183+
# Ideally, we would be able to advance all the times as the final
184+
# state would accept the eos token leading to itself
185+
if (
186+
not self._guides[i].is_finished()
187+
or self._guides[i].accepts_tokens([last_token_id])
188+
):
183189
self._guides[i].advance(
184190
token_id=last_token_id,
185191
return_tokens=False
@@ -211,13 +217,13 @@ def __init__(self, model: SteerableModel):
211217
eos_token_id = tokenizer.eos_token_id
212218
eos_token = tokenizer.eos_token
213219
token_to_str = tokenizer.convert_token_to_string
214-
elif isinstance(model, MLXLM):
220+
elif isinstance(model, MLXLM): # pragma: no cover
215221
tokenizer = model.mlx_tokenizer # type: ignore
216222
vocabulary = tokenizer.get_vocab()
217223
eos_token_id = tokenizer.eos_token_id
218224
eos_token = tokenizer.eos_token
219225
token_to_str = lambda token: tokenizer.convert_tokens_to_string([token]) # type: ignore
220-
else:
226+
else: # pragma: no cover
221227
raise ValueError(f"Unsupported model type: {type(model)}")
222228

223229
self.eos_token_id = eos_token_id

outlines/backends/xgrammar.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _setup(self, batch_size: int, vocab_size: int) -> None:
3939
"""Setup the logits processor for a new generation."""
4040
if self.tensor_library_name == "torch":
4141
self._bias_logits = self._bias_logits_torch
42-
elif self.tensor_library_name == "mlx":
42+
elif self.tensor_library_name == "mlx": # pragma: no cover
4343
self._bias_logits = self._bias_logits_mlx
4444
else: # pragma: no cover
4545
raise ValueError(
@@ -101,7 +101,7 @@ def process_logits(
101101
self.is_first_token = False
102102
else:
103103
for i in range(batch_size):
104-
if not self._matchers[i].is_terminated():
104+
if not self._matchers[i].is_terminated(): # pragma: no cover
105105
last_token_id = self.tensor_adapter.to_scalar(
106106
input_ids[i][-1] # type: ignore
107107
)
@@ -125,7 +125,7 @@ def __init__(self, model: SteerableModel):
125125

126126
if isinstance(model, Transformers):
127127
tokenizer = model.hf_tokenizer
128-
elif isinstance(model, MLXLM):
128+
elif isinstance(model, MLXLM): # pragma: no cover
129129
tokenizer = model.mlx_tokenizer._tokenizer
130130
else: # pragma: no cover
131131
raise ValueError(

outlines/caching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def get_cache():
6161
cache_dir = outlines_cache_dir
6262
elif xdg_cache_home: # pragma: no cover
6363
cache_dir = os.path.join(xdg_cache_home, ".cache", "outlines")
64-
elif home_dir != "/":
64+
elif home_dir != "/": # pragma: no cover
6565
cache_dir = os.path.join(home_dir, ".cache", "outlines")
6666
else: # pragma: no cover
6767
# home_dir may be / inside a docker container without existing user

outlines/models/transformers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,29 @@ def __init__(
244244
and isinstance(model, FlaxPreTrainedModel)
245245
):
246246
self.tensor_library_name = "jax"
247+
warnings.warn("""
248+
Support for `jax` has been deprecated and will be removed in
249+
version 1.4.0 of Outlines. Please use `torch` instead.
250+
Transformers models using `jax` do not support structured
251+
generation.
252+
""",
253+
DeprecationWarning,
254+
stacklevel=2,
255+
)
247256
elif (
248257
TFPreTrainedModel is not None
249258
and isinstance(model, TFPreTrainedModel)
250259
):
251260
self.tensor_library_name = "tensorflow"
261+
warnings.warn("""
262+
Support for `tensorflow` has been deprecated and will be removed in
263+
version 1.4.0 of Outlines. Please use `torch` instead.
264+
Transformers models using `tensorflow` do not support structured
265+
generation.
266+
""",
267+
DeprecationWarning,
268+
stacklevel=2,
269+
)
252270
else:
253271
self.tensor_library_name = "torch"
254272

outlines/processors/base_logits_processor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def __init__(self, tensor_library_name: str):
2828
----------
2929
tensor_library_name
3030
The name of the library to use to manipulate tensors. Possible
31-
values are "jax", "mlx", "numpy", "tensorflow" and "torch". You
32-
must choose the library that your model is using.
31+
values are "mlx", "numpy" and "torch". You must choose the library
32+
that your model is using.
3333
"""
3434
# Temporary fix as torch raises a warning that can cause can an error
3535
# with python 3.12.
@@ -52,7 +52,7 @@ def reset(self):
5252
needs to be reset for a new generation.
5353
5454
"""
55-
pass
55+
pass # pragma: no cover
5656

5757
@abstractmethod
5858
def process_logits(

outlines/processors/tensor_adapters/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,19 @@
22

33
from typing import Union
44

5-
from .jax import JAXTensorAdapter
65
from .mlx import MLXTensorAdapter
76
from .numpy import NumpyTensorAdapter
8-
from .tensorflow import TensorFlowTensorAdapter
97
from .torch import TorchTensorAdapter
108

119

1210
tensor_adapters = {
13-
"jax": JAXTensorAdapter,
1411
"mlx": MLXTensorAdapter,
1512
"numpy": NumpyTensorAdapter,
16-
"tensorflow": TensorFlowTensorAdapter,
1713
"torch": TorchTensorAdapter,
1814
}
1915

2016
TensorAdapterImplementation = Union[
21-
JAXTensorAdapter,
2217
MLXTensorAdapter,
2318
NumpyTensorAdapter,
24-
TensorFlowTensorAdapter,
2519
TorchTensorAdapter,
2620
]

outlines/processors/tensor_adapters/jax.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

outlines/processors/tensor_adapters/tensorflow.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

0 commit comments

Comments
 (0)