Skip to content

Commit c2bc39a

Browse files
authored
[TRTLLM-1302][feat] Topk logprobs for TRT backend and top1 logprob for PyT backend (#6097)
Signed-off-by: Pengyun Lin <[email protected]>
1 parent ef676fc commit c2bc39a

File tree

11 files changed

+222
-125
lines changed

11 files changed

+222
-125
lines changed

tensorrt_llm/executor/postproc_worker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
Optional, Union)
77

88
import zmq
9-
import zmq.asyncio
109

1110
from .._utils import nvtx_range_debug
1211
from ..bindings import executor as tllm

tensorrt_llm/executor/result.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class CompletionOutput:
9191
text (str): The generated output text. Defaults to "".
9292
token_ids (List[int], optional): The token ids of the generated output text. Defaults to [].
9393
cumulative_logprob (float, optional): The cumulative log probability of the generated output text. Defaults to None.
94-
logprobs (TokenLogprobs, optional): The log probabilities of the top probability words at each position if the logprobs are requested. Defaults to None.
94+
logprobs (TokenLogprobs | List[float], optional): The log probabilities of the top probability words at each position if the logprobs are requested. Defaults to None.
9595
prompt_logprobs (TokenLogprobs, optional): The log probabilities per prompt token. Defaults to None.
9696
finish_reason (Literal['stop', 'length', 'timeout', 'cancelled'], optional): The reason why the sequence is finished. Defaults to None.
9797
stop_reason (int, str, optional): The stop string or token id that caused the completion to stop, None if the completion finished for some other reason. Defaults to None.
@@ -102,14 +102,15 @@ class CompletionOutput:
102102
Attributes:
103103
length (int): The number of generated tokens.
104104
token_ids_diff (List[int]): Newly generated token ids.
105-
logprobs_diff (List[float]): Logprobs of newly generated tokens.
105+
logprobs_diff (TokenLogprobs | List[float]): Logprobs of newly generated tokens.
106106
text_diff (str): Newly generated tokens.
107107
"""
108108
index: int
109109
text: str = ""
110110
token_ids: Optional[List[int]] = field(default_factory=list)
111111
cumulative_logprob: Optional[float] = None
112-
logprobs: Optional[TokenLogprobs] = field(default_factory=list)
112+
logprobs: Optional[TokenLogprobs
113+
| List[float]] = field(default_factory=list)
113114
prompt_logprobs: Optional[TokenLogprobs] = field(default_factory=list)
114115
finish_reason: Optional[Literal['stop', 'length', 'timeout',
115116
'cancelled']] = None
@@ -141,7 +142,7 @@ def token_ids_diff(self) -> List[int]:
141142
return self.token_ids[self._last_token_ids_len:]
142143

143144
@property
144-
def logprobs_diff(self) -> List[float]:
145+
def logprobs_diff(self) -> TokenLogprobs | List[float]:
145146
return self.logprobs[self._last_logprobs_len:]
146147

147148

@@ -244,10 +245,12 @@ def _handle_sequence(self,
244245
output.cumulative_logprob = response_tensors.cum_log_probs[src_idx]
245246

246247
if logprobs_result:
248+
# update logprobs from ResponseWrapper (TRT top logprobs WAR)
249+
output._last_logprobs_len = len(output.logprobs)
247250
output.prompt_logprobs = logprobs_result.prompt
248-
output.logprobs = logprobs_result.generation
249-
250-
if response_tensors.log_probs is not None:
251+
output.logprobs += logprobs_result.generation
252+
elif response_tensors.log_probs is not None:
253+
# handle logprobs directly from response tensors
251254
output._last_logprobs_len = len(output.logprobs)
252255
output.logprobs = response_tensors.log_probs[src_idx]
253256
# overcome some WAR in the cpp executor

tensorrt_llm/serve/openai_protocol.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
498498
model: str
499499
frequency_penalty: Optional[float] = 0.0
500500
logit_bias: Optional[Dict[str, float]] = None
501-
logprobs: Optional[int] = None
501+
logprobs: Optional[bool] = False
502502
top_logprobs: Optional[int] = 0
503503
max_completion_tokens: Optional[int] = Field(default=None,
504504
validation_alias='max_tokens')
@@ -602,8 +602,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
602602

603603
# doc: end-chat-completion-extra-params
604604

605-
def to_sampling_params(self, vocab_size: int = 32000) -> SamplingParams:
606-
605+
def to_sampling_params(self,
606+
vocab_size: int = 32000,
607+
gather_generation_logits: bool = False,
608+
backend: Optional[str] = None) -> SamplingParams:
607609
sampling_params = SamplingParams(
608610
frequency_penalty=self.frequency_penalty,
609611
max_tokens=self.max_completion_tokens,
@@ -639,10 +641,20 @@ def to_sampling_params(self, vocab_size: int = 32000) -> SamplingParams:
639641

640642
# chat-completion-extra-params
641643
add_special_tokens=self.add_special_tokens,
642-
643-
# TODO: migrate to use logprobs and prompt_logprobs
644-
_return_log_probs=bool(self.logprobs),
645644
)
645+
if self.logprobs:
646+
logprobs = 1 if not self.top_logprobs else self.top_logprobs
647+
if backend == "pytorch":
648+
sampling_params.logprobs = logprobs
649+
else:
650+
if gather_generation_logits:
651+
sampling_params.logprobs = logprobs
652+
elif self.top_logprobs:
653+
raise ValueError(
654+
"`gather_generation_logits` must be `True` to use `top_logprobs`"
655+
)
656+
else:
657+
sampling_params._return_log_probs = True
646658
return sampling_params
647659

648660
@model_validator(mode='before')
@@ -667,9 +679,12 @@ def check_tool_choice(cls, data):
667679
@model_validator(mode="before")
668680
@classmethod
669681
def check_logprobs(cls, data):
670-
top_logprobs = data.get("top_logprobs")
671-
if top_logprobs is not None and top_logprobs > 0:
672-
raise ValueError("top_logprobs is not supported")
682+
if (top_logprobs := data.get("top_logprobs")) is not None:
683+
if top_logprobs < 0:
684+
raise ValueError("top_logprobs must be positive or zero")
685+
if not data.get("logprobs"):
686+
raise ValueError(
687+
"logprobs must be true when using top_logprobs")
673688
return data
674689

675690
@model_validator(mode="before")

tensorrt_llm/serve/openai_server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,9 @@ async def create_chat_response(
424424
# Pass the tokenizer vocabulary size so ``logit_bias`` can be
425425
# expanded into an embedding bias tensor in the sampler.
426426
sampling_params = request.to_sampling_params(
427-
vocab_size=self.tokenizer.tokenizer.vocab_size)
427+
vocab_size=self.tokenizer.tokenizer.vocab_size,
428+
gather_generation_logits=self.llm.args.gather_generation_logits,
429+
backend=self.llm.args.backend)
428430
# TODO: better way to enable metrics
429431
if len(os.getenv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH", "")) > 0:
430432
sampling_params.return_perf_metrics = True

tensorrt_llm/serve/postprocess_handlers.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ..executor import (DetokenizedGenerationResultBase, GenerationResult,
66
GenerationResultBase)
77
from ..executor.postproc_worker import PostprocArgs
8+
from ..executor.result import Logprob, TokenLogprobs
89
from ..llmapi.reasoning_parser import (BaseReasoningParser,
910
ReasoningParserFactory)
1011
from ..llmapi.tokenizer import TransformersTokenizer
@@ -39,6 +40,7 @@ class ChatPostprocArgs(PostprocArgs):
3940
tool_choice: Optional[Union[Literal["none"],
4041
ChatCompletionNamedToolChoiceParam]] = "none"
4142
return_logprobs: bool = False
43+
top_logprobs: bool = False
4244
stream_options: Optional[StreamOptions] = None
4345
last_message_content: Optional[str] = None
4446
reasoning_parser: Optional[str] = None
@@ -56,23 +58,38 @@ def from_request(cls, request: ChatCompletionRequest):
5658
tools=request.tools,
5759
tool_choice=request.tool_choice,
5860
stream_options=request.stream_options,
59-
return_logprobs=request.logprobs,
61+
return_logprobs=bool(request.logprobs),
62+
top_logprobs=bool(request.top_logprobs),
6063
)
6164

6265

6366
def create_logprobs(token_ids: List[int], tokenizer: TransformersTokenizer,
64-
logprobs: List[float]) -> ChatCompletionLogProbs:
67+
logprobs: List[float] | TokenLogprobs,
68+
top_logprobs: bool) -> ChatCompletionLogProbs:
6569
assert len(token_ids) == len(logprobs), \
6670
"token_ids and logprobs have different lengths"
6771
content: List[ChatCompletionLogProbsContent] = []
6872
for token_id, logprob in zip(token_ids, logprobs):
73+
logprob: float | dict[int, Logprob]
6974
token = tokenizer.decode(token_id)
70-
# returning multiple logprobs is not supported
71-
first_logprob = ChatCompletionLogProbsContent(
75+
chat_logprob = ChatCompletionLogProbsContent(
7276
token=token,
73-
logprob=max(logprob, -9999.0),
74-
bytes=list(token.encode("utf-8", errors="replace")))
75-
content.append(first_logprob)
77+
bytes=list(token.encode("utf-8", errors="replace")),
78+
)
79+
if isinstance(logprob, dict):
80+
if token_id in logprob:
81+
chat_logprob.logprob = max(logprob[token_id].logprob, -9999.0)
82+
if top_logprobs:
83+
chat_logprob.top_logprobs = [
84+
ChatCompletionLogProbsContent(
85+
token=(tk := tokenizer.decode(tid)),
86+
logprob=max(logprob.logprob, -9999.0),
87+
bytes=list(tk.encode("utf-8", errors="replace")))
88+
for tid, logprob in logprob.items()
89+
]
90+
else:
91+
chat_logprob.logprob = max(logprob, -9999.0)
92+
content.append(chat_logprob)
7693
chat_logprobs = ChatCompletionLogProbs(content=content)
7794
return chat_logprobs
7895

@@ -178,7 +195,7 @@ def yield_first_chat(num_tokens: int,
178195
logprobs = output.logprobs_diff
179196
token_ids = output.token_ids_diff
180197
choice.logprobs = create_logprobs(token_ids, args.tokenizer,
181-
logprobs)
198+
logprobs, args.top_logprobs)
182199
if output.finish_reason is not None:
183200
choice.finish_reason = output.finish_reason
184201
choice.stop_reason = output.stop_reason
@@ -247,7 +264,8 @@ def chat_response_post_processor(
247264

248265
if args.return_logprobs:
249266
choice.logprobs = create_logprobs(output.token_ids, args.tokenizer,
250-
output.logprobs)
267+
output.logprobs,
268+
args.top_logprobs)
251269
choices.append(choice)
252270

253271
if args.echo and args.last_message_content:

tests/integration/defs/test_e2e.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,6 +1548,20 @@ def test_trtllm_serve_lora_example(llm_root, llm_venv):
15481548
str(test_root / "_test_trtllm_serve_lora.py")])
15491549

15501550

1551+
@pytest.mark.parametrize("backend", ["pytorch", "trt"])
1552+
def test_trtllm_serve_top_logprobs(llm_root, llm_venv, backend: str):
1553+
example_root = Path(os.path.join(llm_root, "examples", "serve"))
1554+
test_root = unittest_path() / "llmapi" / "apps"
1555+
llm_venv.run_cmd([
1556+
"-m", "pip", "install", "-r",
1557+
os.path.join(example_root, "requirements.txt")
1558+
])
1559+
llm_venv.run_cmd([
1560+
"-m", "pytest",
1561+
str(test_root / "_test_trtllm_serve_top_logprobs.py"), "-k", backend
1562+
])
1563+
1564+
15511565
@pytest.mark.parametrize("backend", ["pytorch", "trt"])
15521566
def test_openai_misc_example(llm_root, llm_venv, backend: str):
15531567
test_root = unittest_path() / "llmapi" / "apps"

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ l0_a10:
3333
- test_e2e.py::test_openai_lora
3434
- test_e2e.py::test_trtllm_serve_multimodal_example
3535
- test_e2e.py::test_trtllm_serve_lora_example
36+
- test_e2e.py::test_trtllm_serve_top_logprobs[pytorch]
3637
- test_e2e.py::test_openai_misc_example[pytorch]
3738
- test_e2e.py::test_openai_reasoning[pytorch]
3839
- test_e2e.py::test_openai_completions_example[pytorch]
@@ -106,6 +107,7 @@ l0_a10:
106107
- llmapi/test_llm_examples.py::test_llmapi_server_example
107108
- llmapi/test_llm_examples.py::test_llmapi_kv_cache_connector[Qwen2-0.5B]
108109
- test_e2e.py::test_trtllm_serve_example
110+
- test_e2e.py::test_trtllm_serve_top_logprobs[trt]
109111
- test_e2e.py::test_openai_misc_example[trt]
110112
- test_e2e.py::test_openai_completions_example[trt]
111113
- test_e2e.py::test_openai_chat_example[trt]

tests/unittest/api_stability/references/completion_output.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ properties:
1616
annotation: int
1717
default: inspect._empty
1818
logprobs_diff:
19-
annotation: List[float]
19+
annotation: list[dict[int, tensorrt_llm.executor.result.Logprob]] | List[float]
2020
default: inspect._empty
2121
text_diff:
2222
annotation: str

tests/unittest/api_stability/references_committed/completion_output.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ methods:
2020
annotation: Optional[torch.Tensor]
2121
default: null
2222
logprobs:
23-
annotation: Optional[list[dict[int, tensorrt_llm.executor.result.Logprob]]]
23+
annotation: Optional[list[dict[int, tensorrt_llm.executor.result.Logprob]] | List[float]]
2424
default: null
2525
prompt_logprobs:
2626
annotation: Optional[list[dict[int, tensorrt_llm.executor.result.Logprob]]]

tests/unittest/llmapi/apps/_test_openai_chat.py

Lines changed: 5 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -141,42 +141,14 @@ def test_single_chat_session(client: openai.OpenAI, model_name: str):
141141
message = chat_completion.choices[0].message
142142
assert message.content is not None
143143
assert message.role == "assistant"
144-
145-
146-
def test_single_chat_session_with_logprobs(client: openai.OpenAI,
147-
model_name: str, backend: str):
148-
if backend == "pytorch":
149-
pytest.skip("Logprobs are not supported in PyTorch backend yet")
150-
151-
messages = [{
152-
"role": "system",
153-
"content": "you are a helpful assistant"
154-
}, {
155-
"role": "user",
156-
"content": "what is 1+1?"
157-
}]
158-
144+
# test logprobs
159145
chat_completion = client.chat.completions.create(
160146
model=model_name,
161147
messages=messages,
162148
max_completion_tokens=10,
163149
logprobs=True,
164150
)
165-
assert chat_completion.id is not None
166-
assert len(chat_completion.choices) == 1
167-
message = chat_completion.choices[0].message
168-
assert message.content is not None
169-
assert message.role == "assistant"
170-
# test logprobs
171151
logprobs = chat_completion.choices[0].logprobs.content
172-
finish_reason = chat_completion.choices[0].finish_reason
173-
if finish_reason == "length":
174-
assert len(logprobs) == 10
175-
elif finish_reason == "stop":
176-
assert len(logprobs) <= 10
177-
else:
178-
raise RuntimeError(
179-
f"finish_reason {finish_reason} not in [length, stop]")
180152
for logprob in logprobs:
181153
assert logprob.token is not None
182154
assert logprob.logprob is not None
@@ -204,10 +176,11 @@ def test_multi_turn_dialogue(client: openai.OpenAI, model_name: str):
204176
assert message.content is not None and len(message.content) >= 0
205177

206178

207-
def test_multiple_response(client: openai.OpenAI, model_name: str,
208-
backend: str):
179+
def test_multiple_responses(client: openai.OpenAI, model_name: str,
180+
backend: str):
209181
if backend == "pytorch":
210-
pytest.skip("Beam search is not supported in PyTorch backend yet")
182+
pytest.skip(
183+
"Multiple responses are not supported in PyTorch backend yet")
211184

212185
messages = [{
213186
"role": "system",
@@ -252,70 +225,6 @@ async def test_chat_streaming(async_client: openai.AsyncOpenAI,
252225
"content": "what is 1+1?"
253226
}]
254227

255-
chat_completion = await async_client.chat.completions.create(
256-
model=model_name,
257-
messages=messages,
258-
max_completion_tokens=10,
259-
temperature=0.0,
260-
logprobs=False,
261-
)
262-
output = chat_completion.choices[0].message.content
263-
_finish_reason = chat_completion.choices[0].finish_reason
264-
265-
# test streaming
266-
stream = await async_client.chat.completions.create(
267-
model=model_name,
268-
messages=messages,
269-
max_completion_tokens=10,
270-
temperature=0.0,
271-
logprobs=False,
272-
stream=True,
273-
)
274-
str_chunks: List[str] = []
275-
276-
finish_reason_counter = 0
277-
finish_reason: str = None
278-
async for chunk in stream:
279-
choice = chunk.choices[0]
280-
delta = choice.delta
281-
if choice.finish_reason is not None:
282-
finish_reason_counter += 1
283-
finish_reason = choice.finish_reason
284-
if delta.role:
285-
assert delta.role == "assistant"
286-
if delta.content:
287-
str_chunks.append(delta.content)
288-
# test finish_reason
289-
if delta.content == "":
290-
assert finish_reason == "stop"
291-
assert finish_reason_counter == 1
292-
assert finish_reason == _finish_reason
293-
num_tokens = len(str_chunks)
294-
if finish_reason == "length":
295-
assert num_tokens == 10
296-
elif finish_reason == "stop":
297-
assert num_tokens <= 10
298-
else:
299-
raise RuntimeError(
300-
f"finish_reason {finish_reason} not in [length, stop]")
301-
# test generated tokens
302-
assert "".join(str_chunks) == output
303-
304-
305-
@pytest.mark.asyncio(loop_scope="module")
306-
async def test_chat_streaming_with_logprobs(async_client: openai.AsyncOpenAI,
307-
model_name: str, backend: str):
308-
if backend == "pytorch":
309-
pytest.skip("Logprobs are not supported in PyTorch backend yet")
310-
311-
messages = [{
312-
"role": "system",
313-
"content": "you are a helpful assistant"
314-
}, {
315-
"role": "user",
316-
"content": "what is 1+1?"
317-
}]
318-
319228
chat_completion = await async_client.chat.completions.create(
320229
model=model_name,
321230
messages=messages,

0 commit comments

Comments
 (0)