Skip to content

Commit 8895334

Browse files
committed
fix some tests
Signed-off-by: Bo Deng <[email protected]>
1 parent bbebac2 commit 8895334

File tree

2 files changed

+49
-44
lines changed

2 files changed

+49
-44
lines changed

tests/integration/defs/disaggregated/test_configs/disagg_config_for_benchmark.yaml

Lines changed: 0 additions & 29 deletions
This file was deleted.

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import re
1818
import subprocess
19+
import tempfile
1920

2021
import pytest
2122
import yaml
@@ -1201,6 +1202,43 @@ def run_disaggregated_benchmark(example_dir,
12011202
workers_proc.wait()
12021203

12031204

1205+
def get_config_for_benchmark(model_root, backend):
1206+
serve_config = {
1207+
"model": model_root,
1208+
"hostname": "localhost",
1209+
"port": 8000,
1210+
"backend": "pytorch",
1211+
"context_servers": {
1212+
"num_instances": 1,
1213+
"max_batch_size": 2,
1214+
"max_num_tokens": 384,
1215+
"max_seq_len": 320,
1216+
"tensor_parallel_size": 1,
1217+
"pipeline_parallel_size": 1,
1218+
"disable_overlap_scheduler": True,
1219+
"cache_transceiver_config": {
1220+
"backend": backend,
1221+
"max_tokens_in_buffer": 512,
1222+
},
1223+
"urls": ["localhost:8001"]
1224+
},
1225+
"generation_servers": {
1226+
"num_instances": 1,
1227+
"tensor_parallel_size": 1,
1228+
"pipeline_parallel_size": 1,
1229+
"max_batch_size": 2,
1230+
"max_num_tokens": 384,
1231+
"max_seq_len": 320,
1232+
"cache_transceiver_config": {
1233+
"backend": backend,
1234+
"max_tokens_in_buffer": 512,
1235+
},
1236+
"urls": ["localhost:8002"]
1237+
}
1238+
}
1239+
return serve_config
1240+
1241+
12041242
@pytest.mark.parametrize("benchmark_model_root", [
12051243
'DeepSeek-V3-Lite-fp8', 'DeepSeek-V3-Lite-bf16', 'llama-v3-8b-hf',
12061244
'llama-3.1-8b-instruct-hf-fp8'
@@ -1209,32 +1247,28 @@ def run_disaggregated_benchmark(example_dir,
12091247
def test_disaggregated_benchmark_on_diff_backends(
12101248
disaggregated_test_root, disaggregated_example_root, llm_venv,
12111249
benchmark_model_root, benchmark_root, shared_gpt_path):
1212-
base_config_path = os.path.join(os.path.dirname(__file__), "test_configs",
1213-
"disagg_config_for_benchmark.yaml")
1214-
with open(base_config_path, 'r', encoding='utf-8') as f:
1215-
config = yaml.load(f, Loader=yaml.SafeLoader)
1216-
config["model"] = benchmark_model_root
1217-
with open("ucx_config.yaml", 'w', encoding='utf-8') as ucx_config:
1218-
yaml.dump(config, ucx_config)
1219-
config["context_servers"]["cache_transceiver_config"][
1220-
"backend"] = "nixl"
1221-
config["generation_servers"]["cache_transceiver_config"][
1222-
"backend"] = "nixl"
1223-
with open("nixl_config.yaml", 'w', encoding='utf-8') as nixl_config:
1224-
yaml.dump(config, nixl_config)
1250+
nixl_config = get_config_for_benchmark(benchmark_model_root, "nixl")
1251+
ucx_config = get_config_for_benchmark(benchmark_model_root, "ucx")
1252+
temp_dir = tempfile.TemporaryDirectory()
1253+
nixl_config_path = os.path.join(temp_dir.name, "nixl_config.yaml")
1254+
ucx_config_path = os.path.join(temp_dir.name, "ucx_config.yaml")
1255+
with open(nixl_config_path, 'w', encoding='utf-8') as f:
1256+
yaml.dump(nixl_config, f)
1257+
with open(ucx_config_path, 'w', encoding='utf-8') as f:
1258+
yaml.dump(ucx_config, f)
12251259

12261260
env = llm_venv._new_env.copy()
12271261
nixl_e2el, nixl_ttft = run_disaggregated_benchmark(
12281262
disaggregated_example_root,
1229-
f"{os.path.dirname(__file__)}/nixl_config.yaml",
1263+
nixl_config_path,
12301264
benchmark_root,
12311265
benchmark_model_root,
12321266
shared_gpt_path,
12331267
env=env,
12341268
cwd=llm_venv.get_working_directory())
12351269
ucx_e2el, ucx_ttft = run_disaggregated_benchmark(
12361270
disaggregated_example_root,
1237-
f"{os.path.dirname(__file__)}/ucx_config.yaml",
1271+
ucx_config_path,
12381272
benchmark_root,
12391273
benchmark_model_root,
12401274
shared_gpt_path,

0 commit comments

Comments
 (0)