Skip to content

Commit 152a54e

Browse files
peaceh-nvdominicshanshan
authored andcommitted
[https://nvbugs/5449218][fix] Fix KvCacheConfig error in test_perf (NVIDIA#6937)
Signed-off-by: peaceh <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
1 parent 71cbb42 commit 152a54e

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

tests/integration/defs/perf/pytorch_model_config.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
Model pytorch yaml config for trtllm-bench perf tests
1818
"""
1919

20-
from tensorrt_llm.llmapi import KvCacheConfig
21-
2220

2321
def recursive_update(d, u):
2422
for k, v in u.items():
@@ -204,9 +202,10 @@ def get_model_yaml_config(model_label: str,
204202
'swap_gate_up_proj_lora_b_weight'] = False
205203
base_config.update(lora_config)
206204

207-
kv_cache_config = base_config.get('kv_cache_config', KvCacheConfig())
205+
kv_cache_config = base_config.get('kv_cache_config', {})
208206
if 'kv_cache_dtype' in base_config:
209-
kv_cache_config.dtype = base_config.pop('kv_cache_dtype', 'auto')
207+
kv_cache_dtype = base_config.pop('kv_cache_dtype', 'auto')
208+
kv_cache_config['dtype'] = kv_cache_dtype
210209
base_config.update({'kv_cache_config': kv_cache_config})
211210

212211
return base_config

0 commit comments

Comments
 (0)