|
14 | 14 | # limitations under the License.
|
15 | 15 |
|
16 | 16 | import os
|
| 17 | +import re |
17 | 18 | import subprocess
|
| 19 | +import tempfile |
18 | 20 |
|
19 | 21 | import pytest
|
20 |
| -from defs.conftest import skip_arm, skip_no_hopper |
21 |
| -from defs.trt_test_alternative import check_call, popen |
| 22 | +import yaml |
| 23 | +from defs.conftest import llm_models_root, skip_arm, skip_no_hopper |
| 24 | +from defs.trt_test_alternative import check_call, check_output, popen |
22 | 25 |
|
23 | 26 | from tensorrt_llm.logger import logger
|
24 | 27 |
|
@@ -1051,3 +1054,227 @@ def test_disaggregated_deepseek_v3_lite_fp8_tp1_two_mtp(
|
1051 | 1054 | "deepseek_v3_lite_fp8_tp1_two_mtp",
|
1052 | 1055 | env=llm_venv._new_env,
|
1053 | 1056 | cwd=llm_venv.get_working_directory())
|
| 1057 | + |
| 1058 | + |
| 1059 | +@pytest.fixture(scope="module") |
| 1060 | +def benchmark_root(): |
| 1061 | + llm_root = os.getenv("LLM_ROOT") |
| 1062 | + return os.path.join(llm_root, "tensorrt_llm", "serve", "scripts") |
| 1063 | + |
| 1064 | + |
| 1065 | +@pytest.fixture(scope="module") |
| 1066 | +def shared_gpt_path(): |
| 1067 | + DEFAULT_LLM_MODEL_ROOT = os.path.join("/scratch.trt_llm_data", "llm-models") |
| 1068 | + LLM_MODELS_ROOT = os.environ.get("LLM_MODELS_ROOT", DEFAULT_LLM_MODEL_ROOT) |
| 1069 | + return os.path.join(LLM_MODELS_ROOT, "datasets", |
| 1070 | + "ShareGPT_V3_unfiltered_cleaned_split.json") |
| 1071 | + |
| 1072 | + |
| 1073 | +@pytest.fixture(scope="function") |
| 1074 | +def benchmark_model_root(request): |
| 1075 | + models_root = llm_models_root() |
| 1076 | + if (request.param == "DeepSeek-V3-Lite-fp8"): |
| 1077 | + model_path = os.path.join(models_root, "DeepSeek-V3-Lite", "fp8") |
| 1078 | + elif (request.param == "DeepSeek-V3-Lite-bf16"): |
| 1079 | + model_path = os.path.join(models_root, "DeepSeek-V3-Lite", "bf16") |
| 1080 | + elif request.param == "llama-v3-8b-hf": |
| 1081 | + model_path = os.path.join(models_root, "llama-models-v3", "8B") |
| 1082 | + elif request.param == "llama-3.1-8b-instruct-hf-fp8": |
| 1083 | + model_path = os.path.join(models_root, "llama-3.1-model", |
| 1084 | + "Llama-3.1-8B-Instruct-FP8") |
| 1085 | + else: |
| 1086 | + raise ValueError(f"Failed to find the model: {request.param}") |
| 1087 | + return model_path |
| 1088 | + |
| 1089 | + |
| 1090 | +def run_disaggregated_benchmark(example_dir, |
| 1091 | + config_file, |
| 1092 | + benchmark_root, |
| 1093 | + benchmark_model_root, |
| 1094 | + shared_gpt_path, |
| 1095 | + env=None, |
| 1096 | + cwd=None): |
| 1097 | + """Run disaggregated test with given configuration.""" |
| 1098 | + run_env = env.copy() |
| 1099 | + run_env["UCX_TLS"] = "^ib" |
| 1100 | + num_rank = 2 |
| 1101 | + workers_cmd = [ |
| 1102 | + 'mpirun', '--allow-run-as-root', '--oversubscribe', '-n', |
| 1103 | + str(num_rank), 'trtllm-serve', 'disaggregated_mpi_worker', '-c', |
| 1104 | + config_file |
| 1105 | + ] |
| 1106 | + |
| 1107 | + server_start_timeout = 900 |
| 1108 | + server_cmd = [ |
| 1109 | + 'trtllm-serve', 'disaggregated', '--server_start_timeout', |
| 1110 | + str(server_start_timeout), '-c', config_file |
| 1111 | + ] |
| 1112 | + try: |
| 1113 | + with ( # Start workers |
| 1114 | + open('output_workers.log', 'w') as output_workers, |
| 1115 | + popen(workers_cmd, |
| 1116 | + stdout=output_workers, |
| 1117 | + stderr=subprocess.STDOUT, |
| 1118 | + env=run_env, |
| 1119 | + cwd=cwd) as workers_proc, |
| 1120 | + # Start server |
| 1121 | + open('output_disagg.log', 'w') as output_disagg, |
| 1122 | + popen(server_cmd, |
| 1123 | + stdout=output_disagg, |
| 1124 | + stderr=subprocess.STDOUT, |
| 1125 | + env=run_env, |
| 1126 | + cwd=cwd) as server_proc): |
| 1127 | + # Ensure the sever has started |
| 1128 | + client_dir = f"{example_dir}/clients" |
| 1129 | + client_cmd = [ |
| 1130 | + 'python3', f'{client_dir}/disagg_client.py', '-c', |
| 1131 | + f'{example_dir}/disagg_config.yaml', '-p', |
| 1132 | + f'{client_dir}/prompts.json', '--ignore-eos', |
| 1133 | + '--server-start-timeout', |
| 1134 | + str(server_start_timeout) |
| 1135 | + ] |
| 1136 | + # Warm up |
| 1137 | + check_call(client_cmd, |
| 1138 | + env=env, |
| 1139 | + poll_procs=[workers_proc, server_proc]) |
| 1140 | + # Start Benchmark |
| 1141 | + benchmark_script = os.path.join(benchmark_root, |
| 1142 | + "benchmark_serving.py") |
| 1143 | + benchmark_cmd = [ |
| 1144 | + 'python3', |
| 1145 | + benchmark_script, |
| 1146 | + '--model', |
| 1147 | + benchmark_model_root, |
| 1148 | + '--tokenizer', |
| 1149 | + benchmark_model_root, |
| 1150 | + '--dataset-name', |
| 1151 | + 'random', |
| 1152 | + '--dataset-path', |
| 1153 | + shared_gpt_path, |
| 1154 | + '--random-input-len', |
| 1155 | + '256', |
| 1156 | + '--random-output-len', |
| 1157 | + '64', |
| 1158 | + '--random-prefix-len', |
| 1159 | + '0', |
| 1160 | + '--num-prompts', |
| 1161 | + '320', |
| 1162 | + '--max-concurrency', |
| 1163 | + '32', |
| 1164 | + '--host', |
| 1165 | + 'localhost', |
| 1166 | + '--port', |
| 1167 | + '8000', |
| 1168 | + '--ignore-eos', |
| 1169 | + '--no-test-input', |
| 1170 | + '--percentile-metrics', |
| 1171 | + 'e2el,ttft', |
| 1172 | + ] |
| 1173 | + # warm up |
| 1174 | + check_call(benchmark_cmd, env=env) |
| 1175 | + output = check_output(benchmark_cmd, env=env) |
| 1176 | + e2el_pattern = r"Median E2EL \(ms\):\s*(\d+\.?\d*)" |
| 1177 | + ttft_pattern = r"Median TTFT \(ms\):\s*(\d+\.?\d*)" |
| 1178 | + e2el_match = re.search(e2el_pattern, output) |
| 1179 | + ttft_match = re.search(ttft_pattern, output) |
| 1180 | + if e2el_match and ttft_match: |
| 1181 | + median_e2el = float(e2el_match.group(1)) |
| 1182 | + median_ttft = float(ttft_match.group(1)) |
| 1183 | + return median_e2el, median_ttft |
| 1184 | + else: |
| 1185 | + raise ValueError("No benchmark result found") |
| 1186 | + |
| 1187 | + except Exception: |
| 1188 | + # Print outputs on error |
| 1189 | + logger.error("-------- Workers output --------") |
| 1190 | + with open('output_workers.log', 'r') as f: |
| 1191 | + logger.error(f.read()) |
| 1192 | + |
| 1193 | + logger.error("-------- Disagg server output --------") |
| 1194 | + with open('output_disagg.log', 'r') as f: |
| 1195 | + logger.error(f.read()) |
| 1196 | + raise |
| 1197 | + finally: |
| 1198 | + server_proc.terminate() |
| 1199 | + workers_proc.terminate() |
| 1200 | + server_proc.wait() |
| 1201 | + workers_proc.wait() |
| 1202 | + |
| 1203 | + |
| 1204 | +def get_config_for_benchmark(model_root, backend): |
| 1205 | + serve_config = { |
| 1206 | + "model": model_root, |
| 1207 | + "hostname": "localhost", |
| 1208 | + "port": 8000, |
| 1209 | + "backend": "pytorch", |
| 1210 | + "context_servers": { |
| 1211 | + "num_instances": 1, |
| 1212 | + "max_batch_size": 2, |
| 1213 | + "max_num_tokens": 384, |
| 1214 | + "max_seq_len": 384, |
| 1215 | + "tensor_parallel_size": 1, |
| 1216 | + "pipeline_parallel_size": 1, |
| 1217 | + "disable_overlap_scheduler": True, |
| 1218 | + "cache_transceiver_config": { |
| 1219 | + "backend": backend, |
| 1220 | + "max_tokens_in_buffer": 512, |
| 1221 | + }, |
| 1222 | + "urls": ["localhost:8001"] |
| 1223 | + }, |
| 1224 | + "generation_servers": { |
| 1225 | + "num_instances": 1, |
| 1226 | + "tensor_parallel_size": 1, |
| 1227 | + "pipeline_parallel_size": 1, |
| 1228 | + "max_batch_size": 2, |
| 1229 | + "max_num_tokens": 384, |
| 1230 | + "max_seq_len": 384, |
| 1231 | + "cache_transceiver_config": { |
| 1232 | + "backend": backend, |
| 1233 | + "max_tokens_in_buffer": 512, |
| 1234 | + }, |
| 1235 | + "urls": ["localhost:8002"] |
| 1236 | + } |
| 1237 | + } |
| 1238 | + return serve_config |
| 1239 | + |
| 1240 | + |
| 1241 | +@pytest.mark.parametrize("benchmark_model_root", [ |
| 1242 | + 'DeepSeek-V3-Lite-fp8', 'DeepSeek-V3-Lite-bf16', 'llama-v3-8b-hf', |
| 1243 | + 'llama-3.1-8b-instruct-hf-fp8' |
| 1244 | +], |
| 1245 | + indirect=True) |
| 1246 | +def test_disaggregated_benchmark_on_diff_backends( |
| 1247 | + disaggregated_test_root, disaggregated_example_root, llm_venv, |
| 1248 | + benchmark_model_root, benchmark_root, shared_gpt_path): |
| 1249 | + nixl_config = get_config_for_benchmark(benchmark_model_root, "nixl") |
| 1250 | + ucx_config = get_config_for_benchmark(benchmark_model_root, "ucx") |
| 1251 | + temp_dir = tempfile.TemporaryDirectory() |
| 1252 | + nixl_config_path = os.path.join(temp_dir.name, "nixl_config.yaml") |
| 1253 | + ucx_config_path = os.path.join(temp_dir.name, "ucx_config.yaml") |
| 1254 | + with open(nixl_config_path, 'w', encoding='utf-8') as f: |
| 1255 | + yaml.dump(nixl_config, f) |
| 1256 | + with open(ucx_config_path, 'w', encoding='utf-8') as f: |
| 1257 | + yaml.dump(ucx_config, f) |
| 1258 | + |
| 1259 | + env = llm_venv._new_env.copy() |
| 1260 | + nixl_e2el, nixl_ttft = run_disaggregated_benchmark( |
| 1261 | + disaggregated_example_root, |
| 1262 | + nixl_config_path, |
| 1263 | + benchmark_root, |
| 1264 | + benchmark_model_root, |
| 1265 | + shared_gpt_path, |
| 1266 | + env=env, |
| 1267 | + cwd=llm_venv.get_working_directory()) |
| 1268 | + ucx_e2el, ucx_ttft = run_disaggregated_benchmark( |
| 1269 | + disaggregated_example_root, |
| 1270 | + ucx_config_path, |
| 1271 | + benchmark_root, |
| 1272 | + benchmark_model_root, |
| 1273 | + shared_gpt_path, |
| 1274 | + env=env, |
| 1275 | + cwd=llm_venv.get_working_directory()) |
| 1276 | + print(f"Nixl E2EL: {nixl_e2el} ms, UCX E2EL: {ucx_e2el} ms") |
| 1277 | + print(f"Nixl TTFT: {nixl_ttft} ms, UCX TTFT: {ucx_ttft} ms") |
| 1278 | + |
| 1279 | + assert ucx_e2el > 0 and nixl_e2el > 0 and nixl_e2el < 1.05 * ucx_e2el |
| 1280 | + assert ucx_ttft > 0 and nixl_ttft > 0 and nixl_ttft < 1.05 * ucx_ttft |
0 commit comments