Skip to content

Commit 48ddc3d

Browse files
LinPolydc3671
authored andcommitted
[fix]: Revert commit 388b491 (#6143)
Signed-off-by: Pengyun Lin <[email protected]>
1 parent 24ce6b9 commit 48ddc3d

File tree

5 files changed

+18
-47
lines changed

5 files changed

+18
-47
lines changed

tensorrt_llm/llmapi/llm.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -544,14 +544,6 @@ def _check_arguments(self, prompt_len: int, query_len: int,
544544
raise ValueError(
545545
f"PyTorch backend currently only supports `logprobs=1`. Received `logprobs={sampling_params.logprobs}` (Top{sampling_params.logprobs} logprobs). Please set `logprobs=1` in `sampling_params` instead."
546546
)
547-
# Check prompt length and query length against max_num_tokens to filter illegal requests.
548-
# Skip check for gen-only requests
549-
if self.args.backend == "pytorch" and not self.args.enable_chunked_prefill and not is_gen_only:
550-
max_num_tokens = self.args.max_num_tokens
551-
if max_num_tokens and prompt_len / self.args.parallel_config.cp_size + query_len > max_num_tokens:
552-
raise ValueError(
553-
f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}), query length ({query_len}) should not exceed "
554-
f"max_num_tokens ({max_num_tokens})")
555547
return
556548

557549
build_config = self.args.build_config
@@ -568,7 +560,7 @@ def _check_arguments(self, prompt_len: int, query_len: int,
568560
(sampling_params.max_tokens or 0) > max_seq_len):
569561
raise ValueError(
570562
f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}) and query length ({query_len}) max_tokens ({sampling_params.max_tokens}) should not exceed "
571-
f"max_seq_len ({max_seq_len})")
563+
f"max_seq_len ({build_config.max_seq_len})")
572564

573565
if sampling_params.use_beam_search and sampling_params.best_of > build_config.max_beam_width:
574566
if sampling_params.n == sampling_params.best_of:

tests/unittest/llmapi/test_llm.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2089,36 +2089,24 @@ def success_path():
20892089
success_path()
20902090

20912091

2092-
def _test_llm_capture_request_error(pytorch_backend: bool, tp_size: int = 1):
2093-
llm_args_extra = {}
2094-
if pytorch_backend:
2095-
LLM_CLASS = LLM_torch
2096-
llm_args_extra["max_num_tokens"] = 64
2097-
else:
2098-
LLM_CLASS = LLM
2099-
build_config = BuildConfig()
2100-
build_config.max_num_tokens = 64
2101-
llm_args_extra["fast_build"] = True
2102-
llm_args_extra["build_config"] = build_config
2092+
def _test_llm_capture_request_error(tp_size: int = 1):
2093+
build_config = BuildConfig()
2094+
build_config.max_num_tokens = 64
21032095

2104-
llm = LLM_CLASS(
2096+
llm = LLM(
21052097
model=llama_model_path,
2106-
tensor_parallel_size=tp_size,
2107-
**llm_args_extra,
2098+
build_config=build_config,
2099+
fast_build=True,
21082100
)
21092101

21102102
prompt = 'A ' * 65 # the minimum max_num_tokens is 64
2111-
if pytorch_backend:
2112-
# pytorch backend will raise ValueError for max_num_tokens
2113-
with pytest.raises(ValueError):
2114-
llm.generate(prompt)
2115-
else:
2116-
with pytest.raises(RequestError):
2117-
llm.generate(prompt)
2103+
2104+
with pytest.raises(RequestError):
2105+
llm.generate(prompt)
21182106

21192107

21202108
def test_llm_capture_request_error():
2121-
_test_llm_capture_request_error(pytorch_backend=False, tp_size=1)
2109+
_test_llm_capture_request_error(tp_size=1)
21222110

21232111

21242112
def test_llm_shutdown_executor():

tests/unittest/llmapi/test_llm_multi_gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def test_llm_get_stats_async_tp2(pytorch_backend):
466466

467467

468468
def test_llm_capture_request_error():
469-
_test_llm_capture_request_error(pytorch_backend=False, tp_size=2)
469+
_test_llm_capture_request_error(tp_size=2)
470470

471471

472472
def test_llm_with_postprocess_parallel_tp2():

tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,11 @@
77
from tensorrt_llm.lora_manager import LoraConfig
88
from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness
99
from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness
10-
from .test_llm import _test_llm_capture_request_error
1110
# isort: on
1211

1312
global_kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
1413

1514

16-
@pytest.mark.gpu2
17-
def test_llm_capture_request_error():
18-
_test_llm_capture_request_error(pytorch_backend=True, tp_size=2)
19-
20-
2115
@pytest.mark.gpu4
2216
def test_tinyllama_logits_processor_tp2pp2():
2317
tinyllama_logits_processor_test_harness(backend="pytorch",

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66

77
# isort: off
88
from .lora_test_utils import check_llama_7b_multi_unique_lora_adapters_from_request
9-
from .test_llm import (
10-
get_model_path, global_kvcache_config, llama_model_path,
11-
llm_get_stats_async_test_harness, llm_get_stats_test_harness, prompts,
12-
run_llm_abort_request, run_llm_with_postprocess_parallel_and_result_handler,
13-
tinyllama_logits_processor_test_harness, _test_llm_capture_request_error)
9+
from .test_llm import (get_model_path, global_kvcache_config, llama_model_path,
10+
llm_get_stats_async_test_harness,
11+
llm_get_stats_test_harness, prompts,
12+
run_llm_abort_request,
13+
run_llm_with_postprocess_parallel_and_result_handler,
14+
tinyllama_logits_processor_test_harness)
1415
from utils.util import (EnvVarsContextManager, force_ampere,
1516
run_function_in_sub_process, similar,
1617
skip_gpu_memory_less_than_40gb,
@@ -69,10 +70,6 @@ def test_llm_get_stats_async(return_context_logits, use_overlap,
6970
enable_iter_req_stats=enable_iter_req_stats)
7071

7172

72-
def test_llm_capture_request_error():
73-
_test_llm_capture_request_error(pytorch_backend=True, tp_size=1)
74-
75-
7673
@force_ampere
7774
@pytest.mark.parametrize(
7875
"sampling_params",

0 commit comments

Comments
 (0)