Skip to content

Commit f68e03e

Browse files
authored
[https://nvbugs/5452167][fix] Fix ngram padding issue (NVIDIA#6837)
Signed-off-by: Mike Iovine <[email protected]>
1 parent 12102e2 commit f68e03e

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

tensorrt_llm/_torch/speculative/ngram.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,13 @@ def get_draft_tokens(
8787
self,
8888
prefix: list[int],
8989
request_id: int,
90-
end_id: int,
90+
padding_id: int,
9191
max_sequence_length: int,
9292
):
9393
prefix_len = len(prefix)
9494
max_draft_token_length_this_step = max_sequence_length - 1 - prefix_len
9595
if max_draft_token_length_this_step <= 0: # No draft token is need if the prefix is long enough
96-
return [end_id]
96+
return [padding_id]
9797
if request_id not in self.start_index: # Extend start_index and pool for a new request
9898
self.start_index[request_id] = 0
9999
if not self.is_public_pool:
@@ -126,7 +126,7 @@ def get_draft_tokens(
126126
pool[pattern].add(new_match)
127127

128128
# Find match
129-
draft_tokens = [end_id] # fallback value
129+
draft_tokens = [padding_id] # fallback value
130130
for size in range(min(self.max_matching_ngram_size, prefix_len - 1), 0,
131131
-1):
132132
pattern = tuple(prefix[-size:])
@@ -194,11 +194,12 @@ def prepare_draft_tokens(
194194
draft_tokens = self.spec_resource_manager.get_draft_tokens(
195195
prefix,
196196
request.request_id,
197-
request.py_end_id,
198-
request.py_orig_prompt_len + request.py_max_new_tokens,
197+
padding_id=0,
198+
max_sequence_length=request.py_orig_prompt_len +
199+
request.py_max_new_tokens,
199200
)
200201
# Pad length to `self.max_draft_len`
201202
if len(draft_tokens) > 0:
202203
pad_length = self.max_draft_len - len(draft_tokens)
203-
draft_tokens.extend([request.py_end_id] * pad_length)
204+
draft_tokens.extend([0] * pad_length)
204205
request.py_draft_tokens = draft_tokens

tests/unittest/_torch/speculative/test_ngram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_llama_ngram(disable_overlap_scheduler: bool, use_cuda_graph: bool,
5454
"The capital of France is",
5555
"The president of the United States is",
5656
]
57-
sampling_params = SamplingParams(max_tokens=32)
57+
sampling_params = SamplingParams(max_tokens=32, ignore_eos=True)
5858

5959
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
6060
results_spec = llm_spec.generate(prompts, sampling_params)

0 commit comments

Comments
 (0)