diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 8d7f2f32341..5c5e573492b 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -963,7 +963,9 @@ def _prepare_and_schedule_batch(self): if self.drafter is not None: self.use_spec_decode = self.drafter.should_use_spec_decode( - self.active_requests) + self.active_requests, self.max_batch_size, + self.model_engine.max_num_tokens, + self.model_engine.spec_config.max_draft_len) self.model_engine.enable_spec_decode = self.use_spec_decode # If speculation is off, this function sets py_draft_tokens to None # for all active requests. If it's on, we initialize py_draft_tokens diff --git a/tensorrt_llm/_torch/speculative/drafter.py b/tensorrt_llm/_torch/speculative/drafter.py index 82d816b8001..4fd4ff4d7f7 100644 --- a/tensorrt_llm/_torch/speculative/drafter.py +++ b/tensorrt_llm/_torch/speculative/drafter.py @@ -27,12 +27,28 @@ def prepare_draft_tokens( raise NotImplementedError @final - def should_use_spec_decode(self, requests: List[LlmRequest]) -> bool: + def should_use_spec_decode(self, requests: List[LlmRequest], + max_batch_size: int, max_num_tokens: int, + max_draft_len: int) -> bool: """ You probably don't want to override this. ModelEngine assumes that speculation is always on if max_concurrency is not specified by the user's spec config. """ - if self.max_concurrency is not None: - return len(requests) <= self.max_concurrency - return True + + # Inputs typically validated upstream: max_batch_size>0, max_num_tokens>0, max_draft_len>=0 + + if self.max_concurrency is None: + return True + + # Defensive guards; keep behavior explicit for zero/empty cases + if not requests or max_batch_size <= 0 or max_num_tokens <= 0: + return False + + tokens_per_request = 1 + max_draft_len + token_cap = max_num_tokens // tokens_per_request + if token_cap <= 0: + return False + + num_effective_requests = min(len(requests), max_batch_size, token_cap) + return num_effective_requests <= self.max_concurrency diff --git a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py index 3bb453a69d1..d1d9510b925 100644 --- a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py +++ b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py @@ -51,7 +51,8 @@ def test_dynamic_spec_decode(): ) # Mock should_use_spec_decode to return True for first two calls, then False - def mock_should_use_spec_decode(self, requests): + def mock_should_use_spec_decode(self, requests, max_batch_size, + max_num_tokens, max_draft_len): if not hasattr(mock_should_use_spec_decode, 'call_count'): mock_should_use_spec_decode.call_count = 0 mock_should_use_spec_decode.call_count += 1 @@ -86,5 +87,60 @@ def mock_should_use_spec_decode(self, requests): assert text_spec == text_ref +def test_should_use_spec_decode(): + from tensorrt_llm._torch.speculative.drafter import Drafter + + class _DummyDrafter(Drafter): + + def prepare_draft_tokens(self, + scheduled_requests, + resource_manager=None) -> None: + return + + drafter = _DummyDrafter(max_concurrency=6) + + # Compare min(len(requests), max_batch_size, token_cap) with max_concurrency + + # Small active_requests ON case: num_effective_requests = min(5, 8, very_large) = 5 <= 6 → True + active_requests = [object()] * 5 + assert drafter.should_use_spec_decode(active_requests, + max_batch_size=8, + max_num_tokens=4096 * 8, + max_draft_len=4) + + # Small batch size ON case: num_effective_requests = min(12, 5, very_large) = 5 <= 6 → True + active_requests = [object()] * 12 + assert drafter.should_use_spec_decode(active_requests, + max_batch_size=5, + max_num_tokens=4096 * 8, + max_draft_len=4) + + # Small token budget ON case: token_cap = 28 // (1+4) = 5 → min(8, 12, 5) = 5 <= 6 → True + active_requests = [object()] * 12 + assert drafter.should_use_spec_decode(active_requests, + max_batch_size=8, + max_num_tokens=28, + max_draft_len=4) + + # Generic OFF case: num_effective_requests = min(12, 8, very_large) = 8 > 6 → False + active_requests = [object()] * 12 + assert not drafter.should_use_spec_decode(active_requests, + max_batch_size=8, + max_num_tokens=4096 * 8, + max_draft_len=4) + + # Edge case - None active requests OFF case + active_requests = [] + assert not drafter.should_use_spec_decode(active_requests, + max_batch_size=8, + max_num_tokens=4096 * 8, + max_draft_len=4) + + # Edge case - Token cap equals 0 OFF case: token_cap = 4 // (1+4) = 0 → min(12, 8, 0) = 0 <= 6 → False + active_requests = [object()] * 12 + assert not drafter.should_use_spec_decode( + active_requests, max_batch_size=8, max_num_tokens=4, max_draft_len=4) + + if __name__ == "__main__": unittest.main()