7
7
import torch
8
8
9
9
from tensorrt_llm import LLM , SamplingParams
10
+ from tensorrt_llm .llmapi .llm_args import KvCacheConfig
10
11
from tensorrt_llm .mapping import CpType
11
- from tensorrt_llm .models .modeling_utils import QuantAlgo , QuantConfig
12
12
13
13
14
14
def dump_jsonl (data , fname ):
@@ -54,11 +54,8 @@ def similarity_score(a, b):
54
54
return SequenceMatcher (None , a , b ).ratio ()
55
55
56
56
57
- # Generate the outputs using either TRT or PyTorch (based on the use_pytorch argument). It’s the same function for both workflows.
58
57
def generate_llm_outputs (args , data , fp8 = False , fp8_kv_cache = False ):
59
- quant_config = QuantConfig (quant_algo = QuantAlgo .FP8 ,
60
- kv_cache_quant_algo = QuantAlgo .FP8 if fp8_kv_cache
61
- else None ) if fp8 else QuantConfig ()
58
+ kv_cache_config = KvCacheConfig (dtype = "fp8" if fp8_kv_cache else "auto" )
62
59
cp_config = {
63
60
"cp_type" : CpType .STAR ,
64
61
"cp_anchor_size" : args .sa_anchor_size ,
@@ -70,7 +67,7 @@ def generate_llm_outputs(args, data, fp8=False, fp8_kv_cache=False):
70
67
max_input_len = args .max_input_len ,
71
68
max_seq_len = args .max_seq_len ,
72
69
max_num_tokens = args .max_num_tokens ,
73
- quant_config = quant_config ,
70
+ kv_cache_config = kv_cache_config ,
74
71
tensor_parallel_size = 1 ,
75
72
context_parallel_size = args .num_procs ,
76
73
cp_config = cp_config ,
0 commit comments