Skip to content

Commit ee20f84

Browse files
crazydemodominicshanshan
authored andcommitted
[TRTLLM-6975][test] Add multi-turn test cases for VLM models (NVIDIA#6749)
Signed-off-by: Ivy Zhang <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
1 parent 5f939b9 commit ee20f84

File tree

7 files changed

+204
-4
lines changed

7 files changed

+204
-4
lines changed

examples/llm-api/quickstart_multimodal.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,15 @@ def add_multimodal_args(parser):
122122
" ├── __init__.py"
123123
" ├── <model_name>.py"
124124
" └── <sub_dirs>"))
125+
# Add multiturn conversation related parameters
126+
parser.add_argument("--multiturn",
127+
action="store_true",
128+
help="Enable multi-turn conversation mode.")
129+
parser.add_argument(
130+
"--conversation_turns",
131+
type=int,
132+
default=2,
133+
help="Number of conversation turns for automated testing.")
125134
return parser
126135

127136

@@ -188,6 +197,80 @@ def main():
188197
f"Unsupported model_type: {model_type} found!\n" \
189198
f"Supported types: {MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types()}"
190199

200+
# If multiturn mode is enabled
201+
if args.multiturn:
202+
# Run predefined multiturn conversation examples
203+
assert args.prompt is not None, "Please provide a prompt for multiturn conversation."
204+
assert args.media is not None, "Please provide media for multiturn conversation."
205+
# Determine how many turns to run
206+
max_turns = min(args.conversation_turns, len(args.prompt))
207+
generated_outputs = [] # Store generated outputs for return
208+
209+
# Initialize conversation history with the first prompt
210+
conversation_history = args.prompt[0] if args.prompt else ""
211+
212+
for i in range(max_turns):
213+
print(f"\n--- Turn {i+1} ---")
214+
215+
try:
216+
# Use multimodal input loader to process input with conversation context
217+
# Use accumulated conversation history instead of just the current prompt
218+
cur_prompt = conversation_history
219+
inputs = default_multimodal_input_loader(
220+
tokenizer=llm.tokenizer,
221+
model_dir=llm._hf_model_dir,
222+
model_type=model_type,
223+
modality=args.modality,
224+
prompts=[cur_prompt],
225+
media=args.media,
226+
image_data_format="pt",
227+
num_frames=8,
228+
device="cpu")
229+
230+
lora_request = None
231+
if args.load_lora:
232+
if model_class is None:
233+
raise ValueError(
234+
"model_class must be provided when load_lora is True"
235+
)
236+
lora_request = model_class.lora_request(
237+
len(inputs), args.modality, llm._hf_model_dir)
238+
239+
# Generate response
240+
outputs = llm.generate(inputs,
241+
sampling_params,
242+
lora_request=lora_request)
243+
assert outputs and len(
244+
outputs) > 0 and outputs[0].outputs and len(
245+
outputs[0].outputs) > 0
246+
response = outputs[0].outputs[0].text.strip()
247+
248+
# Store generated output
249+
generated_outputs.append({
250+
"turn": i + 1,
251+
"user_input": cur_prompt,
252+
"assistant_response": response,
253+
"media": args.media
254+
})
255+
256+
conversation_history = conversation_history + "\n" + response
257+
if i + 1 < len(args.prompt):
258+
conversation_history = conversation_history + "\n" + args.prompt[
259+
i + 1]
260+
261+
except Exception as e:
262+
print(f"Error in turn {i+1}: {e}")
263+
import traceback
264+
traceback.print_exc()
265+
continue
266+
267+
for i, output in enumerate(generated_outputs):
268+
print(
269+
f"[{i}] Prompt: {output['user_input']!r}, Generated text: {output['assistant_response']!r}"
270+
)
271+
return
272+
273+
# Original single-turn processing logic
191274
# set prompts and media to example prompts and images if they are not provided
192275
if args.prompt is None:
193276
args.prompt = example_medias_and_prompts[args.modality]["prompt"]

tests/integration/defs/accuracy/references/gsm8k.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ meta-llama/Llama-4-Maverick-17B-128E-Instruct:
2323
- accuracy: 92.20
2424
- quant_algo: FP8
2525
kv_cache_quant_algo: FP8
26-
accuracy: 90.20
26+
accuracy: 92.20
27+
- quant_algo: FP8
28+
kv_cache_quant_algo: FP8
29+
spec_dec_algo: Eagle
30+
accuracy: 92.20
2731
meta-llama/Llama-4-Scout-17B-16E-Instruct:
2832
- accuracy: 89.70
2933
- quant_algo: NVFP4

tests/integration/defs/accuracy/references/mmlu.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ meta-llama/Llama-4-Maverick-17B-128E-Instruct:
7878
kv_cache_quant_algo: FP8
7979
spec_dec_algo: Eagle
8080
accuracy: 86.40
81+
- quant_algo: FP8
82+
kv_cache_quant_algo: FP8
83+
accuracy: 86.40
8184
meta-llama/Llama-4-Scout-17B-16E-Instruct:
8285
- accuracy: 80.00
8386
- quant_algo: NVFP4

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def run_parallel_test(model_name: str,
286286
total_ctx_gpus = ctx_tp * ctx_pp * ctx_instances
287287
total_gen_gpus = gen_tp * gen_pp * gen_instances
288288
if total_ctx_gpus + total_gen_gpus > get_device_count():
289-
pytest.fail(
289+
pytest.skip(
290290
f"Not enough devices for {ctx_instances} ctx instances (ctx_pp={ctx_pp}*ctx_tp={ctx_tp}) + {gen_instances} gen instances (gen_pp={gen_pp}*gen_tp={gen_tp}), total: {total_ctx_gpus + total_gen_gpus}"
291291
)
292292

@@ -424,6 +424,7 @@ def test_ngram(self):
424424
task.evaluate(llm)
425425

426426
@pytest.mark.skip_less_device(2)
427+
@skip_pre_hopper
427428
@parametrize_with_ids("overlap_scheduler", [True, False])
428429
@parametrize_with_ids("eagle3_one_model", [True, False])
429430
def test_eagle3(self, overlap_scheduler, eagle3_one_model):
@@ -601,6 +602,7 @@ def test_multi_instance(self, testset):
601602

602603
@pytest.mark.skip_less_device_memory(140000)
603604
@pytest.mark.timeout(3600)
605+
@pytest.mark.skip_less_device(4)
604606
class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
605607
MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
606608
MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct"
@@ -685,6 +687,7 @@ def test_nixl_backend(self):
685687
@parametrize_with_ids("overlap_scheduler", [True, False])
686688
@parametrize_with_ids("mtp_nextn",
687689
[0, pytest.param(2, marks=skip_pre_hopper)])
690+
@pytest.mark.skip_less_device(4)
688691
def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
689692
ctx_server_config = {"disable_overlap_scheduler": True}
690693
gen_server_config = {"disable_overlap_scheduler": not overlap_scheduler}
@@ -818,6 +821,7 @@ def test_nixl_backend(self):
818821
task.evaluate(llm)
819822

820823
@pytest.mark.parametrize("overlap_scheduler", [False, True])
824+
@skip_pre_hopper
821825
def test_auto_dtype(self, overlap_scheduler):
822826
ctx_server_config = {
823827
"disable_overlap_scheduler": True,

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,7 @@ class TestMistralSmall24B(LlmapiAccuracyTestHarness):
886886
MODEL_NAME = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
887887
MODEL_PATH = f"{llm_models_root()}/Mistral-Small-3.1-24B-Instruct-2503"
888888

889+
@pytest.mark.skip_less_device_memory(80000)
889890
def test_auto_dtype(self):
890891
with LLM(self.MODEL_PATH) as llm:
891892
task = CnnDailymail(self.MODEL_NAME)

tests/integration/defs/test_e2e.py

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,7 +2090,7 @@ def test_ptp_quickstart_advanced_8gpus(llm_root, llm_venv, model_name,
20902090
def test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k(
20912091
llm_root, llm_venv, model_name, model_path, cuda_graph):
20922092
print(f"Testing {model_name} on 8 GPUs.")
2093-
example_root = Path(os.path.join(llm_root, "examples", "pytorch"))
2093+
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
20942094
cmd = [
20952095
str(example_root / "quickstart_advanced.py"),
20962096
"--enable_chunked_prefill",
@@ -2115,10 +2115,12 @@ def test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k(
21152115
@pytest.mark.skip_less_device_memory(80000)
21162116
@pytest.mark.skip_less_device(2)
21172117
@pytest.mark.parametrize("model_name,model_path", [
2118-
("Llama3.1-70B-BF16", "llama-3.1-model/Meta-Llama-3.1-70B"),
21192118
('Nemotron-Super-49B-v1-BF16',
21202119
'nemotron-nas/Llama-3_3-Nemotron-Super-49B-v1'),
21212120
("Mixtral-8x7B-BF16", "Mixtral-8x7B-Instruct-v0.1"),
2121+
pytest.param('Llama3.1-70B-BF16',
2122+
'llama-3.1-model/Meta-Llama-3.1-70B',
2123+
marks=pytest.mark.skip_less_device_memory(95000)),
21222124
])
21232125
def test_ptp_quickstart_advanced_2gpus_sm120(llm_root, llm_venv, model_name,
21242126
model_path):
@@ -2565,6 +2567,106 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name,
25652567
print("All answers are correct!")
25662568

25672569

2570+
@pytest.mark.skip_less_device_memory(80000)
2571+
@pytest.mark.parametrize("model_name,model_path", [
2572+
("gemma-3-27b-it", "gemma/gemma-3-27b-it"),
2573+
("mistral-small-3.1-24b-instruct", "Mistral-Small-3.1-24B-Instruct-2503"),
2574+
("Phi-4-multimodal-instruct", "multimodals/Phi-4-multimodal-instruct"),
2575+
])
2576+
def test_ptp_quickstart_multimodal_multiturn(llm_root, llm_venv, model_name,
2577+
model_path):
2578+
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
2579+
test_data_root = Path(
2580+
os.path.join(llm_models_root(), "multimodals", "test_data"))
2581+
2582+
print(f"Accuracy test {model_name} image mode with example inputs.")
2583+
2584+
# Define accuracy inputs for image modality
2585+
accuracy_inputs = {
2586+
"image": {
2587+
"prompt": [
2588+
"Describe what you see in this image.",
2589+
"How would you describe the atmosphere of this scene?",
2590+
],
2591+
"media": [
2592+
str(test_data_root / "inpaint.png"),
2593+
],
2594+
}
2595+
}
2596+
2597+
# Define expected keywords for each model
2598+
expected_keywords = {
2599+
"gemma-3-27b-it": {
2600+
"image": [
2601+
["half", "dome", "yosemite", "landmark", "rounded"],
2602+
["atmosphere", "peaceful", "majestic", "calm", "quiet"],
2603+
],
2604+
},
2605+
"mistral-small-3.1-24b-instruct": {
2606+
"image": [
2607+
["depicts", "landscape", "rock", "sky", "high", "altitude"],
2608+
["atmosphere", "serene", "majestic", "sense", "tranquility"],
2609+
],
2610+
},
2611+
"Phi-4-multimodal-instruct": {
2612+
"image": [
2613+
["depicts", "landscape", "mountain", "half", "dome"],
2614+
["atmosphere", "serene", "sense", "tranquility", "peace."],
2615+
],
2616+
},
2617+
}
2618+
# Build command for image modality
2619+
cmd = [
2620+
str(example_root / "quickstart_multimodal.py"),
2621+
"--model_dir",
2622+
f"{llm_models_root()}/{model_path}",
2623+
"--modality",
2624+
"image",
2625+
"--multiturn",
2626+
"--prompt",
2627+
*accuracy_inputs["image"]["prompt"],
2628+
"--media",
2629+
*accuracy_inputs["image"]["media"],
2630+
]
2631+
2632+
# Add model-specific configurations
2633+
if model_name == "gemma-3-27b-it":
2634+
# Gemma3 VLM needs a custom mask which is only supported by flashinfer backend currently.
2635+
# Custom mask involves bidirectional masking of image tokens in context phase. To get this
2636+
# correct, chunked prefill and kv cache reuse need to be turned off.
2637+
cmd.append("--image_format=pil")
2638+
cmd.append("--attention_backend=FLASHINFER")
2639+
cmd.append("--disable_kv_cache_reuse")
2640+
elif model_name == "Phi-4-multimodal-instruct":
2641+
# Set max_seq_len to 4096 to use short rope factor.
2642+
cmd.append("--max_seq_len=4096")
2643+
cmd.append("--load_lora")
2644+
cmd.append("--auto_model_name")
2645+
cmd.append("Phi4MMForCausalLM")
2646+
2647+
output = llm_venv.run_cmd(cmd, caller=check_output)
2648+
print("output:", output)
2649+
# Set match ratio based on model
2650+
match_ratio = 4.0 / 5
2651+
if model_name == "Phi-4-multimodal-instruct":
2652+
match_ratio = 0.6
2653+
2654+
# Check output accuracy
2655+
for prompt_output, prompt_keywords in zip(
2656+
parse_output(output), expected_keywords[model_name]["image"]):
2657+
matches = [
2658+
keyword in prompt_output.lower() for keyword in prompt_keywords
2659+
]
2660+
obs_match_ratio = 1. * sum(matches) / len(matches)
2661+
print("prompt_output:", prompt_output)
2662+
print("prompt_keywords:", prompt_keywords)
2663+
print("matches:", matches)
2664+
print("obs_match_ratio:", obs_match_ratio)
2665+
assert obs_match_ratio >= match_ratio, f"Incorrect output!\nGenerated \"{prompt_output}\"\nExpected keywords \"{prompt_keywords}\"\n Matched keywords: {matches}\n Observed match ratio {obs_match_ratio} below threshold {match_ratio}"
2666+
2667+
print("All answers are correct!")
2668+
2669+
25682670
@pytest.mark.parametrize("model_name,model_path", [
25692671
("BertForSequenceClassification", "bert/bert-base-uncased-yelp-polarity"),
25702672
])

tests/integration/test_lists/qa/llm_function_full.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,9 @@ test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio]
662662
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[gemma-3-27b-it-gemma/gemma-3-27b-it]
663663
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503]
664664
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[Phi-4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct]
665+
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[gemma-3-27b-it-gemma/gemma-3-27b-it]
666+
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503]
667+
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[Phi-4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct]
665668
test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
666669
test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
667670
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]

0 commit comments

Comments
 (0)