Skip to content

Commit cef0553

Browse files
authored
Outlines guided generation (#1539)
This WIP PR starts to add grammar support via outlines, currently this PR supports very simple regex grammars and does not optimize for precompiling or caching grammar fsm's. todo: - [X] add simple outlines guidance to `NextTokenChooser` - [X] update protos for grammar - [X] update generation params API - [X] constrain simple grammar - [ ] support parsing more complex grammar into fsm - [ ] support all outline support grammar types - [ ] explore optimizations to avoid recompiling grammars guided request ```bash curl -s 'http://localhost:3000/generate' \ --header 'Content-Type: application/json' \ --data-raw '{ "inputs": "make an email for david: \n", "parameters": { "max_new_tokens": 6, "grammar": "[\\w-]+@([\\w-]+\\.)+[\\w-]+" } }' | jq ``` response ```json { "generated_text": "[email protected]" } ``` unguided request ```bash curl -s 'http://localhost:3000/generate' \ --header 'Content-Type: application/json' \ --data '{ "inputs": "make an email for david: \n", "parameters": { "max_new_tokens": 6 } }' | jq ``` response ```json { "generated_text": " email = 'david" } ```
1 parent 4c2848b commit cef0553

File tree

31 files changed

+1660
-53
lines changed

31 files changed

+1660
-53
lines changed

benchmark/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::app::App;
88
use crate::event::Event;
99
use crossterm::ExecutableCommand;
1010
use std::io;
11-
use text_generation_client::{NextTokenChooserParameters, ShardedClient};
11+
use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient};
1212
use tokenizers::Tokenizer;
1313
use tokio::sync::{broadcast, mpsc};
1414
use tui::backend::CrosstermBackend;
@@ -45,6 +45,8 @@ pub async fn run(
4545
repetition_penalty: repetition_penalty.unwrap_or(1.0),
4646
frequency_penalty: frequency_penalty.unwrap_or(0.0),
4747
watermark,
48+
grammar: String::new(),
49+
grammar_type: GrammarType::None as i32,
4850
};
4951

5052
// Initialize terminal properties

clients/python/text_generation/client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Response,
1111
Request,
1212
Parameters,
13+
Grammar,
1314
)
1415
from text_generation.errors import parse_error
1516

@@ -76,6 +77,7 @@ def generate(
7677
watermark: bool = False,
7778
decoder_input_details: bool = False,
7879
top_n_tokens: Optional[int] = None,
80+
grammar: Optional[Grammar] = None,
7981
) -> Response:
8082
"""
8183
Given a prompt, generate the following text
@@ -138,6 +140,7 @@ def generate(
138140
watermark=watermark,
139141
decoder_input_details=decoder_input_details,
140142
top_n_tokens=top_n_tokens,
143+
grammar=grammar,
141144
)
142145
request = Request(inputs=prompt, stream=False, parameters=parameters)
143146

@@ -169,6 +172,7 @@ def generate_stream(
169172
typical_p: Optional[float] = None,
170173
watermark: bool = False,
171174
top_n_tokens: Optional[int] = None,
175+
grammar: Optional[Grammar] = None,
172176
) -> Iterator[StreamResponse]:
173177
"""
174178
Given a prompt, generate the following stream of tokens
@@ -227,6 +231,7 @@ def generate_stream(
227231
typical_p=typical_p,
228232
watermark=watermark,
229233
top_n_tokens=top_n_tokens,
234+
grammar=grammar,
230235
)
231236
request = Request(inputs=prompt, stream=True, parameters=parameters)
232237

@@ -326,6 +331,7 @@ async def generate(
326331
watermark: bool = False,
327332
decoder_input_details: bool = False,
328333
top_n_tokens: Optional[int] = None,
334+
grammar: Optional[Grammar] = None,
329335
) -> Response:
330336
"""
331337
Given a prompt, generate the following text asynchronously
@@ -370,6 +376,7 @@ async def generate(
370376
Returns:
371377
Response: generated response
372378
"""
379+
373380
# Validate parameters
374381
parameters = Parameters(
375382
best_of=best_of,
@@ -388,6 +395,7 @@ async def generate(
388395
typical_p=typical_p,
389396
watermark=watermark,
390397
top_n_tokens=top_n_tokens,
398+
grammar=grammar,
391399
)
392400
request = Request(inputs=prompt, stream=False, parameters=parameters)
393401

@@ -417,6 +425,7 @@ async def generate_stream(
417425
typical_p: Optional[float] = None,
418426
watermark: bool = False,
419427
top_n_tokens: Optional[int] = None,
428+
grammar: Optional[Grammar] = None,
420429
) -> AsyncIterator[StreamResponse]:
421430
"""
422431
Given a prompt, generate the following stream of tokens asynchronously
@@ -475,6 +484,7 @@ async def generate_stream(
475484
typical_p=typical_p,
476485
watermark=watermark,
477486
top_n_tokens=top_n_tokens,
487+
grammar=grammar,
478488
)
479489
request = Request(inputs=prompt, stream=True, parameters=parameters)
480490

clients/python/text_generation/types.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,24 @@
11
from enum import Enum
22
from pydantic import BaseModel, validator
3-
from typing import Optional, List
3+
from typing import Optional, List, Union
44

55
from text_generation.errors import ValidationError
66

77

8+
# enum for grammar type
9+
class GrammarType(str, Enum):
10+
Json = "json"
11+
Regex = "regex"
12+
13+
14+
# Grammar type and value
15+
class Grammar(BaseModel):
16+
# Grammar type
17+
type: GrammarType
18+
# Grammar value
19+
value: Union[str, dict]
20+
21+
822
class Parameters(BaseModel):
923
# Activate logits sampling
1024
do_sample: bool = False
@@ -41,6 +55,8 @@ class Parameters(BaseModel):
4155
decoder_input_details: bool = False
4256
# Return the N most likely tokens at each step
4357
top_n_tokens: Optional[int] = None
58+
# grammar to use for generation
59+
grammar: Optional[Grammar] = None
4460

4561
@validator("best_of")
4662
def valid_best_of(cls, field_value, values):
@@ -109,6 +125,14 @@ def valid_top_n_tokens(cls, v):
109125
raise ValidationError("`top_n_tokens` must be strictly positive")
110126
return v
111127

128+
@validator("grammar")
129+
def valid_grammar(cls, v):
130+
if v is not None:
131+
if v.type == GrammarType.Regex and not v.value:
132+
raise ValidationError("`value` cannot be empty for `regex` grammar")
133+
if v.type == GrammarType.Json and not v.value:
134+
raise ValidationError("`value` cannot be empty for `json` grammar")
135+
return v
112136

113137
class Request(BaseModel):
114138
# Prompt
@@ -157,7 +181,7 @@ class Token(BaseModel):
157181
# Token text
158182
text: str
159183
# Logprob
160-
logprob: float
184+
logprob: Optional[float] = None
161185
# Is the token a special token
162186
# Can be used to ignore tokens when concatenating
163187
special: bool

docs/source/basic_tutorials/launcher.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,14 @@ Options:
378378
379379
[env: TOKENIZER_CONFIG_PATH=]
380380
381+
```
382+
## DISABLE_GRAMMAR_SUPPORT
383+
```shell
384+
--disable-grammar-support
385+
Disable outlines grammar constrained generation. This is a feature that allows you to generate text that follows a specific grammar
386+
387+
[env: DISABLE_GRAMMAR_SUPPORT=]
388+
381389
```
382390
## ENV
383391
```shell

integration-tests/conftest.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@
1616
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
1717

1818
from text_generation import AsyncClient
19-
from text_generation.types import Response, Details, InputToken, Token, BestOfSequence
19+
from text_generation.types import (
20+
Response,
21+
Details,
22+
InputToken,
23+
Token,
24+
BestOfSequence,
25+
Grammar,
26+
)
2027

2128
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
2229
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
@@ -224,6 +231,7 @@ def local_launcher(
224231
quantize: Optional[str] = None,
225232
trust_remote_code: bool = False,
226233
use_flash_attention: bool = True,
234+
disable_grammar_support: bool = False,
227235
dtype: Optional[str] = None,
228236
):
229237
port = random.randint(8000, 10_000)
@@ -247,6 +255,8 @@ def local_launcher(
247255

248256
env = os.environ
249257

258+
if disable_grammar_support:
259+
args.append("--disable-grammar-support")
250260
if num_shard is not None:
251261
args.extend(["--num-shard", str(num_shard)])
252262
if quantize is not None:
@@ -287,12 +297,15 @@ def docker_launcher(
287297
quantize: Optional[str] = None,
288298
trust_remote_code: bool = False,
289299
use_flash_attention: bool = True,
300+
disable_grammar_support: bool = False,
290301
dtype: Optional[str] = None,
291302
):
292303
port = random.randint(8000, 10_000)
293304

294305
args = ["--model-id", model_id, "--env"]
295306

307+
if disable_grammar_support:
308+
args.append("--disable-grammar-support")
296309
if num_shard is not None:
297310
args.extend(["--num-shard", str(num_shard)])
298311
if quantize is not None:
@@ -370,11 +383,22 @@ def docker_launcher(
370383
@pytest.fixture(scope="module")
371384
def generate_load():
372385
async def generate_load_inner(
373-
client: AsyncClient, prompt: str, max_new_tokens: int, n: int
386+
client: AsyncClient,
387+
prompt: str,
388+
max_new_tokens: int,
389+
n: int,
390+
seed: Optional[int] = None,
391+
grammar: Optional[Grammar] = None,
392+
stop_sequences: Optional[List[str]] = None,
374393
) -> List[Response]:
375394
futures = [
376395
client.generate(
377-
prompt, max_new_tokens=max_new_tokens, decoder_input_details=True
396+
prompt,
397+
max_new_tokens=max_new_tokens,
398+
decoder_input_details=True,
399+
seed=seed,
400+
grammar=grammar,
401+
stop_sequences=stop_sequences,
378402
)
379403
for _ in range(n)
380404
]
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
{
2+
"details": {
3+
"best_of_sequences": null,
4+
"finish_reason": "length",
5+
"generated_tokens": 10,
6+
"prefill": [
7+
{
8+
"id": 1,
9+
"logprob": null,
10+
"text": "<s>"
11+
},
12+
{
13+
"id": 4321,
14+
"logprob": -13.90625,
15+
"text": "Test"
16+
},
17+
{
18+
"id": 2009,
19+
"logprob": -12.328125,
20+
"text": "request"
21+
}
22+
],
23+
"seed": null,
24+
"tokens": [
25+
{
26+
"id": 13,
27+
"logprob": -2.0566406,
28+
"special": false,
29+
"text": "\n"
30+
},
31+
{
32+
"id": 13,
33+
"logprob": -1.5253906,
34+
"special": false,
35+
"text": "\n"
36+
},
37+
{
38+
"id": 29902,
39+
"logprob": -2.7578125,
40+
"special": false,
41+
"text": "I"
42+
},
43+
{
44+
"id": 4966,
45+
"logprob": -1.9033203,
46+
"special": false,
47+
"text": " hope"
48+
},
49+
{
50+
"id": 445,
51+
"logprob": -0.5019531,
52+
"special": false,
53+
"text": " this"
54+
},
55+
{
56+
"id": 6911,
57+
"logprob": -0.21264648,
58+
"special": false,
59+
"text": " helps"
60+
},
61+
{
62+
"id": 29991,
63+
"logprob": -0.5991211,
64+
"special": false,
65+
"text": "!"
66+
},
67+
{
68+
"id": 2803,
69+
"logprob": -0.37475586,
70+
"special": false,
71+
"text": " Let"
72+
},
73+
{
74+
"id": 592,
75+
"logprob": -0.018463135,
76+
"special": false,
77+
"text": " me"
78+
},
79+
{
80+
"id": 1073,
81+
"logprob": -0.0008597374,
82+
"special": false,
83+
"text": " know"
84+
}
85+
],
86+
"top_tokens": null
87+
},
88+
"generated_text": "\n\nI hope this helps! Let me know"
89+
}

0 commit comments

Comments
 (0)