Skip to content

Commit 61da2da

Browse files
authored
[TRTLLM-6761][refactor] Replace LogitBiasLogitsProcessor with embedding bias tensor system (#6464)
Signed-off-by: Venky Ganesh <[email protected]>
1 parent 6a9b4b1 commit 61da2da

File tree

11 files changed

+423
-125
lines changed

11 files changed

+423
-125
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,14 @@ def __init__(
336336
exclude_last_generation_logits)
337337
self.child_requests = []
338338

339+
self._py_embedding_bias_1d = None
340+
if hasattr(self, 'embedding_bias') and self.embedding_bias is not None:
341+
# Pre-squeeze to 1D if needed (remove batch dimension)
342+
if self.embedding_bias.dim() > 1:
343+
self._py_embedding_bias_1d = self.embedding_bias.squeeze(0)
344+
else:
345+
self._py_embedding_bias_1d = self.embedding_bias
346+
339347
def is_generation_only_request(self):
340348
return self.py_llm_request_type == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY
341349

@@ -463,9 +471,7 @@ def executor_request_to_llm_request(
463471
is_streaming=executor_request.streaming,
464472
end_id=executor_request.end_id,
465473
pad_id=executor_request.pad_id,
466-
embedding_bias=torch.tensor(executor_request.embedding_bias,
467-
dtype=torch.int32)
468-
if executor_request.embedding_bias else None,
474+
embedding_bias=executor_request.embedding_bias,
469475
bad_words_list=torch.tensor(
470476
convert_wordlist(executor_request.bad_words), dtype=torch.int32)
471477
if executor_request.bad_words else None,

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,58 @@ def append_eagle3(tokens: torch.Tensor, model_outputs):
391391
d2t = model_outputs["d2t"][tokens]
392392
tokens += d2t
393393

394+
@staticmethod
395+
def _apply_embedding_bias(
396+
logits: torch.Tensor,
397+
requests: list[LlmRequest],
398+
steps_per_request: list[int] = None) -> torch.Tensor:
399+
"""Apply embedding bias (aka logit bias) to logits.
400+
If steps_per_request is None, assumes 1 step per request (non-batched path).
401+
"""
402+
# Collect biases and their associated data
403+
bias_list = []
404+
bias_data = [] # Either indices (fast path) or steps (batched path)
405+
406+
for i, req in enumerate(requests):
407+
bias = req._py_embedding_bias_1d
408+
if bias is not None:
409+
bias_list.append(bias)
410+
bias_data.append(i if steps_per_request is
411+
None else steps_per_request[i])
412+
413+
if not bias_list:
414+
return logits
415+
416+
bias_tensor = torch.stack(bias_list).to(logits.device,
417+
non_blocking=True)
418+
logits = logits.clone()
419+
420+
if steps_per_request is None:
421+
# Fast path: direct indexing
422+
indices = torch.tensor(bias_data, device=logits.device)
423+
logits[indices] += bias_tensor
424+
else:
425+
# Batched path: expand biases and use boolean mask
426+
expanded_biases = torch.repeat_interleave(bias_tensor,
427+
torch.tensor(
428+
bias_data,
429+
device=logits.device),
430+
dim=0)
431+
432+
mask = torch.zeros(sum(steps_per_request),
433+
dtype=torch.bool,
434+
device=logits.device)
435+
offset = 0
436+
for i, req in enumerate(requests):
437+
steps = steps_per_request[i]
438+
if req._py_embedding_bias_1d is not None:
439+
mask[offset:offset + steps] = True
440+
offset += steps
441+
442+
logits[mask] += expanded_biases
443+
444+
return logits
445+
394446
def _process_requests(self,
395447
requests: list[LlmRequest],
396448
model_outputs: dict[str, torch.Tensor],
@@ -411,6 +463,7 @@ def _process_requests(self,
411463

412464
if fast_path:
413465
logits = raw_logits[:len(requests)]
466+
logits = self._apply_embedding_bias(logits, requests)
414467
next_tokens = torch.argmax(logits, dim=-1)
415468
self.append_eagle3(next_tokens, model_outputs)
416469
int_next_tokens = next_tokens.to(torch.int, non_blocking=True)
@@ -430,17 +483,29 @@ def _process_requests(self,
430483

431484
if batched_strategy is not None:
432485
logits = raw_logits[:sum_steps]
486+
# Collect steps per request for batched strategy
487+
steps_per_request = [
488+
1 + len(req.py_draft_tokens) for req in requests
489+
]
490+
logits = self._apply_embedding_bias(logits, requests,
491+
steps_per_request)
433492
batched_next_tokens, batched_softmax = sample(
434493
batched_strategy, logits)
435494
self.append_eagle3(batched_next_tokens, model_outputs)
436495

437496
offset = 0
438-
for strategy, slot, steps in zip(strategies, seq_slots, num_steps):
497+
for i, (strategy, slot,
498+
steps) in enumerate(zip(strategies, seq_slots, num_steps)):
439499
input_slice = slice(offset, offset + steps)
440500
logits = raw_logits[input_slice]
501+
502+
req = requests[i]
503+
441504
if batched_next_tokens is None:
505+
logits = self._apply_embedding_bias(logits, [req])
442506
next_tokens, softmax = sample(strategy, logits)
443507
else:
508+
# Batched processing already applied bias, just use the results
444509
next_tokens = batched_next_tokens[input_slice]
445510
softmax = batched_softmax[input_slice]
446511
current_slice = slice(0, steps), slot, beam

tensorrt_llm/sampling_params.py

Lines changed: 7 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass, field, fields
5-
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
5+
from typing import List, NamedTuple, Optional, Tuple, Union
66

77
import torch
88
from pydantic import BaseModel
@@ -108,55 +108,6 @@ def __call__(
108108
pass # noqa
109109

110110

111-
class LogitBiasLogitsProcessor(LogitsProcessor):
112-
def __init__(self, logit_bias: Dict[str, float]) -> None:
113-
super().__init__()
114-
self.logit_bias = logit_bias
115-
self.tokens_to_adjust = self.process_logit_bias(logit_bias)
116-
if not self.tokens_to_adjust:
117-
raise ValueError("Empty logit_bias provided - no tokens to adjust")
118-
119-
def process_logit_bias(self, logit_bias: Dict[str, float]) -> Dict[int, float]:
120-
valid = {}
121-
invalid = {}
122-
123-
for k, v in logit_bias.items():
124-
try:
125-
token_id = int(k)
126-
valid[token_id] = v
127-
except (ValueError, TypeError):
128-
invalid[k] = v
129-
130-
if invalid:
131-
raise ValueError(
132-
f"Invalid token_ids in logit_bias: {list(invalid.keys())}. "
133-
f"All keys must be integers."
134-
)
135-
return valid
136-
137-
def __call__(
138-
self,
139-
req_id: int,
140-
logits: torch.Tensor,
141-
token_ids: List[List[int]],
142-
stream_ptr: Optional[int],
143-
client_id: Optional[int],
144-
) -> None:
145-
vocab_size = logits.size(-1)
146-
token_ids_list = list(self.tokens_to_adjust.keys())
147-
bias_values = torch.tensor(list(self.tokens_to_adjust.values()), device=logits.device)
148-
149-
invalid_token_ids = [tid for tid in token_ids_list if tid >= vocab_size]
150-
if invalid_token_ids:
151-
raise ValueError(
152-
f"Token ID(s) {invalid_token_ids} exceed vocabulary size (vocab_size={vocab_size})"
153-
)
154-
155-
stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr)
156-
with torch.cuda.stream(stream):
157-
logits[:, :, token_ids_list] += bias_values
158-
159-
160111
@dataclass(slots=True, kw_only=True)
161112
class AdditionalModelOutput:
162113
"""An additional output to gather from the model.
@@ -328,6 +279,12 @@ def __post_init__(self):
328279

329280
self.best_of = self.best_of or self.n
330281

282+
if self.embedding_bias is not None:
283+
if isinstance(self.embedding_bias, torch.Tensor):
284+
self.embedding_bias = self.embedding_bias.detach().clone()
285+
else:
286+
self.embedding_bias = torch.tensor(self.embedding_bias, dtype=torch.float32)
287+
331288
self._validate()
332289

333290
def _validate(self):

tensorrt_llm/serve/openai_protocol.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import uuid
66
from typing import Any, Dict, List, Literal, Optional, Union
77

8+
import torch
89
from openai.types.chat import \
910
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
1011
from openai.types.chat import \
@@ -16,7 +17,34 @@
1617
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
1718
from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams
1819

19-
from ..sampling_params import LogitBiasLogitsProcessor
20+
21+
def _logit_bias_to_embedding_bias(logit_bias: Optional[Dict[str, float]],
22+
vocab_size: int) -> Optional[torch.Tensor]:
23+
"""Convert OpenAI logit_bias dict to embedding_bias tensor for sampling."""
24+
if logit_bias is None:
25+
return None
26+
27+
# Create 1D zeros tensor as expected by executor API (will be unsqueezed to [1, vocab_size] internally)
28+
embedding_bias = torch.zeros(vocab_size, dtype=torch.float32)
29+
30+
# Apply biases for specified token IDs
31+
for token_str, bias in logit_bias.items():
32+
try:
33+
token_id = int(token_str)
34+
if 0 <= token_id < vocab_size:
35+
embedding_bias[token_id] = bias
36+
else:
37+
raise ValueError(
38+
f"Token ID {token_id} out of vocabulary range [0, {vocab_size})"
39+
)
40+
except ValueError as e:
41+
if "invalid literal" in str(e):
42+
raise ValueError(
43+
f"Invalid logit_bias key '{token_str}': must be a valid integer token ID"
44+
)
45+
raise
46+
47+
return embedding_bias
2048

2149

2250
class OpenAIBaseModel(BaseModel):
@@ -225,7 +253,7 @@ class CompletionRequest(OpenAIBaseModel):
225253

226254
# doc: end-completion-extra-params
227255

228-
def to_sampling_params(self) -> SamplingParams:
256+
def to_sampling_params(self, vocab_size: int = 32000) -> SamplingParams:
229257
sampling_params = SamplingParams(
230258
best_of=self.best_of,
231259
frequency_penalty=self.frequency_penalty,
@@ -258,8 +286,8 @@ def to_sampling_params(self) -> SamplingParams:
258286
detokenize=self.detokenize,
259287

260288
# logits_bias
261-
logits_processor=None if not self.logit_bias else
262-
LogitBiasLogitsProcessor(self.logit_bias),
289+
embedding_bias=_logit_bias_to_embedding_bias(
290+
self.logit_bias, vocab_size),
263291

264292
# completion-extra-params
265293
add_special_tokens=self.add_special_tokens,
@@ -521,7 +549,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
521549

522550
# doc: end-chat-completion-extra-params
523551

524-
def to_sampling_params(self) -> SamplingParams:
552+
def to_sampling_params(self, vocab_size: int = 32000) -> SamplingParams:
525553

526554
sampling_params = SamplingParams(
527555
frequency_penalty=self.frequency_penalty,
@@ -553,8 +581,8 @@ def to_sampling_params(self) -> SamplingParams:
553581
self.response_format),
554582

555583
# logits_bias
556-
logits_processor=None if not self.logit_bias else
557-
LogitBiasLogitsProcessor(self.logit_bias),
584+
embedding_bias=_logit_bias_to_embedding_bias(
585+
self.logit_bias, vocab_size),
558586

559587
# chat-completion-extra-params
560588
add_special_tokens=self.add_special_tokens,

tensorrt_llm/serve/openai_server.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,13 @@ async def create_chat_response(
253253
tool_dicts = None if request.tools is None else [
254254
tool.model_dump() for tool in request.tools
255255
]
256-
sampling_params = request.to_sampling_params()
256+
# Pass the tokenizer vocabulary size so ``logit_bias`` can be
257+
# expanded into an embedding bias tensor in the sampler.
258+
sampling_params = request.to_sampling_params(
259+
vocab_size=self.tokenizer.tokenizer.vocab_size)
257260
# TODO: better way to enable metrics
258261
if len(os.getenv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH", "")) > 0:
259262
sampling_params.return_perf_metrics = True
260-
261263
postproc_args = ChatPostprocArgs.from_request(request)
262264
disaggregated_params = to_llm_disaggregated_params(request.disaggregated_params)
263265

@@ -406,7 +408,10 @@ async def generator_wrapper(generator: AsyncIterator[Any]):
406408

407409
promises: List[RequestOutput] = []
408410
postproc_params_collection: List[Optional[PostprocParams]] = []
409-
sampling_params = request.to_sampling_params()
411+
# Pass the tokenizer vocabulary size so ``logit_bias`` can be
412+
# expanded into an embedding bias tensor in the sampler.
413+
sampling_params = request.to_sampling_params(
414+
vocab_size=self.tokenizer.tokenizer.vocab_size)
410415
# TODO: better way to enable metrics
411416
if len(os.getenv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH", "")) > 0:
412417
sampling_params.return_perf_metrics = True

tests/integration/defs/test_e2e.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,18 +1399,20 @@ def test_openai_misc_example(llm_root, llm_venv, backend: str):
13991399
@pytest.mark.parametrize("backend", ["pytorch", "trt"])
14001400
def test_openai_completions_example(llm_root, llm_venv, backend: str):
14011401
test_root = unittest_path() / "llmapi" / "apps"
1402+
filter_expr = f"{backend} and not sampler"
14021403
llm_venv.run_cmd([
14031404
"-m", "pytest",
1404-
str(test_root / "_test_openai_completions.py"), "-k", backend
1405+
str(test_root / "_test_openai_completions.py"), "-k", filter_expr
14051406
])
14061407

14071408

14081409
@pytest.mark.parametrize("backend", ["pytorch", "trt"])
14091410
def test_openai_chat_example(llm_root, llm_venv, backend: str):
14101411
test_root = unittest_path() / "llmapi" / "apps"
1412+
filter_expr = f"{backend} and not sampler"
14111413
llm_venv.run_cmd([
14121414
"-m", "pytest",
1413-
str(test_root / "_test_openai_chat.py"), "-k", backend
1415+
str(test_root / "_test_openai_chat.py"), "-k", filter_expr
14141416
])
14151417

14161418

@@ -1423,6 +1425,24 @@ def test_openai_reasoning(llm_root, llm_venv, backend: str):
14231425
])
14241426

14251427

1428+
@pytest.mark.parametrize("sampler", ["torch_sampler", "trtllm_sampler"])
1429+
def test_openai_completions_with_logit_bias(llm_root, llm_venv, sampler: str):
1430+
test_root = unittest_path() / "llmapi" / "apps"
1431+
llm_venv.run_cmd([
1432+
"-m", "pytest",
1433+
str(test_root / "_test_openai_completions.py"), "-k", sampler
1434+
])
1435+
1436+
1437+
@pytest.mark.parametrize("sampler", ["torch_sampler", "trtllm_sampler"])
1438+
def test_openai_chat_with_logit_bias(llm_root, llm_venv, sampler: str):
1439+
test_root = unittest_path() / "llmapi" / "apps"
1440+
llm_venv.run_cmd([
1441+
"-m", "pytest",
1442+
str(test_root / "_test_openai_chat.py"), "-k", sampler
1443+
])
1444+
1445+
14261446
def test_openai_lora(llm_root, llm_venv):
14271447
test_root = unittest_path() / "llmapi" / "apps"
14281448
llm_venv.run_cmd(["-m", "pytest", str(test_root / "_test_openai_lora.py")])

tests/integration/test_lists/test-db/l0_l40s.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ l0_l40s:
3333
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio]
3434
- test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
3535
- test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
36+
- test_e2e.py::test_openai_completions_with_logit_bias[torch_sampler]
37+
- test_e2e.py::test_openai_chat_with_logit_bias[torch_sampler]
38+
- test_e2e.py::test_openai_completions_with_logit_bias[trtllm_sampler]
39+
- test_e2e.py::test_openai_chat_with_logit_bias[trtllm_sampler]
3640
- condition:
3741
ranges:
3842
system_gpu_count:

0 commit comments

Comments
 (0)