From 26f37514c4e07ef66e844f55aee5003d743b0ffa Mon Sep 17 00:00:00 2001 From: RobinPicard Date: Mon, 11 Aug 2025 18:34:42 +0200 Subject: [PATCH] Alternative implementation of thinking mode --- outlines/backends/__init__.py | 48 +++++++++++++++++-- outlines/backends/outlines_core.py | 24 +++++++--- outlines/generator.py | 19 +++++++- outlines/models/base.py | 33 +++++++++++-- .../processors/thinking_logits_processor.py | 42 ++++++++++++++++ 5 files changed, 152 insertions(+), 14 deletions(-) create mode 100644 outlines/processors/thinking_logits_processor.py diff --git a/outlines/backends/__init__.py b/outlines/backends/__init__.py index 680e54959..4a0906b5f 100644 --- a/outlines/backends/__init__.py +++ b/outlines/backends/__init__.py @@ -8,6 +8,10 @@ from outlines.backends.outlines_core import OutlinesCoreBackend from outlines.backends.xgrammar import XGrammarBackend from outlines.models import SteerableModel +from outlines.processors.thinking_logits_processor import ThinkingLogitsProcessor +from outlines.models.transformers import Transformers +from outlines.models.llama_cpp import LlamaCpp +from outlines.models.mlxlm import MLXLM CFG_DEFAULT_BACKEND = "llguidance" @@ -39,12 +43,30 @@ def _get_backend(backend_name: str, model: SteerableModel) -> BaseBackend: return LLGuidanceBackend(model) else: raise ValueError(f"Backend {backend_name} not supported") - + + +def _get_end_thinking_token_id(end_thinking_tag: str, model: SteerableModel) -> int: + if isinstance(model, Transformers): + tokenizer = model.hf_tokenizer + elif isinstance(model, LlamaCpp): + tokenizer = model.tokenizer + elif isinstance(model, MLXLM): + tokenizer = model.mlx_tokenizer + encoded_end_thinking_tag = tokenizer.encode(end_thinking_tag) + if len(encoded_end_thinking_tag) != 1: + raise ValueError( + "The end_thinking_tag must correspond to a single token in" + + "the tokenizer vocabulary." + ) + return encoded_end_thinking_tag[0] def get_json_schema_logits_processor( backend_name: str | None, model: SteerableModel, json_schema: str, + *, + end_thinking_tag: str | None, + thinking_max_tokens: int | None, ) -> LogitsProcessorType: """Create a logits processor from a JSON schema. @@ -67,13 +89,20 @@ def get_json_schema_logits_processor( backend_name or JSON_SCHEMA_DEFAULT_BACKEND, model, ) - return backend.get_json_schema_logits_processor(json_schema) + backend_logits_processor = backend.get_json_schema_logits_processor(json_schema) + if end_thinking_tag is not None: + end_thinking_token_id = _get_end_thinking_token_id(end_thinking_tag, model) + return ThinkingLogitsProcessor(end_thinking_token_id, thinking_max_tokens, backend_logits_processor) + return backend_logits_processor def get_regex_logits_processor( backend_name: str | None, model: SteerableModel, regex: str, + *, + end_thinking_tag: str | None, + thinking_max_tokens: int | None, ) -> LogitsProcessorType: """Create a logits processor from a regex. @@ -96,13 +125,20 @@ def get_regex_logits_processor( backend_name or REGEX_DEFAULT_BACKEND, model, ) - return backend.get_regex_logits_processor(regex) + backend_logits_processor = backend.get_regex_logits_processor(regex) + if end_thinking_tag is not None: + end_thinking_token_id = _get_end_thinking_token_id(end_thinking_tag, model) + return ThinkingLogitsProcessor(end_thinking_token_id, thinking_max_tokens, backend_logits_processor) + return backend_logits_processor def get_cfg_logits_processor( backend_name: str | None, model: SteerableModel, grammar: str, + *, + end_thinking_tag: str | None, + thinking_max_tokens: int | None, ) -> LogitsProcessorType: """Create a logits processor from a context-free grammar. @@ -125,4 +161,8 @@ def get_cfg_logits_processor( backend_name or CFG_DEFAULT_BACKEND, model, ) - return backend.get_cfg_logits_processor(grammar) + backend_logits_processor = backend.get_cfg_logits_processor(grammar) + if end_thinking_tag is not None: + end_thinking_token_id = _get_end_thinking_token_id(end_thinking_tag, model) + return ThinkingLogitsProcessor(end_thinking_token_id, thinking_max_tokens, backend_logits_processor) + return backend_logits_processor diff --git a/outlines/backends/outlines_core.py b/outlines/backends/outlines_core.py index 9929bff9f..124dc208f 100644 --- a/outlines/backends/outlines_core.py +++ b/outlines/backends/outlines_core.py @@ -1,6 +1,6 @@ """Backend class for Outlines Core.""" -from typing import Callable, Dict +from typing import Callable, Dict, Union from outlines_core import Guide, Index, Vocabulary # TODO: change this once the import issue is fixed in outlines_core @@ -90,7 +90,7 @@ def _setup(self, batch_size: int, vocab_size: int) -> None: ] def _bias_logits_mlx( # pragma: no cover - self, batch_size: int, logits: TensorType + self, batch_size: int, logits: TensorType, skip: list[bool] ) -> TensorType: """Bias the logits for MLX tensors.""" from outlines_core.kernels.mlx import ( @@ -100,6 +100,9 @@ def _bias_logits_mlx( # pragma: no cover biased_logits_array = [] for i in range(batch_size): + if skip[i]: + biased_logits_array.append(logits[i]) + continue fill_next_token_bitmask(self._guides[i], self._bitmasks[i]) biased_logits = apply_token_bitmask( self.tensor_adapter.unsqueeze(logits[i]), self._bitmasks[i] # type: ignore @@ -109,7 +112,7 @@ def _bias_logits_mlx( # pragma: no cover return self.tensor_adapter.concatenate(biased_logits_array) def _bias_logits_torch( - self, batch_size: int, logits: TensorType + self, batch_size: int, logits: TensorType, skip: list[bool] ) -> TensorType: """Bias the logits for Torch tensors.""" from outlines_core.kernels.torch import ( @@ -118,6 +121,8 @@ def _bias_logits_torch( ) for i in range(batch_size): + if skip[i]: + continue fill_next_token_bitmask(self._guides[i], self._bitmasks[i]) self._bitmasks[i] = self.tensor_adapter.to_device( self._bitmasks[i], @@ -135,7 +140,7 @@ def _bias_logits_torch( return logits def _bias_logits_numpy( - self, batch_size: int, logits: TensorType + self, batch_size: int, logits: TensorType, skip: list[bool] ) -> TensorType: """Bias the logits for Numpy tensors.""" from outlines_core.kernels.numpy import ( @@ -144,6 +149,8 @@ def _bias_logits_numpy( ) for i in range(batch_size): + if skip[i]: + continue fill_next_token_bitmask(self._guides[i], self._bitmasks[i]) apply_token_bitmask_inplace( self.tensor_adapter.unsqueeze(logits[i]), # type: ignore @@ -153,7 +160,7 @@ def _bias_logits_numpy( return logits def process_logits( - self, input_ids: TensorType, logits: TensorType + self, input_ids: TensorType, logits: TensorType, skip: Union[list[bool], None] = None ) -> TensorType: """Use the guides to bias the logits. @@ -173,11 +180,16 @@ def process_logits( batch_size = self.tensor_adapter.shape(input_ids)[0] vocab_size = self.tensor_adapter.shape(logits)[1] + if skip is None: + skip = [False] * batch_size + if self.is_first_token: self._setup(batch_size, vocab_size) self.is_first_token = False else: for i in range(batch_size): + if skip[i]: + continue last_token_id = self.tensor_adapter.to_scalar(input_ids[i][-1]) # type: ignore if not self._guides[i].is_finished(): self._guides[i].advance( @@ -185,7 +197,7 @@ def process_logits( return_tokens=False ) - return self.bias_logits(batch_size, logits) + return self.bias_logits(batch_size, logits, skip) class OutlinesCoreBackend(BaseBackend): diff --git a/outlines/generator.py b/outlines/generator.py index f2e669d8f..bf85cdf8d 100644 --- a/outlines/generator.py +++ b/outlines/generator.py @@ -218,6 +218,9 @@ def __init__( model: SteerableModel, output_type: Optional[Any], backend_name: Optional[str] = None, + *, + end_thinking_tag: Optional[str] = None, + thinking_max_tokens: Optional[int] = None, ): """ Parameters @@ -241,12 +244,16 @@ def __init__( backend_name, model, cfg_string, + end_thinking_tag=end_thinking_tag, + thinking_max_tokens=thinking_max_tokens, ) elif isinstance(term, JsonSchema): self.logits_processor = get_json_schema_logits_processor( backend_name, model, term.schema, + end_thinking_tag=end_thinking_tag, + thinking_max_tokens=thinking_max_tokens, ) else: regex_string = to_regex(term) @@ -254,6 +261,8 @@ def __init__( backend_name, model, regex_string, + end_thinking_tag=end_thinking_tag, + thinking_max_tokens=thinking_max_tokens, ) @classmethod @@ -349,6 +358,8 @@ def Generator( backend: Optional[str] = None, *, processor: Optional[LogitsProcessorType] = None, + end_thinking_tag: Optional[str] = None, + thinking_max_tokens: Optional[int] = None, ) -> Union[SteerableGenerator, BlackBoxGenerator, AsyncBlackBoxGenerator]: """Create a generator for the given model and output parameters. @@ -389,7 +400,13 @@ def Generator( if processor is not None: return SteerableGenerator.from_processor(model, processor) # type: ignore else: - return SteerableGenerator(model, output_type, backend) # type: ignore + return SteerableGenerator( + model, + output_type, + backend, + end_thinking_tag=end_thinking_tag, + thinking_max_tokens=thinking_max_tokens + ) else: if processor is not None: raise NotImplementedError( diff --git a/outlines/models/base.py b/outlines/models/base.py index 2ad0407f3..58654a95f 100644 --- a/outlines/models/base.py +++ b/outlines/models/base.py @@ -82,6 +82,9 @@ def __call__( model_input: Any, output_type: Optional[Any] = None, backend: Optional[str] = None, + *, + end_thinking_tag: Optional[str] = None, + thinking_max_tokens: Optional[int] = None, **inference_kwargs: Any ) -> Any: """Call the model. @@ -119,13 +122,22 @@ def __call__( """ from outlines.generator import Generator - return Generator(self, output_type, backend)(model_input, **inference_kwargs) + return Generator( + self, + output_type, + backend, + end_thinking_tag=end_thinking_tag, + thinking_max_tokens=thinking_max_tokens + )(model_input, **inference_kwargs) def batch( self, model_input: List[Any], output_type: Optional[Any] = None, backend: Optional[str] = None, + *, + end_thinking_tag: Optional[str] = None, + thinking_max_tokens: Optional[int] = None, **inference_kwargs: Any ) -> List[Any]: """Make a batch call to the model (several inputs at once). @@ -164,7 +176,13 @@ def batch( """ from outlines import Generator - generator = Generator(self, output_type, backend) + generator = Generator( + self, + output_type, + backend, + end_thinking_tag=end_thinking_tag, + thinking_max_tokens=thinking_max_tokens + ) return generator.batch(model_input, **inference_kwargs) # type: ignore def stream( @@ -172,6 +190,9 @@ def stream( model_input: Any, output_type: Optional[Any] = None, backend: Optional[str] = None, + *, + end_thinking_tag: Optional[str] = None, + thinking_max_tokens: Optional[int] = None, **inference_kwargs: Any ) -> Iterator[Any]: """Stream a response from the model. @@ -212,7 +233,13 @@ def stream( """ from outlines import Generator - generator = Generator(self, output_type, backend) + generator = Generator( + self, + output_type, + backend, + end_thinking_tag=end_thinking_tag, + thinking_max_tokens=thinking_max_tokens + ) return generator.stream(model_input, **inference_kwargs) # type: ignore @abstractmethod diff --git a/outlines/processors/thinking_logits_processor.py b/outlines/processors/thinking_logits_processor.py new file mode 100644 index 000000000..230debb21 --- /dev/null +++ b/outlines/processors/thinking_logits_processor.py @@ -0,0 +1,42 @@ +from outlines.processors.base_logits_processor import OutlinesLogitsProcessor, TensorType + + +class ThinkingLogitsProcessor(OutlinesLogitsProcessor): + + def __init__(self, end_thinking_token_id: int, thinking_max_tokens: int, logits_processor: OutlinesLogitsProcessor): + super().__init__(logits_processor.tensor_library_name) + self.logits_processor = logits_processor + self.end_thinking_token_id = end_thinking_token_id + self.thinking_max_tokens = thinking_max_tokens + self.is_first_token = True + + def reset(self) -> None: + self.is_first_token = True + self.logits_processor.reset() + + def setup(self, batch_size: int) -> None: + self._is_thinking = [self.end_thinking_token_id is not None] * batch_size + self._num_tokens_generated = 0 + + def process_logits(self, input_ids: TensorType, logits: TensorType) -> TensorType: + + batch_size = self.tensor_adapter.shape(input_ids)[0] + + if self.is_first_token: + self.setup(batch_size) + self.is_first_token = False + else: + self._num_tokens_generated += 1 + for i in range(batch_size): + if not self._is_thinking[i]: + continue + latest_token_id = self.tensor_adapter.to_scalar(input_ids[i][-1]) + if latest_token_id == self.end_thinking_token_id: + self._is_thinking[i] = False + elif self._num_tokens_generated >= self.thinking_max_tokens: + logits[i][self.end_thinking_token_id] = float("inf") + + if all(self._is_thinking): + return logits + + return self.logits_processor.process_logits(input_ids, logits)