Skip to content

Commit 658b7be

Browse files
crazydemodominicshanshan
authored andcommitted
[TRTLLM-6975][test] Add multi-turn test cases for VLM models (NVIDIA#6749)
Signed-off-by: Ivy Zhang <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
1 parent 3d54a1a commit 658b7be

File tree

7 files changed

+204
-4
lines changed

7 files changed

+204
-4
lines changed

examples/llm-api/quickstart_multimodal.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,15 @@ def add_multimodal_args(parser):
122122
" ├── __init__.py"
123123
" ├── <model_name>.py"
124124
" └── <sub_dirs>"))
125+
# Add multiturn conversation related parameters
126+
parser.add_argument("--multiturn",
127+
action="store_true",
128+
help="Enable multi-turn conversation mode.")
129+
parser.add_argument(
130+
"--conversation_turns",
131+
type=int,
132+
default=2,
133+
help="Number of conversation turns for automated testing.")
125134
return parser
126135

127136

@@ -188,6 +197,80 @@ def main():
188197
f"Unsupported model_type: {model_type} found!\n" \
189198
f"Supported types: {MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types()}"
190199

200+
# If multiturn mode is enabled
201+
if args.multiturn:
202+
# Run predefined multiturn conversation examples
203+
assert args.prompt is not None, "Please provide a prompt for multiturn conversation."
204+
assert args.media is not None, "Please provide media for multiturn conversation."
205+
# Determine how many turns to run
206+
max_turns = min(args.conversation_turns, len(args.prompt))
207+
generated_outputs = [] # Store generated outputs for return
208+
209+
# Initialize conversation history with the first prompt
210+
conversation_history = args.prompt[0] if args.prompt else ""
211+
212+
for i in range(max_turns):
213+
print(f"\n--- Turn {i+1} ---")
214+
215+
try:
216+
# Use multimodal input loader to process input with conversation context
217+
# Use accumulated conversation history instead of just the current prompt
218+
cur_prompt = conversation_history
219+
inputs = default_multimodal_input_loader(
220+
tokenizer=llm.tokenizer,
221+
model_dir=llm._hf_model_dir,
222+
model_type=model_type,
223+
modality=args.modality,
224+
prompts=[cur_prompt],
225+
media=args.media,
226+
image_data_format="pt",
227+
num_frames=8,
228+
device="cpu")
229+
230+
lora_request = None
231+
if args.load_lora:
232+
if model_class is None:
233+
raise ValueError(
234+
"model_class must be provided when load_lora is True"
235+
)
236+
lora_request = model_class.lora_request(
237+
len(inputs), args.modality, llm._hf_model_dir)
238+
239+
# Generate response
240+
outputs = llm.generate(inputs,
241+
sampling_params,
242+
lora_request=lora_request)
243+
assert outputs and len(
244+
outputs) > 0 and outputs[0].outputs and len(
245+
outputs[0].outputs) > 0
246+
response = outputs[0].outputs[0].text.strip()
247+
248+
# Store generated output
249+
generated_outputs.append({
250+
"turn": i + 1,
251+
"user_input": cur_prompt,
252+
"assistant_response": response,
253+
"media": args.media
254+
})
255+
256+
conversation_history = conversation_history + "\n" + response
257+
if i + 1 < len(args.prompt):
258+
conversation_history = conversation_history + "\n" + args.prompt[
259+
i + 1]
260+
261+
except Exception as e:
262+
print(f"Error in turn {i+1}: {e}")
263+
import traceback
264+
traceback.print_exc()
265+
continue
266+
267+
for i, output in enumerate(generated_outputs):
268+
print(
269+
f"[{i}] Prompt: {output['user_input']!r}, Generated text: {output['assistant_response']!r}"
270+
)
271+
return
272+
273+
# Original single-turn processing logic
191274
# set prompts and media to example prompts and images if they are not provided
192275
if args.prompt is None:
193276
args.prompt = example_medias_and_prompts[args.modality]["prompt"]

tests/integration/defs/accuracy/references/gsm8k.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ meta-llama/Llama-4-Maverick-17B-128E-Instruct:
2323
- accuracy: 92.20
2424
- quant_algo: FP8
2525
kv_cache_quant_algo: FP8
26-
accuracy: 90.20
26+
accuracy: 92.20
27+
- quant_algo: FP8
28+
kv_cache_quant_algo: FP8
29+
spec_dec_algo: Eagle
30+
accuracy: 92.20
2731
meta-llama/Llama-4-Scout-17B-16E-Instruct:
2832
- accuracy: 89.70
2933
- quant_algo: NVFP4

tests/integration/defs/accuracy/references/mmlu.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ meta-llama/Llama-4-Maverick-17B-128E-Instruct:
7878
kv_cache_quant_algo: FP8
7979
spec_dec_algo: Eagle
8080
accuracy: 86.40
81+
- quant_algo: FP8
82+
kv_cache_quant_algo: FP8
83+
accuracy: 86.40
8184
meta-llama/Llama-4-Scout-17B-16E-Instruct:
8285
- accuracy: 80.00
8386
- quant_algo: NVFP4

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def run_parallel_test(model_name: str,
286286
total_ctx_gpus = ctx_tp * ctx_pp * ctx_instances
287287
total_gen_gpus = gen_tp * gen_pp * gen_instances
288288
if total_ctx_gpus + total_gen_gpus > get_device_count():
289-
pytest.fail(
289+
pytest.skip(
290290
f"Not enough devices for {ctx_instances} ctx instances (ctx_pp={ctx_pp}*ctx_tp={ctx_tp}) + {gen_instances} gen instances (gen_pp={gen_pp}*gen_tp={gen_tp}), total: {total_ctx_gpus + total_gen_gpus}"
291291
)
292292

@@ -421,6 +421,7 @@ def test_ngram(self):
421421
task = GSM8K(self.MODEL_NAME)
422422
task.evaluate(llm)
423423

424+
@skip_pre_hopper
424425
@parametrize_with_ids("overlap_scheduler", [True, False])
425426
@parametrize_with_ids("eagle3_one_model", [True, False])
426427
def test_eagle3(self, overlap_scheduler, eagle3_one_model):
@@ -597,6 +598,7 @@ def test_multi_instance(self, testset):
597598

598599
@pytest.mark.skip_less_device_memory(140000)
599600
@pytest.mark.timeout(3600)
601+
@pytest.mark.skip_less_device(4)
600602
class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
601603
MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
602604
MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct"
@@ -678,6 +680,7 @@ def test_nixl_backend(self):
678680
@parametrize_with_ids("overlap_scheduler", [True, False])
679681
@parametrize_with_ids("mtp_nextn",
680682
[0, pytest.param(2, marks=skip_pre_hopper)])
683+
@pytest.mark.skip_less_device(4)
681684
def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
682685
ctx_server_config = {"disable_overlap_scheduler": True}
683686
gen_server_config = {"disable_overlap_scheduler": not overlap_scheduler}
@@ -811,6 +814,7 @@ def test_nixl_backend(self):
811814
task.evaluate(llm)
812815

813816
@pytest.mark.parametrize("overlap_scheduler", [False, True])
817+
@skip_pre_hopper
814818
def test_auto_dtype(self, overlap_scheduler):
815819
ctx_server_config = {
816820
"disable_overlap_scheduler": True,

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,7 @@ class TestMistralSmall24B(LlmapiAccuracyTestHarness):
814814
MODEL_NAME = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
815815
MODEL_PATH = f"{llm_models_root()}/Mistral-Small-3.1-24B-Instruct-2503"
816816

817+
@pytest.mark.skip_less_device_memory(80000)
817818
def test_auto_dtype(self):
818819
with LLM(self.MODEL_PATH) as llm:
819820
task = CnnDailymail(self.MODEL_NAME)

tests/integration/defs/test_e2e.py

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2076,7 +2076,7 @@ def test_ptp_quickstart_advanced_8gpus(llm_root, llm_venv, model_name,
20762076
def test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k(
20772077
llm_root, llm_venv, model_name, model_path, cuda_graph):
20782078
print(f"Testing {model_name} on 8 GPUs.")
2079-
example_root = Path(os.path.join(llm_root, "examples", "pytorch"))
2079+
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
20802080
cmd = [
20812081
str(example_root / "quickstart_advanced.py"),
20822082
"--enable_chunked_prefill",
@@ -2101,10 +2101,12 @@ def test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k(
21012101
@pytest.mark.skip_less_device_memory(80000)
21022102
@pytest.mark.skip_less_device(2)
21032103
@pytest.mark.parametrize("model_name,model_path", [
2104-
("Llama3.1-70B-BF16", "llama-3.1-model/Meta-Llama-3.1-70B"),
21052104
('Nemotron-Super-49B-v1-BF16',
21062105
'nemotron-nas/Llama-3_3-Nemotron-Super-49B-v1'),
21072106
("Mixtral-8x7B-BF16", "Mixtral-8x7B-Instruct-v0.1"),
2107+
pytest.param('Llama3.1-70B-BF16',
2108+
'llama-3.1-model/Meta-Llama-3.1-70B',
2109+
marks=pytest.mark.skip_less_device_memory(95000)),
21082110
])
21092111
def test_ptp_quickstart_advanced_2gpus_sm120(llm_root, llm_venv, model_name,
21102112
model_path):
@@ -2551,6 +2553,106 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name,
25512553
print("All answers are correct!")
25522554

25532555

2556+
@pytest.mark.skip_less_device_memory(80000)
2557+
@pytest.mark.parametrize("model_name,model_path", [
2558+
("gemma-3-27b-it", "gemma/gemma-3-27b-it"),
2559+
("mistral-small-3.1-24b-instruct", "Mistral-Small-3.1-24B-Instruct-2503"),
2560+
("Phi-4-multimodal-instruct", "multimodals/Phi-4-multimodal-instruct"),
2561+
])
2562+
def test_ptp_quickstart_multimodal_multiturn(llm_root, llm_venv, model_name,
2563+
model_path):
2564+
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
2565+
test_data_root = Path(
2566+
os.path.join(llm_models_root(), "multimodals", "test_data"))
2567+
2568+
print(f"Accuracy test {model_name} image mode with example inputs.")
2569+
2570+
# Define accuracy inputs for image modality
2571+
accuracy_inputs = {
2572+
"image": {
2573+
"prompt": [
2574+
"Describe what you see in this image.",
2575+
"How would you describe the atmosphere of this scene?",
2576+
],
2577+
"media": [
2578+
str(test_data_root / "inpaint.png"),
2579+
],
2580+
}
2581+
}
2582+
2583+
# Define expected keywords for each model
2584+
expected_keywords = {
2585+
"gemma-3-27b-it": {
2586+
"image": [
2587+
["half", "dome", "yosemite", "landmark", "rounded"],
2588+
["atmosphere", "peaceful", "majestic", "calm", "quiet"],
2589+
],
2590+
},
2591+
"mistral-small-3.1-24b-instruct": {
2592+
"image": [
2593+
["depicts", "landscape", "rock", "sky", "high", "altitude"],
2594+
["atmosphere", "serene", "majestic", "sense", "tranquility"],
2595+
],
2596+
},
2597+
"Phi-4-multimodal-instruct": {
2598+
"image": [
2599+
["depicts", "landscape", "mountain", "half", "dome"],
2600+
["atmosphere", "serene", "sense", "tranquility", "peace."],
2601+
],
2602+
},
2603+
}
2604+
# Build command for image modality
2605+
cmd = [
2606+
str(example_root / "quickstart_multimodal.py"),
2607+
"--model_dir",
2608+
f"{llm_models_root()}/{model_path}",
2609+
"--modality",
2610+
"image",
2611+
"--multiturn",
2612+
"--prompt",
2613+
*accuracy_inputs["image"]["prompt"],
2614+
"--media",
2615+
*accuracy_inputs["image"]["media"],
2616+
]
2617+
2618+
# Add model-specific configurations
2619+
if model_name == "gemma-3-27b-it":
2620+
# Gemma3 VLM needs a custom mask which is only supported by flashinfer backend currently.
2621+
# Custom mask involves bidirectional masking of image tokens in context phase. To get this
2622+
# correct, chunked prefill and kv cache reuse need to be turned off.
2623+
cmd.append("--image_format=pil")
2624+
cmd.append("--attention_backend=FLASHINFER")
2625+
cmd.append("--disable_kv_cache_reuse")
2626+
elif model_name == "Phi-4-multimodal-instruct":
2627+
# Set max_seq_len to 4096 to use short rope factor.
2628+
cmd.append("--max_seq_len=4096")
2629+
cmd.append("--load_lora")
2630+
cmd.append("--auto_model_name")
2631+
cmd.append("Phi4MMForCausalLM")
2632+
2633+
output = llm_venv.run_cmd(cmd, caller=check_output)
2634+
print("output:", output)
2635+
# Set match ratio based on model
2636+
match_ratio = 4.0 / 5
2637+
if model_name == "Phi-4-multimodal-instruct":
2638+
match_ratio = 0.6
2639+
2640+
# Check output accuracy
2641+
for prompt_output, prompt_keywords in zip(
2642+
parse_output(output), expected_keywords[model_name]["image"]):
2643+
matches = [
2644+
keyword in prompt_output.lower() for keyword in prompt_keywords
2645+
]
2646+
obs_match_ratio = 1. * sum(matches) / len(matches)
2647+
print("prompt_output:", prompt_output)
2648+
print("prompt_keywords:", prompt_keywords)
2649+
print("matches:", matches)
2650+
print("obs_match_ratio:", obs_match_ratio)
2651+
assert obs_match_ratio >= match_ratio, f"Incorrect output!\nGenerated \"{prompt_output}\"\nExpected keywords \"{prompt_keywords}\"\n Matched keywords: {matches}\n Observed match ratio {obs_match_ratio} below threshold {match_ratio}"
2652+
2653+
print("All answers are correct!")
2654+
2655+
25542656
@pytest.mark.parametrize("model_name,model_path", [
25552657
("BertForSequenceClassification", "bert/bert-base-uncased-yelp-polarity"),
25562658
])

tests/integration/test_lists/qa/llm_function_full.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,9 @@ test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio]
651651
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[gemma-3-27b-it-gemma/gemma-3-27b-it]
652652
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503]
653653
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[Phi-4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct]
654+
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[gemma-3-27b-it-gemma/gemma-3-27b-it]
655+
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503]
656+
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[Phi-4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct]
654657
test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
655658
test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
656659
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]

0 commit comments

Comments
 (0)