Skip to content

Commit bb15eb9

Browse files
committed
Fix accuracy drop of gsm8k
Signed-off-by: Chenfei Zhang <[email protected]>
1 parent 55f4f2d commit bb15eb9

File tree

4 files changed

+19
-11
lines changed

4 files changed

+19
-11
lines changed

tensorrt_llm/executor/executor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ def generate_async(
134134
if postproc_params:
135135
postproc_params.postproc_args.num_prompt_tokens = len(
136136
prompt_token_ids)
137+
138+
print(f"[CF][generate_async] sampling_params is ")
139+
print(sampling_params)
137140
request = GenerationRequest(
138141
prompt_token_ids,
139142
sampling_params=sampling_params,

tests/integration/defs/accuracy/references/gsm8k.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ meta-llama/Llama-3.3-70B-Instruct:
1313
- accuracy: 83.78
1414
- quant_algo: NVFP4
1515
kv_cache_quant_algo: FP8
16-
accuracy: 88.70
16+
accuracy: 87.33
1717
- quant_algo: FP8
1818
kv_cache_quant_algo: FP8
19-
accuracy: 84.08
19+
accuracy: 90.30
2020
- quant_algo: FP8
21-
accuracy: 84.08
21+
accuracy: 90.30
2222
meta-llama/Llama-4-Maverick-17B-128E-Instruct:
2323
- accuracy: 92.20
2424
meta-llama/Llama-4-Scout-17B-16E-Instruct:

tests/integration/defs/accuracy/references/mmlu.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@ meta-llama/Llama-3.3-70B-Instruct:
6363
accuracy: 81.31
6464
- quant_algo: NVFP4
6565
kv_cache_quant_algo: FP8
66-
accuracy: 79.31
66+
accuracy: 78.78
6767
- quant_algo: FP8
6868
kv_cache_quant_algo: FP8
69-
accuracy: 81.02
69+
accuracy: 80.40
7070
- quant_algo: FP8
71-
accuracy: 80.34
71+
accuracy: 80.40
7272
meta-llama/Llama-4-Maverick-17B-128E-Instruct:
7373
- accuracy: 86.40
7474
- quant_algo: FP8

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def test_eagle3_tp8(self, eagle3_one_model):
463463
@pytest.mark.skip_less_device(4)
464464
@skip_pre_hopper
465465
def test_fp8_tp4(self):
466-
model_path = f"{llm_models_root()}/modelopt-hf-model-hub/Llama-3.3-70B-Instruct-fp8"
466+
model_path = f"{llm_models_root()}/llama-3.3-models/Llama-3.3-70B-Instruct-FP8"
467467
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5)
468468
with LLM(model_path,
469469
tensor_parallel_size=4,
@@ -472,6 +472,7 @@ def test_fp8_tp4(self):
472472
kv_cache_config=kv_cache_config) as llm:
473473
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
474474
sampling_params = SamplingParams(
475+
max_tokens=256,
475476
temperature=0.0,
476477
add_special_tokens=False,
477478
)
@@ -481,16 +482,21 @@ def test_fp8_tp4(self):
481482
task.evaluate(llm, sampling_params=sampling_params)
482483
task = GPQADiamond(self.MODEL_NAME)
483484
task.evaluate(llm,
484-
sampling_params=sampling_params,
485485
extra_evaluator_kwargs=dict(apply_chat_template=True))
486486

487487
@pytest.mark.skip_less_device(4)
488488
@skip_pre_blackwell
489489
def test_nvfp4_tp4(self):
490-
model_path = f"{llm_models_root()}/modelopt-hf-model-hub/Llama-3.3-70B-Instruct-fp4"
491-
with LLM(model_path, tensor_parallel_size=4) as llm:
490+
model_path = f"{llm_models_root()}/llama-3.3-models/Llama-3.3-70B-Instruct-FP4"
491+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5)
492+
with LLM(model_path,
493+
tensor_parallel_size=4,
494+
max_seq_len=8192,
495+
max_batch_size=32,
496+
kv_cache_config=kv_cache_config) as llm:
492497
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
493498
sampling_params = SamplingParams(
499+
max_tokens=256,
494500
temperature=0.0,
495501
add_special_tokens=False,
496502
)
@@ -500,7 +506,6 @@ def test_nvfp4_tp4(self):
500506
task.evaluate(llm, sampling_params=sampling_params)
501507
task = GPQADiamond(self.MODEL_NAME)
502508
task.evaluate(llm,
503-
sampling_params=sampling_params,
504509
extra_evaluator_kwargs=dict(apply_chat_template=True))
505510

506511

0 commit comments

Comments
 (0)