Skip to content

Commit b8f036f

Browse files
authored
[TRTLLM-6650][fix] Enhance CUDA graph + Beam search to correctly handle padding (#6665)
Signed-off-by: Stefan Niebler <[email protected]>
1 parent e251f7c commit b8f036f

File tree

3 files changed

+26
-15
lines changed

3 files changed

+26
-15
lines changed

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
attn_metadata: AttentionMetadata,
3535
spec_metadata: Optional[SpecMetadata] = None,
3636
use_mrope: bool = False,
37+
max_beam_width: int = 1,
3738
) -> None:
3839
"""
3940
Stores a CUDA graph and its associated input buffers.
@@ -49,19 +50,21 @@ def __init__(
4950
e.g. FlashInfer cause graph breaks).
5051
"""
5152
self.batch_size = batch_size
52-
53+
self.max_beam_width = max_beam_width
5354
# [CUDA graph spec decode padding]
5455
# We pad input IDs/position IDs to the maximum draft length (token per request).
5556
# We're forced to do this because we cannot reallocate inputs over many graph runs.
5657
token_per_request = spec_metadata.max_draft_len + 1 if spec_metadata is not None else 1
5758

5859
# Using ones instead of zeros prevents NaNs in e.g. Deepseek
59-
self.input_ids = torch.ones((batch_size * token_per_request, ),
60-
device=device,
61-
dtype=torch.int32)
62-
self.position_ids = torch.zeros((1, batch_size * token_per_request),
63-
device=device,
64-
dtype=torch.int32)
60+
self.input_ids = torch.ones(
61+
(batch_size * max_beam_width * token_per_request, ),
62+
device=device,
63+
dtype=torch.int32)
64+
self.position_ids = torch.zeros(
65+
(1, batch_size * max_beam_width * token_per_request),
66+
device=device,
67+
dtype=torch.int32)
6568
self.mrope_position_deltas = torch.zeros(
6669
(batch_size,
6770
1), device=device, dtype=torch.int32) if use_mrope else None

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -846,8 +846,8 @@ def _get_padded_batch(
846846
spec_resource_manager: Optional[BaseResourceManager] = None) -> int:
847847
can_run_cuda_graph = scheduled_requests.can_run_cuda_graph
848848
batch_size = scheduled_requests.batch_size
849-
# The number of sequences in the batch is the number of prompts times the beam width.
850-
new_batch_size = batch_size * self.max_beam_width
849+
new_batch_size = batch_size
850+
851851
if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1:
852852
graph_batch_size = self.dist.tp_allgather(
853853
[can_run_cuda_graph, batch_size])
@@ -981,8 +981,8 @@ def _maybe_get_cuda_graph(
981981
self._cuda_graphs[batch_size] = {}
982982

983983
self._cuda_graphs[batch_size][draft_len] = DecodingCUDAGraphRunner(
984-
num_sequences_in_batch, "cuda", attn_metadata, spec_metadata,
985-
self.use_mrope)
984+
batch_size, "cuda", attn_metadata, spec_metadata, self.use_mrope,
985+
self.max_beam_width)
986986
return self._cuda_graphs[batch_size][draft_len]
987987

988988
def __del__(self) -> None:
@@ -1376,8 +1376,11 @@ def _prepare_tp_inputs(
13761376
gather_ids.append(len(position_ids) - 1)
13771377

13781378
request_ids.append(request.py_request_id)
1379-
gen_request_seq_slots.append(request.py_seq_slot)
13801379
request.py_batch_idx = request.py_seq_slot
1380+
# Do not add a gen_request_seq_slot for CUDA graph dummy requests
1381+
# to prevent access errors due to None values
1382+
if not request.is_cuda_graph_dummy:
1383+
gen_request_seq_slots.append(request.py_seq_slot)
13811384

13821385
previous_batch_len = len(previous_batch_indices)
13831386

@@ -1506,7 +1509,7 @@ def previous_seq_slots_device():
15061509
pin_memory=True,
15071510
)
15081511

1509-
num_generation_requests = len(scheduled_requests.generation_requests)
1512+
num_generation_requests = len(gen_request_seq_slots)
15101513
# Cache indirection is only used for beam search on generation requests
15111514
if self.use_beam_search and num_generation_requests > 0:
15121515
# CUDA Graph needs to set beam width during warmup (where the graph is captured), to ensure that cache indirection buffer is correctly picked up by the CUDA graph

tests/unittest/_torch/test_beam_search.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def llm_cuda_graph(fixed_params, input_prompts):
6161
max_seq_len=32,
6262
max_beam_width=fixed_params["max_beam_width"],
6363
disable_overlap_scheduler=False,
64-
cuda_graph_config=CudaGraphConfig(),
64+
cuda_graph_config=CudaGraphConfig(batch_sizes=[1, 2, 4, 8],
65+
enable_padding=True),
6566
)
6667

6768

@@ -126,7 +127,7 @@ def test_beam_search_output_shapes(gather_context_logits: bool,
126127
@pytest.mark.parametrize("gather_generation_logits", [True, False])
127128
@pytest.mark.parametrize("gather_context_logits", [True, False])
128129
@pytest.mark.parametrize("num_output_beams", [1, 2])
129-
@pytest.mark.parametrize("num_prompts", [1, 2])
130+
@pytest.mark.parametrize("num_prompts", [1, 2, 3])
130131
@pytest.mark.threadleak(enabled=False)
131132
def test_beam_search_output_shapes_cuda_graph_and_overlap(
132133
gather_context_logits: bool, gather_generation_logits: bool,
@@ -145,6 +146,10 @@ def test_beam_search_output_shapes_cuda_graph_and_overlap(
145146
return_generation_logits=gather_generation_logits,
146147
logprobs=return_log_probs,
147148
)
149+
# test padding of cuda graph with 3 prompts
150+
# replicate the prompts to have more than 2 prompts available
151+
if (num_prompts == 3 and len(input_prompts) == 2):
152+
input_prompts = [input_prompts[0]] * 3
148153
outputs = llm_cuda_graph.generate(input_prompts[:num_prompts],
149154
sampling_params=sampling_params)
150155
assert len(outputs) == num_prompts

0 commit comments

Comments
 (0)