16
16
import os
17
17
import re
18
18
import subprocess
19
+ import tempfile
19
20
20
21
import pytest
21
22
import yaml
@@ -1201,6 +1202,43 @@ def run_disaggregated_benchmark(example_dir,
1201
1202
workers_proc .wait ()
1202
1203
1203
1204
1205
+ def get_config_for_benchmark (model_root , backend ):
1206
+ serve_config = {
1207
+ "model" : model_root ,
1208
+ "hostname" : "localhost" ,
1209
+ "port" : 8000 ,
1210
+ "backend" : "pytorch" ,
1211
+ "context_servers" : {
1212
+ "num_instances" : 1 ,
1213
+ "max_batch_size" : 2 ,
1214
+ "max_num_tokens" : 384 ,
1215
+ "max_seq_len" : 320 ,
1216
+ "tensor_parallel_size" : 1 ,
1217
+ "pipeline_parallel_size" : 1 ,
1218
+ "disable_overlap_scheduler" : True ,
1219
+ "cache_transceiver_config" : {
1220
+ "backend" : backend ,
1221
+ "max_tokens_in_buffer" : 512 ,
1222
+ },
1223
+ "urls" : ["localhost:8001" ]
1224
+ },
1225
+ "generation_servers" : {
1226
+ "num_instances" : 1 ,
1227
+ "tensor_parallel_size" : 1 ,
1228
+ "pipeline_parallel_size" : 1 ,
1229
+ "max_batch_size" : 2 ,
1230
+ "max_num_tokens" : 384 ,
1231
+ "max_seq_len" : 320 ,
1232
+ "cache_transceiver_config" : {
1233
+ "backend" : backend ,
1234
+ "max_tokens_in_buffer" : 512 ,
1235
+ },
1236
+ "urls" : ["localhost:8002" ]
1237
+ }
1238
+ }
1239
+ return serve_config
1240
+
1241
+
1204
1242
@pytest .mark .parametrize ("benchmark_model_root" , [
1205
1243
'DeepSeek-V3-Lite-fp8' , 'DeepSeek-V3-Lite-bf16' , 'llama-v3-8b-hf' ,
1206
1244
'llama-3.1-8b-instruct-hf-fp8'
@@ -1209,32 +1247,28 @@ def run_disaggregated_benchmark(example_dir,
1209
1247
def test_disaggregated_benchmark_on_diff_backends (
1210
1248
disaggregated_test_root , disaggregated_example_root , llm_venv ,
1211
1249
benchmark_model_root , benchmark_root , shared_gpt_path ):
1212
- base_config_path = os .path .join (os .path .dirname (__file__ ), "test_configs" ,
1213
- "disagg_config_for_benchmark.yaml" )
1214
- with open (base_config_path , 'r' , encoding = 'utf-8' ) as f :
1215
- config = yaml .load (f , Loader = yaml .SafeLoader )
1216
- config ["model" ] = benchmark_model_root
1217
- with open ("ucx_config.yaml" , 'w' , encoding = 'utf-8' ) as ucx_config :
1218
- yaml .dump (config , ucx_config )
1219
- config ["context_servers" ]["cache_transceiver_config" ][
1220
- "backend" ] = "nixl"
1221
- config ["generation_servers" ]["cache_transceiver_config" ][
1222
- "backend" ] = "nixl"
1223
- with open ("nixl_config.yaml" , 'w' , encoding = 'utf-8' ) as nixl_config :
1224
- yaml .dump (config , nixl_config )
1250
+ nixl_config = get_config_for_benchmark (benchmark_model_root , "nixl" )
1251
+ ucx_config = get_config_for_benchmark (benchmark_model_root , "ucx" )
1252
+ temp_dir = tempfile .TemporaryDirectory ()
1253
+ nixl_config_path = os .path .join (temp_dir .name , "nixl_config.yaml" )
1254
+ ucx_config_path = os .path .join (temp_dir .name , "ucx_config.yaml" )
1255
+ with open (nixl_config_path , 'w' , encoding = 'utf-8' ) as f :
1256
+ yaml .dump (nixl_config , f )
1257
+ with open (ucx_config_path , 'w' , encoding = 'utf-8' ) as f :
1258
+ yaml .dump (ucx_config , f )
1225
1259
1226
1260
env = llm_venv ._new_env .copy ()
1227
1261
nixl_e2el , nixl_ttft = run_disaggregated_benchmark (
1228
1262
disaggregated_example_root ,
1229
- f" { os . path . dirname ( __file__ ) } /nixl_config.yaml" ,
1263
+ nixl_config_path ,
1230
1264
benchmark_root ,
1231
1265
benchmark_model_root ,
1232
1266
shared_gpt_path ,
1233
1267
env = env ,
1234
1268
cwd = llm_venv .get_working_directory ())
1235
1269
ucx_e2el , ucx_ttft = run_disaggregated_benchmark (
1236
1270
disaggregated_example_root ,
1237
- f" { os . path . dirname ( __file__ ) } /ucx_config.yaml" ,
1271
+ ucx_config_path ,
1238
1272
benchmark_root ,
1239
1273
benchmark_model_root ,
1240
1274
shared_gpt_path ,
0 commit comments