Skip to content

Commit a9561b3

Browse files
mikeiovineRia Jain
authored andcommitted
[None][feat] Optimize CUDA graph memory usage for spec decode cases (NVIDIA#6718)
Signed-off-by: Mike Iovine <[email protected]>
1 parent d45236b commit a9561b3

File tree

12 files changed

+540
-21
lines changed

12 files changed

+540
-21
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -726,8 +726,11 @@ def disable_optimization(backend: Backend):
726726
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
727727
# so that when we disable spec decode at runtime, we can still run the captured graph.
728728
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
729-
if not self.is_draft_model and self.max_draft_len > 0 and not self.spec_config.spec_dec_mode.use_one_engine(
730-
):
729+
if (not self.is_draft_model and self.max_draft_len > 0
730+
and not self.spec_config.spec_dec_mode.use_one_engine()
731+
# Assume that speculation is always on if the user didn't give us a max_concurrency
732+
# value. This will save on memory.
733+
and self.spec_config.max_concurrency is not None):
731734
draft_lengths.append(0)
732735

733736
for bs in cuda_graph_batch_sizes:

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -937,11 +937,30 @@ def _executor_loop(self):
937937
self.guided_decoder.init_disagg_gen_requests(
938938
scheduled_batch)
939939
if self.drafter is not None and self.use_spec_decode:
940-
if self.guided_decoder is not None:
941-
self.guided_decoder.rollback_rejected_tokens(
942-
scheduled_batch)
943-
self.drafter.prepare_draft_tokens(
944-
scheduled_batch, self.resource_manager)
940+
# When running with an external drafter, only TP rank 0 sends request to drafter
941+
if self.dist.tp_size > 1 and getattr(
942+
self.drafter, 'single_draft_call',
943+
lambda: False)():
944+
if self.dist.rank == 0:
945+
self.drafter.prepare_draft_tokens(
946+
scheduled_batch, self.resource_manager)
947+
draft_data = {}
948+
for req in scheduled_batch.generation_requests:
949+
draft_data[
950+
req.py_request_id] = req.py_draft_tokens
951+
self.dist.tp_broadcast(draft_data, root=0)
952+
else:
953+
draft_data = self.dist.tp_broadcast(None,
954+
root=0)
955+
for req in scheduled_batch.generation_requests:
956+
req.py_draft_tokens = draft_data[
957+
req.py_request_id]
958+
else:
959+
if self.guided_decoder is not None:
960+
self.guided_decoder.rollback_rejected_tokens(
961+
scheduled_batch)
962+
self.drafter.prepare_draft_tokens(
963+
scheduled_batch, self.resource_manager)
945964

946965
batch_outputs = self._forward_step(scheduled_batch)
947966
self._execute_guided_decoder(scheduled_batch,

tensorrt_llm/_torch/speculative/drafter.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import List, Optional
2+
from typing import List, Optional, final
33

44
from ..pyexecutor.llm_request import LlmRequest
55
from ..pyexecutor.resource_manager import ResourceManager
@@ -26,8 +26,13 @@ def prepare_draft_tokens(
2626
"""
2727
raise NotImplementedError
2828

29+
@final
2930
def should_use_spec_decode(self, requests: List[LlmRequest]) -> bool:
30-
"""Check if spec decode should be used for the current iteration."""
31+
"""
32+
You probably don't want to override this. ModelEngine
33+
assumes that speculation is always on if max_concurrency
34+
is not specified by the user's spec config.
35+
"""
3136
if self.max_concurrency is not None:
3237
return len(requests) <= self.max_concurrency
3338
return True
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import asyncio
2+
import json
3+
from typing import List
4+
5+
import aiohttp
6+
7+
from tensorrt_llm.logger import logger
8+
9+
from ..pyexecutor.llm_request import *
10+
from ..pyexecutor.scheduler import ScheduledRequests
11+
from .drafter import Drafter
12+
13+
14+
class APIDrafter(Drafter):
15+
16+
def __init__(
17+
self,
18+
spec_config: "ExternalAPIConfig",
19+
):
20+
super().__init__()
21+
self.max_draft_len = spec_config.max_draft_len
22+
self.endpoint = spec_config.endpoint
23+
assert self.endpoint is not None, "API endpoint is required for external API speculative decoding."
24+
self.template = spec_config.template if spec_config.template is not None else {}
25+
self.response_field = spec_config.response_field if spec_config.response_field is not None else "draft_tokens"
26+
27+
def single_draft_call(self):
28+
return True
29+
30+
def get_nested_field_from_response(self, response: dict) -> List[int]:
31+
# Allows for nested fields in the response.
32+
# Example: "choices.0.message.content"
33+
# Returns the value of the nested field: response["choices"][0]["message"]["content"]
34+
keys = self.response_field.split(".")
35+
current = response
36+
37+
for key in keys:
38+
try:
39+
if key.isdigit():
40+
key = int(key)
41+
if isinstance(current, list) and 0 <= key < len(current):
42+
current = current[key]
43+
else:
44+
logger.warning(
45+
f"Response field {self.response_field} is invalid for response {response}. Index {key} is invalid."
46+
)
47+
return []
48+
else:
49+
if isinstance(current, dict) and key in current:
50+
current = current[key]
51+
else:
52+
logger.warning(
53+
f"Response field {self.response_field} is invalid for response {response}. Index {key} is invalid."
54+
)
55+
return []
56+
57+
except (KeyError, ValueError, IndexError):
58+
logger.warning(
59+
f"Response field path is invalid: {self.response_field}")
60+
return []
61+
62+
if not isinstance(current, list):
63+
logger.warning(
64+
f"API response '{self.response_field}' must be a list. Got type: {type(current)}"
65+
)
66+
return []
67+
return current
68+
69+
async def get_draft_tokens(
70+
self,
71+
prefix: list[int],
72+
request_id: int,
73+
end_id: int,
74+
max_sequence_length: int,
75+
) -> List[int]:
76+
try:
77+
request_data = {
78+
"prefix": prefix,
79+
"request_id": request_id,
80+
"end_id": end_id,
81+
"max_sequence_length": max_sequence_length,
82+
}
83+
if self.template:
84+
request_data.update(self.template)
85+
86+
async with aiohttp.ClientSession() as session:
87+
async with session.post(
88+
url=self.endpoint,
89+
json=request_data,
90+
headers={"Content-Type": "application/json"},
91+
timeout=aiohttp.ClientTimeout(total=10),
92+
) as response:
93+
94+
# check for unsuccessful response
95+
if response.status != 200:
96+
logger.error(
97+
f"Failed to get draft tokens. API call failed for request {request_id} with status code {response.status}"
98+
)
99+
return []
100+
101+
result = await response.json()
102+
draft_tokens = self.get_nested_field_from_response(result)
103+
if len(draft_tokens) > self.max_draft_len:
104+
draft_tokens = draft_tokens[:self.max_draft_len]
105+
logger.debug(
106+
f"Retrieved draft tokens for request {request_id}: {draft_tokens}"
107+
)
108+
return draft_tokens
109+
110+
except json.JSONDecodeError as e:
111+
logger.error(
112+
f"Failed to parse JSON response for request {request_id}: {e}")
113+
return []
114+
115+
except Exception as e:
116+
logger.error(
117+
f"Failed to get draft tokens. API call failed for request {request_id} with the following error: {e}"
118+
)
119+
return []
120+
121+
async def async_prepare_draft_tokens(
122+
self,
123+
scheduled_requests: ScheduledRequests,
124+
resource_manager: None,
125+
) -> None:
126+
# Sort by request_id when py_batch_idx is None as a fallback.
127+
# This happens in the disagg case: for a set of new requests, we draft
128+
# before forward_step, so py_batch_idx is not assigned.
129+
sorted_requests = sorted(
130+
scheduled_requests.generation_requests,
131+
key=lambda r:
132+
(r.py_batch_idx is None, r.py_batch_idx or r.request_id),
133+
)
134+
135+
tasks = []
136+
for request in sorted_requests:
137+
# Add new token to a copy of the generated tokens to find new draft tokens
138+
prefix = list(request.get_tokens()[0]) # Get a copy
139+
task = self.get_draft_tokens(
140+
prefix,
141+
request.request_id,
142+
request.py_end_id,
143+
request.py_orig_prompt_len + request.py_max_new_tokens,
144+
)
145+
tasks.append(task)
146+
147+
try:
148+
all_draft_tokens = await asyncio.wait_for(asyncio.gather(
149+
*tasks, return_exceptions=True),
150+
timeout=10.0)
151+
except asyncio.TimeoutError:
152+
logger.error(
153+
f"Timeout occurred while getting draft tokens for batch of requests"
154+
)
155+
all_draft_tokens = [[] for _ in tasks]
156+
157+
for request, draft_tokens in zip(sorted_requests, all_draft_tokens):
158+
if isinstance(draft_tokens, Exception):
159+
logger.error(
160+
f"An exception occurred while getting draft tokens for request {request.request_id}. Set TLLM_LOG_LEVEL for more details."
161+
)
162+
draft_tokens = []
163+
elif len(draft_tokens) == 0:
164+
logger.error(
165+
f"Draft tokens could not be generated for request {request.request_id}. Set TLLM_LOG_LEVEL for more details."
166+
)
167+
else:
168+
# Pad length to `self.max_draft_len`
169+
if len(draft_tokens) > 0:
170+
pad_length = self.max_draft_len - len(draft_tokens)
171+
draft_tokens.extend([request.py_end_id] * pad_length)
172+
173+
request.py_draft_tokens = draft_tokens
174+
175+
def prepare_draft_tokens(
176+
self,
177+
scheduled_requests: ScheduledRequests,
178+
resource_manager: None,
179+
) -> None:
180+
asyncio.run(
181+
self.async_prepare_draft_tokens(scheduled_requests,
182+
resource_manager))

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class SpeculativeDecodingMode(IntEnum):
1717
NGRAM = auto()
1818
DRAFT_TARGET = auto()
1919
USER_PROVIDED = auto()
20+
EXTERNAL_API = auto()
2021
NONE = auto()
2122
AUTO = auto()
2223

@@ -44,6 +45,9 @@ def is_user_provided(self):
4445
def is_none(self):
4546
return self == SpeculativeDecodingMode.NONE
4647

48+
def is_external_api(self):
49+
return self == SpeculativeDecodingMode.EXTERNAL_API
50+
4751
def is_draft_target(self):
4852
return self == SpeculativeDecodingMode.DRAFT_TARGET
4953

@@ -79,7 +83,7 @@ def has_spec_decoder(self):
7983

8084
def has_spec_drafter(self):
8185
return self.is_eagle3() or self.is_draft_target() or self.is_ngram(
82-
) or self.is_user_provided()
86+
) or self.is_user_provided() or self.is_external_api()
8387

8488
def extend_ctx(self, attention_backend: Type[AttentionBackend]):
8589
"""
@@ -91,8 +95,8 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]):
9195
# Fixme: only trtllm attention backend supports eagle3 generation-phase kernels on blackwell.
9296
return ((self.is_eagle3() or self.is_draft_target())
9397
and not (issubclass(attention_backend, TrtllmAttention)
94-
and get_sm_version() == 100)
95-
) or self.is_ngram() or self.is_user_provided()
98+
and get_sm_version() == 100)) or self.is_ngram(
99+
) or self.is_user_provided() or self.is_external_api()
96100

97101
def attention_need_spec_dec_mode(self):
98102
"""

tensorrt_llm/_torch/speculative/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .eagle3 import (Eagle3OneModelSampler, Eagle3OneModelSpecMetadata,
88
Eagle3OneModelWorker, Eagle3ResourceManager,
99
Eagle3SpecMetadata)
10+
from .external_api import APIDrafter
1011
from .model_drafter import ModelDrafter
1112
from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler,
1213
MTPSpecMetadata, MTPWorker)
@@ -50,7 +51,8 @@ def get_spec_metadata(spec_config,
5051
)
5152
if spec_config.spec_dec_mode.is_draft_target() or \
5253
spec_config.spec_dec_mode.is_ngram() or \
53-
spec_config.spec_dec_mode.is_user_provided():
54+
spec_config.spec_dec_mode.is_user_provided() or \
55+
spec_config.spec_dec_mode.is_external_api():
5456
return SpecMetadata(
5557
max_draft_len=spec_config.max_draft_len,
5658
spec_dec_mode=spec_config.spec_dec_mode,
@@ -99,6 +101,8 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None):
99101
return NGramPoolManager(spec_config, max_num_requests)
100102
if spec_dec_mode.is_user_provided():
101103
return spec_config.resource_manager
104+
if spec_dec_mode.is_external_api():
105+
return None
102106
return None
103107

104108

@@ -142,6 +146,9 @@ def get_spec_drafter(model_engine,
142146
if spec_config.spec_dec_mode.is_ngram():
143147
return NGramDrafter(spec_config, spec_resource_manager)
144148

149+
if spec_config.spec_dec_mode.is_external_api():
150+
return APIDrafter(spec_config)
151+
145152
return None
146153

147154

tensorrt_llm/llmapi/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
CapacitySchedulerPolicy, ContextChunkingPolicy,
1010
CudaGraphConfig, DraftTargetDecodingConfig,
1111
DynamicBatchConfig, EagleDecodingConfig,
12-
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
13-
LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig,
14-
MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig,
15-
TorchCompileConfig, TorchLlmArgs, TrtLlmArgs,
16-
UserProvidedDecodingConfig)
12+
ExtendedRuntimePerfKnobConfig, ExternalAPIConfig,
13+
KvCacheConfig, LlmArgs, LookaheadDecodingConfig,
14+
MedusaDecodingConfig, MoeConfig, MTPDecodingConfig,
15+
NGramDecodingConfig, SchedulerConfig, TorchCompileConfig,
16+
TorchLlmArgs, TrtLlmArgs, UserProvidedDecodingConfig)
1717
from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
1818
QuantConfig)
1919
from .mpi_session import MpiCommSession
@@ -49,6 +49,7 @@
4949
'CacheTransceiverConfig',
5050
'NGramDecodingConfig',
5151
'UserProvidedDecodingConfig',
52+
'ExternalAPIConfig',
5253
'TorchCompileConfig',
5354
'DraftTargetDecodingConfig',
5455
'LlmArgs',

0 commit comments

Comments
 (0)