Skip to content

Commit 43a2695

Browse files
feat: add logits processor support for trtllm backend (#2702)
Signed-off-by: Bhuvan Agrawal <[email protected]> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent 699996e commit 43a2695

File tree

10 files changed

+322
-30
lines changed

10 files changed

+322
-30
lines changed

components/backends/trtllm/README.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
4343
- [Client](#client)
4444
- [Benchmarking](#benchmarking)
4545
- [Multimodal Support](#multimodal-support)
46+
- [Logits Processing](#logits-processing)
4647
- [Performance Sweep](#performance-sweep)
4748

4849
## Feature Support Matrix
@@ -242,6 +243,63 @@ To benchmark your deployment with GenAI-Perf, see this utility script, configuri
242243

243244
Dynamo with the TensorRT-LLM backend supports multimodal models, enabling you to process both text and images (or pre-computed embeddings) in a single request. For detailed setup instructions, example requests, and best practices, see the [Multimodal Support Guide](./multimodal_support.md).
244245

246+
## Logits Processing
247+
248+
Logits processors let you modify the next-token logits at every decoding step (e.g., to apply custom constraints or sampling transforms). Dynamo provides a backend-agnostic interface and an adapter for TensorRT-LLM so you can plug in custom processors.
249+
250+
### How it works
251+
- **Interface**: Implement `dynamo.logits_processing.BaseLogitsProcessor` which defines `__call__(input_ids, logits)` and modifies `logits` in-place.
252+
- **TRT-LLM adapter**: Use `dynamo.trtllm.logits_processing.adapter.create_trtllm_adapters(...)` to convert Dynamo processors into TRT-LLM-compatible processors and assign them to `SamplingParams.logits_processor`.
253+
- **Examples**: See example processors in `lib/bindings/python/src/dynamo/logits_processing/examples/` ([temperature](../../../lib/bindings/python/src/dynamo/logits_processing/examples/temperature.py), [hello_world](../../../lib/bindings/python/src/dynamo/logits_processing/examples/hello_world.py)).
254+
255+
### Quick test: HelloWorld processor
256+
You can enable a test-only processor that forces the model to respond with "Hello world!". This is useful to verify the wiring without modifying your model or engine code.
257+
258+
```bash
259+
cd $DYNAMO_HOME/components/backends/trtllm
260+
export DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR=1
261+
./launch/agg.sh
262+
```
263+
264+
Notes:
265+
- When enabled, Dynamo initializes the tokenizer so the HelloWorld processor can map text to token IDs.
266+
- Expected chat response contains "Hello world".
267+
268+
### Bring your own processor
269+
Implement a processor by conforming to `BaseLogitsProcessor` and modify logits in-place. For example, temperature scaling:
270+
271+
```python
272+
from typing import Sequence
273+
import torch
274+
from dynamo.logits_processing import BaseLogitsProcessor
275+
276+
class TemperatureProcessor(BaseLogitsProcessor):
277+
def __init__(self, temperature: float = 1.0):
278+
if temperature <= 0:
279+
raise ValueError("Temperature must be positive")
280+
self.temperature = temperature
281+
282+
def __call__(self, input_ids: Sequence[int], logits: torch.Tensor):
283+
if self.temperature == 1.0:
284+
return
285+
logits.div_(self.temperature)
286+
```
287+
288+
Wire it into TRT-LLM by adapting and attaching to `SamplingParams`:
289+
290+
```python
291+
from dynamo.trtllm.logits_processing.adapter import create_trtllm_adapters
292+
from dynamo.logits_processing.examples import TemperatureProcessor
293+
294+
processors = [TemperatureProcessor(temperature=0.7)]
295+
sampling_params.logits_processor = create_trtllm_adapters(processors)
296+
```
297+
298+
### Current limitations
299+
- Per-request processing only (batch size must be 1); beam width > 1 is not supported.
300+
- Processors must modify logits in-place and not return a new tensor.
301+
- If your processor needs tokenization, ensure the tokenizer is initialized (do not skip tokenizer init).
302+
245303
## Performance Sweep
246304

247305
For detailed instructions on running comprehensive performance sweeps across both aggregated and disaggregated serving configurations, see the [TensorRT-LLM Benchmark Scripts for DeepSeek R1 model](./performance_sweeps/README.md). This guide covers recommended benchmarking setups, usage of provided scripts, and best practices for evaluating system performance.
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import logging
5+
from typing import List, Optional
6+
7+
import torch
8+
from tensorrt_llm.sampling_params import LogitsProcessor
9+
10+
from dynamo.logits_processing import BaseLogitsProcessor
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class TrtllmDynamoLogitsAdapter(LogitsProcessor):
16+
"""
17+
Adapter that wraps Dynamo BaseLogitsProcessor instances to work with TensorRT-LLM's logits processor interface.
18+
19+
Inherits from tensorrt_llm.LogitsProcessor and implements the required interface:
20+
__call__(self, req_ids: int, logits: torch.Tensor, ids: List[List[int]], stream_ptr, client_id: Optional[int])
21+
22+
This adapter maintains per-request state and converts between the interfaces.
23+
"""
24+
25+
def __init__(self, processor: BaseLogitsProcessor):
26+
super().__init__()
27+
self.processor = processor
28+
29+
def __call__(
30+
self,
31+
req_ids: int,
32+
logits: torch.Tensor,
33+
ids: List[List[int]],
34+
stream_ptr,
35+
client_id: Optional[int] = None,
36+
):
37+
"""
38+
TensorRT-LLM logits processor interface.
39+
40+
Args:
41+
req_ids: Request identifier
42+
logits: Logits tensor for current step
43+
ids: List of token sequences (batch of sequences)
44+
stream_ptr: CUDA stream pointer
45+
client_id: Optional client identifier
46+
47+
Returns:
48+
Modified logits tensor (in-place modification expected by TRT-LLM)
49+
"""
50+
stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr)
51+
try:
52+
with torch.cuda.stream(stream):
53+
if logits.shape[0] != 1:
54+
raise ValueError(
55+
f"This logits adapter only supports per-request logits processing. "
56+
f"Received logits with batch size {logits.shape[0]} expected 1"
57+
)
58+
if logits.shape[1] != 1:
59+
raise ValueError(
60+
"Logits processing with beam width > 1 is not supported"
61+
)
62+
# Call the processor which modifies the logits in-place
63+
self.processor(ids[0], logits[0, 0, :])
64+
65+
except Exception as e:
66+
logger.error(f"Error in logits processor for request {req_ids}: {e}")
67+
# Don't modify logits on error
68+
69+
# TRT-LLM expects void return (in-place modification)
70+
71+
72+
def create_trtllm_adapters(
73+
processors: List[BaseLogitsProcessor],
74+
) -> List[TrtllmDynamoLogitsAdapter]:
75+
"""
76+
Create TensorRT-LLM compatible adapters from Dynamo logits processors.
77+
78+
Args:
79+
processors: List of Dynamo BaseLogitsProcessor instances
80+
81+
Returns:
82+
List of TensorRT-LLM compatible logits processor adapters
83+
"""
84+
adapters = []
85+
for processor in processors:
86+
adapter = TrtllmDynamoLogitsAdapter(processor)
87+
adapters.append(adapter)
88+
return adapters

components/backends/trtllm/src/dynamo/trtllm/main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import asyncio
55
import logging
6+
import os
67
import signal
78
import sys
89

@@ -225,6 +226,12 @@ async def init(runtime: DistributedRuntime, config: Config):
225226
modelType = ModelType.Backend
226227
multimodal_processor = None
227228

229+
if os.getenv("DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR") == "1":
230+
# We need to initialize the tokenizer for the test logits processor
231+
# But detokenizing still happens in the rust engine, so we do _not_ want
232+
# to set default_sampling_params.detokenize to True.
233+
engine_args["skip_tokenizer_init"] = False
234+
228235
if modality == "multimodal":
229236
engine_args["skip_tokenizer_init"] = False
230237
modelType = ModelType.Chat

components/backends/trtllm/src/dynamo/trtllm/request_handlers/handler_base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import copy
1717
import logging
18+
import os
1819
from dataclasses import asdict, dataclass
1920
from enum import Enum
2021
from typing import Optional, Union
@@ -23,9 +24,11 @@
2324
from tensorrt_llm import SamplingParams
2425
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
2526

27+
from dynamo.logits_processing.examples import HelloWorldLogitsProcessor
2628
from dynamo.nixl_connect import Connector
2729
from dynamo.runtime.logging import configure_dynamo_logging
2830
from dynamo.trtllm.engine import TensorRTLLMEngine
31+
from dynamo.trtllm.logits_processing.adapter import create_trtllm_adapters
2932
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
3033
from dynamo.trtllm.publisher import Publisher
3134
from dynamo.trtllm.utils.disagg_utils import (
@@ -182,6 +185,12 @@ async def generate_locally(
182185
request_id = request.get("id") or request.get("request_id", "unknown-id")
183186
model_name = request.get("model", "unknown_model")
184187

188+
# Optional test-only logits processing (enable with DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR=1)
189+
if os.getenv("DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR") == "1":
190+
processors = [HelloWorldLogitsProcessor(self.engine.llm.tokenizer)]
191+
adapters = create_trtllm_adapters(processors)
192+
sampling_params.logits_processor = adapters
193+
185194
# NEW: Updated engine call to include multimodal data
186195
async for res in self.engine.llm.generate_async(
187196
inputs=processed_input, # Use the correctly extracted inputs

lib/bindings/python/src/dynamo/logits_processing/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
logits processors must implement.
99
"""
1010

11-
from typing import Protocol, Sequence
11+
from typing import Protocol, Sequence, runtime_checkable
1212

1313
import torch
1414

1515

16+
@runtime_checkable
1617
class BaseLogitsProcessor(Protocol):
1718
"""
1819
Protocol for logits processors in Dynamo.
@@ -25,15 +26,14 @@ def __call__(
2526
self,
2627
input_ids: Sequence[int],
2728
logits: torch.Tensor,
28-
) -> torch.Tensor:
29+
) -> None:
2930
"""
3031
Process the logits for the next token prediction.
3132
3233
Args:
3334
input_ids: The input token IDs generated so far.
3435
logits: The raw logits for the next token. Shape: (vocab_size,)
3536
36-
Returns:
37-
A tensor with the same shape, dtype, and device as `logits`.
37+
The processor is expected to modify the logits in-place.
3838
"""
3939
...
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from .hello_world import HelloWorldLogitsProcessor
5+
from .temperature import TemperatureProcessor
6+
7+
__all__ = ["TemperatureProcessor", "HelloWorldLogitsProcessor"]
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from typing import Sequence
5+
6+
import torch
7+
from transformers import PreTrainedTokenizerBase
8+
9+
from dynamo.logits_processing import BaseLogitsProcessor
10+
11+
RESPONSE = "Hello world!"
12+
13+
14+
class HelloWorldLogitsProcessor(BaseLogitsProcessor):
15+
"""
16+
Sample Logits Processor that always outputs a hardcoded
17+
response (`RESPONSE`), no matter the input
18+
"""
19+
20+
def __init__(self, tokenizer: PreTrainedTokenizerBase):
21+
self.tokenizer = tokenizer
22+
self.token_ids = tokenizer.encode(RESPONSE, add_special_tokens=False)
23+
self.eos_id = tokenizer.eos_token_id
24+
if self.eos_id is None:
25+
raise ValueError(
26+
"Tokenizer has no eos_token_id; HelloWorldLogitsProcessor requires one."
27+
)
28+
self.state = 0
29+
30+
def __call__(self, input_ids: Sequence[int], scores: torch.Tensor):
31+
mask = torch.full_like(scores, float("-inf"))
32+
33+
if self.state < len(self.token_ids):
34+
token_idx = self.token_ids[self.state]
35+
else:
36+
token_idx = self.eos_id
37+
# Allow only a single token to be output
38+
mask[token_idx] = 0.0
39+
40+
# The `scores` tensor *must* also be modified in-place
41+
scores.add_(mask)
42+
self.state += 1
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from typing import Sequence
5+
6+
import torch
7+
8+
from dynamo.logits_processing import BaseLogitsProcessor
9+
10+
11+
class TemperatureProcessor(BaseLogitsProcessor):
12+
"""
13+
Example logits processor that applies temperature scaling.
14+
15+
This is a simple demonstration of how to implement a logits processor
16+
that can be used with any Dynamo backend.
17+
"""
18+
19+
def __init__(self, temperature: float = 1.0):
20+
"""
21+
Args:
22+
temperature: Scaling factor. Higher values make distribution more uniform,
23+
lower values make it more peaked. Must be positive.
24+
"""
25+
if temperature <= 0:
26+
raise ValueError("Temperature must be positive")
27+
self.temperature = temperature
28+
29+
def __call__(self, input_ids: Sequence[int], logits: torch.Tensor):
30+
"""
31+
Apply temperature scaling to logits.
32+
33+
Args:
34+
input_ids: Token IDs generated so far (unused in this simple example)
35+
logits: Raw logits tensor from model
36+
37+
The processor is expected to modify the logits in-place.
38+
"""
39+
if self.temperature == 1.0:
40+
return
41+
logits.div_(self.temperature)

tests/serve/common.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
"""Common base classes and utilities for engine tests (vLLM, TRT-LLM, etc.)"""
55

6+
import os
67
from dataclasses import dataclass
78
from typing import Any, Callable, List
89

@@ -32,6 +33,11 @@ def create_payload_for_config(config: EngineConfig) -> Payload:
3233
3334
This provides the default implementation for text-only models.
3435
"""
36+
expected_response = (
37+
["Hello world"]
38+
if os.getenv("DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR") == "1"
39+
else ["AI"]
40+
)
3541
return Payload(
3642
payload_chat={
3743
"model": config.model,
@@ -54,5 +60,5 @@ def create_payload_for_config(config: EngineConfig) -> Payload:
5460
},
5561
repeat_count=3,
5662
expected_log=[],
57-
expected_response=["AI"],
63+
expected_response=expected_response,
5864
)

0 commit comments

Comments
 (0)