Skip to content

Commit b00edd6

Browse files
syuoniShunkang
authored andcommitted
[TRTLLM-6854][feat] Enable guided decoding with disagg serving (NVIDIA#6704)
Signed-off-by: Enwei Zhu <[email protected]>
1 parent 0653551 commit b00edd6

File tree

7 files changed

+175
-24
lines changed

7 files changed

+175
-24
lines changed

docs/source/torch/features/feature_combination_matrix.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@
1515
| KV Cache Reuse | Yes | Yes | Yes | Untested | Yes | Untested | Yes | No | Yes | Yes | --- | | | |
1616
| Slide Window Attention | Yes | Yes | Yes | Untested | No | Untested | Untested | Untested | Yes | Yes | WIP | --- | | |
1717
| Logits Post Processor | No | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | --- | |
18-
| Guided Decoding | Yes | Yes | Yes | No | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | --- |
18+
| Guided Decoding | Yes | Yes | Yes | Yes | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | --- |

tensorrt_llm/_torch/pyexecutor/guided_decoder.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self,
3939
guided_decoding_config, vocab_size_padded)
4040
else:
4141
raise ValueError(
42-
f"invalid guided decoding backend: {self.guided_decoding_backend}"
42+
f"Invalid guided decoding backend: {self.guided_decoding_backend}"
4343
)
4444
logger.info(
4545
f"Guided decoder initialized with backend: {self.guided_decoding_backend}"
@@ -71,15 +71,15 @@ def __init__(self,
7171
def bitmask_size(self) -> int:
7272
return math.ceil(self.vocab_size_padded / 32)
7373

74-
def _is_matcher_init(self, llm_req: LlmRequest) -> bool:
74+
def _require_matcher_init(self, llm_req: LlmRequest) -> bool:
7575
if llm_req.guided_decoding_params is None:
7676
return False
7777
if llm_req.py_is_draft:
7878
return False
7979
# The request is in the last chunk of a context forward step.
8080
return llm_req.is_context_init_state and llm_req.is_last_context_chunk
8181

82-
def _is_matcher_in_progress(self, llm_req: LlmRequest) -> bool:
82+
def _require_matcher_advance(self, llm_req: LlmRequest) -> bool:
8383
if llm_req.guided_decoding_params is None:
8484
return False
8585
if llm_req.py_is_draft:
@@ -102,12 +102,17 @@ def build(self, scheduled_requests: ScheduledRequests) -> None:
102102
self.num_advanced_tokens[slot] = 0
103103
self.num_guided_tokens[slot] = 0
104104

105-
if self._is_matcher_init(llm_req):
105+
matcher_init: bool = self._require_matcher_init(llm_req)
106+
matcher_advance: bool = self._require_matcher_advance(llm_req)
107+
if not (matcher_init or matcher_advance):
108+
continue
109+
110+
if matcher_init:
106111
matcher = self.grammar_matcher_factory.create(
107112
llm_req.guided_decoding_params)
108113
self.grammar_matchers[slot] = matcher
109114

110-
elif self._is_matcher_in_progress(llm_req):
115+
if matcher_advance:
111116
matcher = self.grammar_matchers[slot]
112117
# The last new token must be acceptable unless the matcher is terminated in a drafting loop.
113118
if llm_req.py_is_draft and (matcher.is_terminated()
@@ -127,9 +132,6 @@ def build(self, scheduled_requests: ScheduledRequests) -> None:
127132
f"Request {llm_req.py_request_id} failed to accept last new token: {last_new_token}."
128133
)
129134

130-
else:
131-
continue
132-
133135
self.num_advanced_tokens[slot] += 1
134136
if not matcher.is_terminated():
135137
matcher.fill_next_token_bitmask(self.bitmask_host[slot], 0)
@@ -244,3 +246,19 @@ def rollback_draft_tokens(self,
244246
# Reset the drafting states.
245247
self.num_advanced_draft_tokens[slot] = 0
246248
self.is_draft_terminated[slot] = False
249+
250+
@nvtx_range("GuidedDecoder.init_disagg_gen_requests")
251+
def init_disagg_gen_requests(self,
252+
scheduled_requests: ScheduledRequests) -> None:
253+
"""Initialize the grammar matchers for disagg gen requests.
254+
"""
255+
for llm_req in scheduled_requests.generation_requests:
256+
if llm_req.guided_decoding_params is None:
257+
continue
258+
assert not llm_req.py_is_draft
259+
slot: int = llm_req.py_seq_slot
260+
if llm_req.context_phase_params is not None and llm_req.py_decoding_iter == 1:
261+
# The request is in the first generation forward step at the disagg gen instance.
262+
self.grammar_matchers[
263+
slot] = self.grammar_matcher_factory.create(
264+
llm_req.guided_decoding_params)

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,9 @@ def __init__(
335335
self.py_return_generation_logits = return_generation_logits
336336
self.py_return_logits_device_memory = return_logits_device_memory
337337
self.py_is_draft = is_draft
338+
# The request's sequence slot ID, an index between 0 (inclusive) and max_batch_size (exclusive).
338339
self.py_seq_slot = seq_slot
340+
# If the request is a draft request, target_seq_slot is the sequence slot ID of its target request.
339341
self.py_target_seq_slot = target_seq_slot
340342

341343
# TODO: remove this when use DynamicDecodeOp in pytorch flow.

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,9 @@ def _executor_loop_pp(self):
749749
if self._need_return_logits(scheduled_batch):
750750
logits_host = batch_outputs["logits"].to(
751751
"cpu", non_blocking=True)
752+
if self.kv_cache_transceiver and self.guided_decoder:
753+
self.guided_decoder.init_disagg_gen_requests(
754+
scheduled_batch)
752755
self._execute_guided_decoder(
753756
scheduled_batch, batch_outputs['logits'])
754757

@@ -939,6 +942,10 @@ def _executor_loop(self):
939942
self._handle_first_token_response(scheduled_batch)
940943

941944
self.resource_manager.prepare_resources(scheduled_batch)
945+
946+
if self.kv_cache_transceiver and self.guided_decoder:
947+
self.guided_decoder.init_disagg_gen_requests(
948+
scheduled_batch)
942949
if self.drafter is not None and self.use_spec_decode:
943950
if self.guided_decoder is not None:
944951
self.guided_decoder.rollback_rejected_tokens(
@@ -1063,6 +1070,9 @@ def _executor_loop_overlap(self):
10631070
if self.previous_batch is not None:
10641071
self._update_requests(self.previous_batch.sample_state)
10651072

1073+
if self.kv_cache_transceiver and self.guided_decoder:
1074+
self.guided_decoder.init_disagg_gen_requests(
1075+
scheduled_batch)
10661076
self._execute_guided_decoder(scheduled_batch,
10671077
batch_outputs['logits'])
10681078

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 126 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Please take a look at the existing test_llm_api_pytorch.py file for reference.
55
import concurrent
66
import contextlib
7+
import json
78
import os
89
import tempfile
910
import time
@@ -19,12 +20,13 @@
1920
from tensorrt_llm.executor.result import GenerationResultBase
2021
from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams
2122
from tensorrt_llm.llmapi.llm_args import LlmArgs
23+
from tensorrt_llm.llmapi.tokenizer import load_hf_tokenizer
2224

2325
from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids,
2426
skip_pre_hopper)
2527
from ..trt_test_alternative import popen
26-
from .accuracy_core import (GSM8K, MMLU, LlmapiAccuracyTestHarness,
27-
get_accuracy_task)
28+
from .accuracy_core import (GSM8K, MMLU, JsonModeEval,
29+
LlmapiAccuracyTestHarness, get_accuracy_task)
2830

2931

3032
class Result(GenerationResultBase):
@@ -43,7 +45,7 @@ def result(self):
4345
return self
4446

4547

46-
DuckLLM = namedtuple('DuckLLM', ['args', 'generate_async'])
48+
DuckLLM = namedtuple('DuckLLM', ['args', 'tokenizer', 'generate_async'])
4749

4850

4951
class MyThreadPoolExecutor(ThreadPoolExecutor):
@@ -162,17 +164,35 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
162164

163165
def send_request(prompt: str, sampling_params: SamplingParams,
164166
streaming: bool):
165-
response = client.completions.create(
166-
model=model_name,
167-
prompt=prompt,
168-
stream=streaming,
169-
**({
170-
"max_tokens": sampling_params.max_tokens,
171-
"temperature": sampling_params.temperature,
172-
"top_p": sampling_params.top_p,
173-
"stop": sampling_params.stop,
174-
"seed": sampling_params.seed
175-
} if sampling_params else {}))
167+
kwargs = {}
168+
if sampling_params is not None:
169+
kwargs.update(max_tokens=sampling_params.max_tokens,
170+
temperature=sampling_params.temperature,
171+
top_p=sampling_params.top_p,
172+
stop=sampling_params.stop,
173+
seed=sampling_params.seed)
174+
if (guided_decoding_params :=
175+
sampling_params.guided_decoding) is not None:
176+
extra_body = {}
177+
if (schema := guided_decoding_params.json) is not None:
178+
extra_body.update(response_format={
179+
"type": "json",
180+
"schema": json.loads(schema)
181+
})
182+
elif guided_decoding_params.json_object:
183+
extra_body.update(
184+
response_format={"type": "json_object"})
185+
else:
186+
# TODO: Support other guided decoding types
187+
raise ValueError(
188+
f"Unsupported guided decoding params: {guided_decoding_params}."
189+
)
190+
kwargs.update(extra_body=extra_body)
191+
192+
response = client.completions.create(model=model_name,
193+
prompt=prompt,
194+
stream=streaming,
195+
**kwargs)
176196
result = Result(id=0,
177197
sampling_params=sampling_params,
178198
outputs=[
@@ -192,8 +212,10 @@ def generate_async(prompt: str,
192212
thread_pool.futures.append(future)
193213
return future
194214

215+
tokenizer = load_hf_tokenizer(model_name)
216+
195217
try:
196-
yield DuckLLM(args, generate_async)
218+
yield DuckLLM(args, tokenizer, generate_async)
197219
finally:
198220
ctx_server.terminate()
199221
gen_server.terminate()
@@ -394,6 +416,95 @@ def test_eagle3(self, overlap_scheduler, eagle3_one_model):
394416
task = GSM8K(self.MODEL_NAME)
395417
task.evaluate(llm)
396418

419+
@pytest.mark.skip_less_device_memory(32000)
420+
@pytest.mark.parametrize("backend", ["xgrammar", "llguidance"])
421+
def test_guided_decoding(self, backend: str, mocker):
422+
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
423+
ctx_server_config = {
424+
"disable_overlap_scheduler": True,
425+
"guided_decoding_backend": backend,
426+
"cache_transceiver_config": {
427+
"backend": "default"
428+
}
429+
}
430+
gen_server_config = {
431+
"guided_decoding_backend": backend,
432+
"cache_transceiver_config": {
433+
"backend": "default"
434+
}
435+
}
436+
disaggregated_server_config = {
437+
"hostname": "localhost",
438+
"port": 8000,
439+
"backend": "pytorch",
440+
"context_servers": {
441+
"num_instances": 1,
442+
"urls": ["localhost:8001"]
443+
},
444+
"generation_servers": {
445+
"num_instances": 1,
446+
"urls": ["localhost:8002"]
447+
}
448+
}
449+
with launch_disaggregated_llm(disaggregated_server_config,
450+
ctx_server_config, gen_server_config,
451+
self.MODEL_PATH) as llm:
452+
task = JsonModeEval(self.MODEL_NAME)
453+
task.evaluate(llm)
454+
455+
@pytest.mark.skip_less_device_memory(32000)
456+
@pytest.mark.parametrize("backend", ["xgrammar", "llguidance"])
457+
def test_guided_decoding_with_eagle3(self, backend: str, mocker):
458+
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
459+
speculative_decoding_config = {
460+
"decoding_type": "Eagle",
461+
"max_draft_len": 3,
462+
"speculative_model_dir":
463+
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
464+
"eagle3_one_model": False
465+
}
466+
467+
ctx_server_config = {
468+
"disable_overlap_scheduler": True,
469+
"speculative_config": speculative_decoding_config,
470+
"kv_cache_config": {
471+
"free_gpu_memory_fraction": 0.8,
472+
},
473+
"guided_decoding_backend": backend,
474+
"cache_transceiver_config": {
475+
"backend": "default"
476+
}
477+
}
478+
gen_server_config = {
479+
"disable_overlap_scheduler": True,
480+
"speculative_config": speculative_decoding_config,
481+
"kv_cache_config": {
482+
"free_gpu_memory_fraction": 0.8,
483+
},
484+
"guided_decoding_backend": backend,
485+
"cache_transceiver_config": {
486+
"backend": "default"
487+
}
488+
}
489+
disaggregated_server_config = {
490+
"hostname": "localhost",
491+
"port": 8000,
492+
"backend": "pytorch",
493+
"context_servers": {
494+
"num_instances": 1,
495+
"urls": ["localhost:8001"]
496+
},
497+
"generation_servers": {
498+
"num_instances": 1,
499+
"urls": ["localhost:8002"]
500+
}
501+
}
502+
with launch_disaggregated_llm(disaggregated_server_config,
503+
ctx_server_config, gen_server_config,
504+
self.MODEL_PATH) as llm:
505+
task = JsonModeEval(self.MODEL_NAME)
506+
task.evaluate(llm)
507+
397508
@pytest.mark.skip_less_device(2)
398509
@pytest.mark.parametrize("tp,pp", [(1, 2), (2, 1), (2, 2)],
399510
ids=["tp1pp2", "tp2pp1", "tp2pp2"])

tests/integration/test_lists/qa/llm_function_full.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,10 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[
448448
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance]
449449
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar]
450450
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[llguidance]
451+
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar]
452+
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance]
453+
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[xgrammar]
454+
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance]
451455
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4
452456
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
453457
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_eagle3_tp8[eagle3_one_model=True]
@@ -520,6 +524,10 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype
520524
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram
521525
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True]
522526
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False]
527+
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar]
528+
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance]
529+
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar]
530+
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance]
523531
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp1pp2]
524532
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp1]
525533
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp2]

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ l0_dgx_h100:
4242
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[True]
4343
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False]
4444
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True]
45+
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar]
46+
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar]
4547
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp1pp2]
4648
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp1pp2]
4749
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp1]

0 commit comments

Comments
 (0)