From 105f48312c46fb41e1d8ee1d8a8ae12e881af9fc Mon Sep 17 00:00:00 2001 From: Harish Subramony Date: Fri, 17 Oct 2025 19:47:14 +0000 Subject: [PATCH 1/2] initial port for heterogenous support Signed-off-by: Harish Subramony --- examples/nixl/run_accuracy_test.sh | 33 +- examples/nixl/run_benchmark_profile.sh | 20 +- examples/nixl/run_benchmark_test.sh | 118 +- examples/nixl/run_benchmark_test_heter.sh | 318 +++++ examples/nixl/run_tpu_disagg_accuracy_test.sh | 159 +++ examples/nixl/test_accuracy.py | 30 +- examples/nixl/test_edge_cases.py | 39 +- examples/nixl/toy_proxy_server.py | 216 ++-- vllm_gaudi/attention/backends/hpu_attn.py | 7 +- .../kv_connector/v1/hpu_nixl_connector.py | 1021 ++++++++++++++++- vllm_gaudi/v1/worker/hpu_model_runner.py | 78 +- 11 files changed, 1804 insertions(+), 235 deletions(-) create mode 100644 examples/nixl/run_benchmark_test_heter.sh create mode 100644 examples/nixl/run_tpu_disagg_accuracy_test.sh diff --git a/examples/nixl/run_accuracy_test.sh b/examples/nixl/run_accuracy_test.sh index 237452634..d00385722 100755 --- a/examples/nixl/run_accuracy_test.sh +++ b/examples/nixl/run_accuracy_test.sh @@ -1,27 +1,18 @@ #!/bin/bash -#set -xe +set -xe # Models to run -MODELS=( - "Qwen/Qwen3-0.6B" -) #MODELS=( -# "meta-llama/Llama-3.1-8B" +# "Qwen/Qwen3-0.6B" #) +MODELS=( + "meta-llama/Llama-3.1-8B-Instruct" +) export VLLM_USE_V1=1 export VLLM_SKIP_WARMUP="true" export PT_HPU_LAZY_MODE=1 -NIXL_BUFFER_DEVICE=${NIXL_BUFFER_DEVICE:-"cpu"} -VLLM_NIXL_BACKEND=${VLLM_NIXL_BACKEND:-"UCX"} - -if [ "$VLLM_NIXL_BACKEND" == "UCX" ]; then - export VLLM_NIXL_DEVICE_TO_DEVICE=false -else - export VLLM_NIXL_DEVICE_TO_DEVICE=true -fi - # Number of prefill and decode instances to create NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1 NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1 @@ -96,9 +87,9 @@ run_tests_for_model() { GPU_ID=2 # Calculate port number (base port + instance number) - PORT=$((8700 + i)) + PORT=$((8300 + i)) # Calculate side channel port. Avoid clash with with TP workers. - SIDE_CHANNEL_PORT=$((6559 + i)) + SIDE_CHANNEL_PORT=$((5559 + i)) echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT" @@ -109,7 +100,7 @@ run_tests_for_model() { --max_num_batched_tokens 8192 \ --gpu-memory-utilization 0.3 \ --tensor-parallel-size $PREFILLER_TP_SIZE \ - --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"${NIXL_BUFFER_DEVICE}\", \"kv_connector_extra_config\":{\"backends\":[\"${VLLM_NIXL_BACKEND}\"]}}'" + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'" if [ -n "$model_args" ]; then FULL_CMD="$BASE_CMD $model_args" @@ -129,7 +120,7 @@ run_tests_for_model() { # Calculate GPU ID - we'll distribute across available GPUs, starting from after prefill GPUs #GPU_ID=$(((i + NUM_PREFILL_INSTANCES) % $(get_num_gpus))) # Calculate port number (base port + instance number) - PORT=$((8800 + i)) + PORT=$((8400 + i)) # Calculate side channel port SIDE_CHANNEL_PORT=$((5659 + i * $DECODER_TP_SIZE)) @@ -142,7 +133,7 @@ run_tests_for_model() { --max_num_batched_tokens 8192 \ --gpu-memory-utilization 0.3 \ --tensor-parallel-size $DECODER_TP_SIZE \ - --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"${NIXL_BUFFER_DEVICE}\", \"kv_connector_extra_config\":{\"backends\":[\"${VLLM_NIXL_BACKEND}\"]}}'" + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'" if [ -n "$model_args" ]; then FULL_CMD="$BASE_CMD $model_args" @@ -169,7 +160,7 @@ run_tests_for_model() { done # Build the command for the proxy server with all the hosts and ports - PROXY_CMD="python toy_proxy_server.py --port 9195" + PROXY_CMD="python toy_proxy_server.py --port 9192" # Add all prefill hosts and ports PROXY_CMD+=" --prefiller-hosts ${PREFILL_HOSTS[@]}" @@ -184,7 +175,7 @@ run_tests_for_model() { $PROXY_CMD & # Wait for the proxy to start - sleep 20 + sleep 10 # curl -X POST -s http://localhost:9192/v1/completions \ # -H "Content-Type: application/json" \ diff --git a/examples/nixl/run_benchmark_profile.sh b/examples/nixl/run_benchmark_profile.sh index 9f8916ed9..af2d65ef6 100644 --- a/examples/nixl/run_benchmark_profile.sh +++ b/examples/nixl/run_benchmark_profile.sh @@ -20,7 +20,7 @@ export VLLM_USE_V1=1 export VLLM_SKIP_WARMUP=True export PT_HPU_LAZY_MODE=1 export HABANA_PROFILE=1 -Enable full vLLM Profiler and instruct where to save the profiling: +#Enable full vLLM Profiler and instruct where to save the profiling: export VLLM_PROFILER_ENABLED=1 export VLLM_TORCH_PROFILER_DIR=./ @@ -106,7 +106,7 @@ run_tests_for_model() { echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT" # Build the command with or without model-specific args - BASE_CMD="RANK=0 UCX_TLS=rc,ud,ib VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ + BASE_CMD="RANK=0 HABANA_VISIBLE_DEVICES=2 UCX_TLS=tcp VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ --port $PORT \ --long_prefill_token_threshold 8192 \ --max_num_batched_tokens 8192 \ @@ -135,12 +135,12 @@ run_tests_for_model() { # Calculate port number (base port + instance number) PORT=$((8400)) # Calculate side channel port - SIDE_CHANNEL_PORT=$((5659 + i * $DECODER_TP_SIZE)) + SIDE_CHANNEL_PORT=$((5559 + i * $DECODER_TP_SIZE)) echo "Starting decode instance $i on GPU $GPU_ID, port $PORT" # Build the command with or without model-specific args - BASE_CMD="RANK=1 UCX_TLS=rc,ud,ib VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ + BASE_CMD="RANK=1 HABANA_VISIBLE_DEVICES=3 UCX_TLS=tcp VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ --port $PORT \ --gpu-memory-utilization 0.3 \ --tensor-parallel-size $DECODER_TP_SIZE \ @@ -189,7 +189,7 @@ run_tests_for_model() { $PROXY_CMD & # Wait for the proxy to start - sleep 500 + sleep 50 # curl -X POST -s http://localhost:9191/v1/completions \ # -H "Content-Type: application/json" \ @@ -234,14 +234,14 @@ run_tests_for_model() { --random-output-len 5 \ --num-prompts 10 \ --burstiness 100 \ - --request-rate 3.6 \ + --request-rate 0.1 \ --metric-percentiles 95 \ --percentile-metrics ttft,tpot,itl,e2el \ --backend openai \ --endpoint /v1/completions \ --ignore-eos - sleep 1000 + sleep 10 curl -X POST http://localhost:8300/start_profile curl -X POST http://localhost:8400/start_profile @@ -254,7 +254,7 @@ run_tests_for_model() { --random-output-len 5 \ --num-prompts 10 \ --burstiness 100 \ - --request-rate 3.6 \ + --request-rate 0.1 \ --metric-percentiles 95 \ --percentile-metrics ttft,tpot,itl,e2el \ --backend openai \ @@ -262,11 +262,11 @@ run_tests_for_model() { --ignore-eos - sleep 500 + sleep 10 curl -X POST http://localhost:8300/stop_profile curl -X POST http://localhost:8400/stop_profile - sleep 500 + sleep 10 # Clean up before running next model cleanup_instances sleep 3 diff --git a/examples/nixl/run_benchmark_test.sh b/examples/nixl/run_benchmark_test.sh index c9b5ba192..6ac657a65 100755 --- a/examples/nixl/run_benchmark_test.sh +++ b/examples/nixl/run_benchmark_test.sh @@ -13,14 +13,19 @@ set -xe MODELS=( "/root/software/data/pytorch/huggingface/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/" ) + export VLLM_USE_V1=1 -#export VLLM_SKIP_WARMUP=True +export VLLM_SKIP_WARMUP=True export PT_HPU_LAZY_MODE=1 export VLLM_EXPONENTIAL_BUCKETING=False #export VLLM_PROMPT_BS_BUCKET_MIN=1 #export VLLM_PROMPT_SEQ_BUCKET_MIN=1 -#export VLLM_PROMPT_SEQ_BUCKET_STEP=8192 -#export VLLM_PROMPT_SEQ_BUCKET_MAX=8192 +export VLLM_PROMPT_SEQ_BUCKET_MIN=8192 +export VLLM_PROMPT_SEQ_BUCKET_STEP=8192 +export VLLM_PROMPT_SEQ_BUCKET_MAX=8192 +export VLLM_DECODE_BLOCK_BUCKET_MIN=1024 +export VLLM_DECODE_BLOCK_BUCKET_MAX=1184 +export VLLM_USE_PADDING_AWARE_SCHEDULING=1 # Number of prefill and decode instances to create NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1 @@ -98,17 +103,17 @@ run_tests_for_model() { # Calculate port number (base port + instance number) PORT=$((8300 + i)) # Calculate side channel port. Avoid clash with with TP workers. - SIDE_CHANNEL_PORT=$((6559 + i)) + SIDE_CHANNEL_PORT=$((5559 + i)) echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT" # Build the command with or without model-specific args - BASE_CMD="RANK=0 UCX_TLS=rc,ud,ib VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ + BASE_CMD="HABANA_VISIBLE_DEVICES=0 RANK=0 UCX_TLS=rc,ud,ib VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ --port $PORT \ --long_prefill_token_threshold 8192 \ --max_num_batched_tokens 8192 \ - --gpu-memory-utilization 0.3 \ --disable-log-requests \ + --gpu-memory-utilization 0.3 \ --tensor-parallel-size $PREFILLER_TP_SIZE \ --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'" @@ -132,12 +137,12 @@ run_tests_for_model() { # Calculate port number (base port + instance number) PORT=$((8400 + i)) # Calculate side channel port - SIDE_CHANNEL_PORT=$((5659 + i * $DECODER_TP_SIZE)) + SIDE_CHANNEL_PORT=$((4659 + i * $DECODER_TP_SIZE)) echo "Starting decode instance $i on GPU $GPU_ID, port $PORT" # Build the command with or without model-specific args - BASE_CMD="RANK=1 UCX_TLS=rc,ud,ib VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ + BASE_CMD="HABANA_VISIBLE_DEVICES=1 RANK=1 UCX_TLS=rc,ud,ib VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ --port $PORT \ --gpu-memory-utilization 0.3 \ --tensor-parallel-size $DECODER_TP_SIZE \ @@ -171,7 +176,7 @@ run_tests_for_model() { done # Build the command for the proxy server with all the hosts and ports - PROXY_CMD="python toy_proxy_server.py --port 9192" + PROXY_CMD="python toy_proxy_server.py --port 9111" # Add all prefill hosts and ports PROXY_CMD+=" --prefiller-hosts ${PREFILL_HOSTS[@]}" @@ -186,9 +191,9 @@ run_tests_for_model() { $PROXY_CMD & # Wait for the proxy to start - sleep 10 + sleep 100 -# curl -X POST -s http://localhost:9192/v1/completions \ +# curl -X POST -s http://localhost:9111/v1/completions \ # -H "Content-Type: application/json" \ # -d '{ # "model": "meta-llama/Llama-3.1-8B", @@ -198,7 +203,7 @@ run_tests_for_model() { # }' # sleep 5 # echo "--------------------===================-------------" -#curl -X POST -s http://localhost:9192/v1/completions \ +#curl -X POST -s http://localhost:9111/v1/completions \ # -H "Content-Type: application/json" \ # -d '{ # "model": "/root/software/data/pytorch/huggingface/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/", @@ -206,7 +211,7 @@ run_tests_for_model() { # "max_tokens": 5, # "temperature": 0 # }' - #curl -X POST -s http://localhost:9192/v1/completions \ + #curl -X POST -s http://localhost:9111/v1/completions \ # -H "Content-Type: application/json" \ # -d '{ # "model": "/root/software/data/pytorch/huggingface/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/", @@ -219,21 +224,78 @@ run_tests_for_model() { # Run lm eval for this model echo "Running tests for $model_name" #TEST_MODEL=$model_name python -m pytest -s -x test_accuracy.py - python3 ../../../../benchmarks/benchmark_serving.py \ - --port 9192 \ - --seed "$(date +%s)" \ - --model /root/software/data/pytorch/huggingface/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/ \ - --dataset-name random \ - --random-input-len 8192 \ - --random-output-len 200 \ - --num-prompts 100 \ - --burstiness 100 \ - --request-rate 3.6 \ - --metric-percentiles 95 \ - --percentile-metrics ttft,tpot,itl,e2el \ - --backend openai \ - --endpoint /v1/completions \ - --ignore-eos + #python3 ../../../../benchmarks/benchmark_serving.py \ + # --port 9111 \ + # --seed "$(date +%s)" \ + # --model /root/software/data/pytorch/huggingface/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/ \ + # --dataset-name random \ + # --random-input-len 8192 \ + # --random-output-len 256 \ + # --num-prompts 32 \ + # --burstiness 100 \ + # --request-rate 3.6 \ + # --metric-percentiles 95 \ + # --percentile-metrics ttft,tpot,itl,e2el \ + # --backend openai \ + # --endpoint /v1/completions \ + # --ignore-eos + + #sleep 100 + #python3 ../../../../benchmarks/benchmark_serving.py \ + # --port 8300 \ + # --seed "$(date +%s)" \ + # --model /root/software/data/pytorch/huggingface/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/ \ + # --dataset-name random \ + # --random-input-len 8192 \ + # --random-output-len 200 \ + # --num-prompts 100 \ + # --burstiness 100 \ + # --request-rate 3.6 \ + # --metric-percentiles 95 \ + # --percentile-metrics ttft,tpot,itl,e2el \ + # --backend openai \ + # --endpoint /v1/completions \ + # --ignore-eos + qps=(0.5) #(0.1 0.25 0.5 1 2 3 4) # 5) + # explicit num_prompts mapping (must have same length as qps[]) + num_prompts=(32) #(32 64 128 256 256 256 256) # 256) + input_len=8192 + output_len=256 #56 + + # just sanity‐check lengths + if [ "${#qps[@]}" -ne "${#num_prompts[@]}" ]; then + echo "❌ qps[] and num_prompts[] must be the same length" + exit 1 + fi + + for i in "${!qps[@]}"; do + q=${qps[$i]} + np=${num_prompts[$i]} + + ts=$(date +"%Y%m%d_%H%M%S") + logf="./nixlresult/run_in${input_len}_out${output_len}_qps${q//./p}_$ts.log" + + echo "[$(date +"%Y-%m-%d %H:%M:%S")] input=${input_len}, output=${output_len}, qps=${q}, num_prompts=${np}" \ + | tee "$logf" + + python3 ../../../../benchmarks/benchmark_serving.py \ + --port 9111 \ + --seed "$(date +%s)" \ + --model /root/software/data/pytorch/huggingface/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/ \ + --tokenizer /root/software/data/pytorch/huggingface/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/ \ + --dataset-name random \ + --random-input-len "$input_len" \ + --random-output-len 256 \ + --num-prompts "$np" \ + --request-rate "$q" \ + --percentile-metrics ttft,tpot,itl,e2el \ + --burstiness 100 \ + --backend openai \ + --endpoint /v1/completions \ + --ignore-eos \ + 2>&1 | tee -a "$logf" + + done # Clean up before running next model cleanup_instances diff --git a/examples/nixl/run_benchmark_test_heter.sh b/examples/nixl/run_benchmark_test_heter.sh new file mode 100644 index 000000000..83a10fbbd --- /dev/null +++ b/examples/nixl/run_benchmark_test_heter.sh @@ -0,0 +1,318 @@ +#!/bin/bash +set -xe + +# Models to run +#MODELS=( +# "Qwen/Qwen3-0.6B" +#) +#MODELS=( +# "meta-llama/Llama-3.1-8B" +#) + + +MODELS=( + "/root/software/data/pytorch/huggingface/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/" +) +#MODELS=( +# "Qwen/Qwen3-0.6B" +#) +export VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS=1000000 +export VLLM_RPC_TIMEOUT=1000000000 +export NIXL_LOG_LEVEL=debug +#export UCX_LOG_LEVEL=debug +export VLLM_USE_V1=1 +export VLLM_SKIP_WARMUP=True +export PT_HPU_LAZY_MODE=1 +export VLLM_EXPONENTIAL_BUCKETING=False +export VLLM_PROMPT_BS_BUCKET_MIN=1 +export VLLM_PROMPT_SEQ_BUCKET_MIN=1 +export VLLM_PROMPT_SEQ_BUCKET_MIN=8192 +export VLLM_PROMPT_SEQ_BUCKET_STEP=8192 +export VLLM_PROMPT_SEQ_BUCKET_MAX=8192 +export VLLM_DECODE_BLOCK_BUCKET_MIN=1024 +export VLLM_DECODE_BLOCK_BUCKET_MAX=1184 +export VLLM_USE_PADDING_AWARE_SCHEDULING=1 +export DECODER_TP_RATIO=2 + +# Number of prefill and decode instances to create +NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1 +NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1 +PREFILLER_TP_SIZE=2 #${PREFILLER_TP_SIZE:-1} +DECODER_TP_SIZE=4 #${DECODER_TP_SIZE:-1} + + +# Find the git repository root directory +#GIT_ROOT=$(git rev-parse --show-toplevel) +GIT_ROOT="/home/vllm-nixl/vllm" + +#SMI_BIN=$(which nvidia-smi || which rocm-smi) + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT + +# Waits for vLLM to start. +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Function to clean up previous instances +cleanup_instances() { + echo "Cleaning up any running vLLM instances..." + pkill -f "vllm serve" || true + sleep 2 +} + +# Handle to get model-specific arguments for deepseek +get_model_args() { + local model_name=$1 + local extra_args="" + + if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then + extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code" + fi + + echo "$extra_args" +} + +get_num_gpus() { + if [[ "$SMI_BIN" == *"nvidia"* ]]; then + echo "$($SMI_BIN --query-gpu=name --format=csv,noheader | wc -l)" + else + echo "$($SMI_BIN -l | grep GPU | wc -l)" + fi +} + +# Function to run tests for a specific model +run_tests_for_model() { + local model_name=$1 + echo "================================" + echo "Testing model: $model_name" + echo "================================" + + # Get model-specific arguments + local model_args=$(get_model_args "$model_name") + + # Arrays to store all hosts and ports + PREFILL_HOSTS=() + PREFILL_PORTS=() + DECODE_HOSTS=() + DECODE_PORTS=() + + # Start prefill instances + for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do + # Calculate GPU ID - we'll distribute across available GPUs + #GPU_ID=$((i % $(get_num_gpus))) + GPU_ID=2 + + # Calculate port number (base port + instance number) + PORT=$((8300 + i)) + # Calculate side channel port. Avoid clash with with TP workers. + SIDE_CHANNEL_PORT=$((5559 + i)) + + echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT" + + # Build the command with or without model-specific args + BASE_CMD="MY_ROLE=PREFILL UCX_TLS=rc,ud,ib VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ + --port $PORT \ + --long_prefill_token_threshold 8192 \ + --max_num_batched_tokens 8192 \ + --disable-log-requests \ + --gpu-memory-utilization 0.3 \ + --tensor-parallel-size $PREFILLER_TP_SIZE \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'" + + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi + + eval "$FULL_CMD &" + + # Store host and port for proxy configuration + PREFILL_HOSTS+=("localhost") + PREFILL_PORTS+=($PORT) + done + + # Start decode instances + for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do + # Calculate GPU ID - we'll distribute across available GPUs, starting from after prefill GPUs + #GPU_ID=$(((i + NUM_PREFILL_INSTANCES) % $(get_num_gpus))) + # Calculate port number (base port + instance number) + PORT=$((8400 + i)) + # Calculate side channel port + SIDE_CHANNEL_PORT=$((4659 + i * $DECODER_TP_SIZE)) + + echo "Starting decode instance $i on GPU $GPU_ID, port $PORT" + + # Build the command with or without model-specific args + BASE_CMD="MY_ROLE=DECODE UCX_TLS=rc,ud,ib VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ + --port $PORT \ + --gpu-memory-utilization 0.3 \ + --tensor-parallel-size $DECODER_TP_SIZE \ + --long_prefill_token_threshold 8192 \ + --max_num_batched_tokens 8192 \ + --disable-log-requests \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'" + + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi + + eval "$FULL_CMD &" + + # Store host and port for proxy configuration + DECODE_HOSTS+=("localhost") + DECODE_PORTS+=($PORT) + done + + # Wait for all instances to start + for PORT in "${PREFILL_PORTS[@]}"; do + echo "Waiting for prefill instance on port $PORT to start..." + wait_for_server $PORT + done + + for PORT in "${DECODE_PORTS[@]}"; do + echo "Waiting for decode instance on port $PORT to start..." + wait_for_server $PORT + done + + # Build the command for the proxy server with all the hosts and ports + PROXY_CMD="python toy_proxy_server.py --port 9111" + + # Add all prefill hosts and ports + PROXY_CMD+=" --prefiller-hosts ${PREFILL_HOSTS[@]}" + PROXY_CMD+=" --prefiller-ports ${PREFILL_PORTS[@]}" + + # Add all decode hosts and ports + PROXY_CMD+=" --decoder-hosts ${DECODE_HOSTS[@]}" + PROXY_CMD+=" --decoder-ports ${DECODE_PORTS[@]}" + + # Start the proxy server + echo "Starting proxy server with command: $PROXY_CMD" + $PROXY_CMD & + + # Wait for the proxy to start + sleep 10 + +# curl -X POST -s http://localhost:9111/v1/completions \ +# -H "Content-Type: application/json" \ +# -d '{ +# "model": "meta-llama/Llama-3.1-8B", +# "prompt": "Mark Elliot Zuckerberg is an American businessman who co-founded the social media service Facebook and its parent company Meta Platforms, of which he is the chairman, chief executive officer, and controlling shareholder. Zuckerberg has been the subject of multiple lawsuits regarding the creation and ownership of the website as well as issues such as user privacy. Born in White Plains, New York, Zuckerberg briefly attended Harvard College, where he launched Facebook in February 2004 with his roommates Eduardo Saverin, Andrew McCollum, Dustin Moskovitz and Chris Hughes. Zuckerberg took the company public in May 2012 with majority shares. He became the worlds youngest self-made billionaire[a] in 2008, at age 23, and has consistently ranked among the worlds wealthiest individuals. According to Forbes, Zuckerbergs estimated net worth stood at US$221.2 billion as of May 2025, making him the second-richest individual in the world.[2]", +# "max_tokens": 5, +# "temperature": 0 +# }' +# sleep 5 +# echo "--------------------===================-------------" +curl -X POST -s http://localhost:9111/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "/root/software/data/pytorch/huggingface/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/", + "prompt": "Mark Elliot Zuckerberg is an American businessman who co-founded the social media service Facebook and its parent company Meta Platforms, of which he is the chairman, chief executive officer, and controlling shareholder. Zuckerberg has been the subject of multiple lawsuits regarding the creation and ownership of the website as well as issues such as user privacy. Born in White Plains, New York, Zuckerberg briefly attended Harvard College, where he launched Facebook in February 2004 with his roommates Eduardo Saverin, Andrew McCollum, Dustin Moskovitz and Chris Hughes. Zuckerberg took the company public in May 2012 with majority shares. He became the worlds youngest self-made billionaire[a] in 2008, at age 23, and has consistently ranked among the worlds wealthiest individuals. According to Forbes, Zuckerbergs estimated net worth stood at US$221.2 billion as of May 2025, making him the second-richest individual in the world.[2] Intel opened its first international manufacturing facility in 1972, in Malaysia, which would host multiple Intel operations, before opening assembly facilities and semiconductor plants in Singapore and Jerusalem in the early 1980s, and manufacturing and development centers in China, India, and Costa Rica in the 1990s.[31] By the early 1980s, its business was dominated by DRAM chips. However, increased competition from Japanese semiconductor manufacturers had, by 1983, dramatically reduced the profitability of this market. The growing success of the IBM personal computer, based on an Intel microprocessor, was among factors that convinced Gordon Moore (CEO since 1975) to shift the companys focus to microprocessors and to change fundamental aspects of that business model. Moores decision to sole-source Intels 386 chip played into the companys continuing success.", + "max_tokens": 50, + "temperature": 0 + }' +#curl -X POST -s http://localhost:9111/v1/completions \ +# -H "Content-Type: application/json" \ +# -d '{ +# "model": "/root/software/data/pytorch/huggingface/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/", +# "prompt": ["This was a few months ago. It was my day off and the only thing I had to do was pick my girlfriend up from work at 9:00 pm. Other than that, I was free to loaf on the couch from morning to night, which is what I did. Around 8:00, I decided to shower before I left the house. Now, I have short hair that dries pretty quickly, but I am deeply vain about it, so I always dry it with the hairdryer right after I shower to ensure my hair doesnt get flat and weird. I never skip this step. So, I get out of the shower, start drying my hair... And then I wake up in bed. Its half an hour later. I feel like garbage, my entire body mysteriously hurts, and I am slowly realizing that I dont remember exiting the bathroom. My only clear thought is: oh shit, its 9:00! I have to pick up my girlfriend! Better shake myself awake. I dragged my aching carcass back to the bathroom, and this was when I noticed the massive blisters forming all over my hand. I was still pretty out of it, but I knew that this was a hospital visit kind of burn. My girlfriend then called to check in because I was running late and, despite my undoubtedly convincing argument that I was still perfectly fine to drive, she immediately knew something was wrong. She cabbed home and we got a ride to the ER. Turns out, I had my first ever seizure! It seems like during the seizure, I clenched the hairdryer in my fist and had it pointed at my other hand long enough to thoroughly cook it. The tissue loss is pretty deep in some areas and there was concerns about me retaining my mobility, but its been healing well so far.", +# "Mark Elliot Zuckerberg is an American businessman who co-founded the social media service Facebook and its parent company Meta Platforms, of which he is the chairman, chief executive officer, and controlling shareholder. Zuckerberg has been the subject of multiple lawsuits regarding the creation and ownership of the website as well as issues such as user privacy. Born in White Plains, New York, Zuckerberg briefly attended Harvard College, where he launched Facebook in February 2004 with his roommates Eduardo Saverin, Andrew McCollum, Dustin Moskovitz and Chris Hughes. Zuckerberg took the company public in May 2012 with majority shares. He became the worlds youngest self-made billionaire[a] in 2008, at age 23, and has consistently ranked among the worlds wealthiest individuals. According to Forbes, Zuckerbergs estimated net worth stood at US$221.2 billion as of May 2025, making him the second-richest individual in the world.[2]"], +# "max_tokens": 100, +# "temperature": 0 +# }' + #sleep 2 + # Run lm eval for this model + #echo "Running tests for $model_name" + #TEST_MODEL=$model_name python -m pytest -s -x test_accuracy.py + #python3 ../../../../benchmarks/benchmark_serving.py \ + # --port 9111 \ + # --seed "$(date +%s)" \ + # --model /root/software/data/pytorch/huggingface/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/ \ + # --dataset-name random \ + # --random-input-len 8192 \ + # --random-output-len 256 \ + # --num-prompts 32 \ + # --burstiness 100 \ + # --request-rate 3.6 \ + # --metric-percentiles 95 \ + # --percentile-metrics ttft,tpot,itl,e2el \ + # --backend openai \ + # --endpoint /v1/completions \ + # --ignore-eos + + #sleep 100 + #python3 ../../../../benchmarks/benchmark_serving.py \ + # --port 8300 \ + # --seed "$(date +%s)" \ + # --model /root/software/data/pytorch/huggingface/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/ \ + # --dataset-name random \ + # --random-input-len 8192 \ + # --random-output-len 200 \ + # --num-prompts 100 \ + # --burstiness 100 \ + # --request-rate 3.6 \ + # --metric-percentiles 95 \ + # --percentile-metrics ttft,tpot,itl,e2el \ + # --backend openai \ + # --endpoint /v1/completions \ + # --ignore-eos + qps=(0.5) #(0.1 0.25 0.5 1 2 3 4) # 5) + # explicit num_prompts mapping (must have same length as qps[]) + num_prompts=(32) #(32 64 128 256 256 256 256) # 256) + input_len=8192 + output_len=256 #56 + + # just sanity‐check lengths + #if [ "${#qps[@]}" -ne "${#num_prompts[@]}" ]; then + # echo "❌ qps[] and num_prompts[] must be the same length" + # exit 1 + #fi + + #for i in "${!qps[@]}"; do + #q=${qps[$i]} + #np=${num_prompts[$i]} + + #ts=$(date +"%Y%m%d_%H%M%S") + #logf="./nixlresult/run_in${input_len}_out${output_len}_qps${q//./p}_$ts.log" + + #echo "[$(date +"%Y-%m-%d %H:%M:%S")] input=${input_len}, output=${output_len}, qps=${q}, num_prompts=${np}" \ + # | tee "$logf" + + #python3 ../../../../benchmarks/benchmark_serving.py \ + # --port 9111 \ + # --seed "$(date +%s)" \ + # --model /root/software/data/pytorch/huggingface/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/ \ + # --tokenizer /root/software/data/pytorch/huggingface/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/ \ + # --dataset-name random \ + # --random-input-len "$input_len" \ + # --random-output-len 256 \ + # --num-prompts "$np" \ + # --request-rate "$q" \ + # --percentile-metrics ttft,tpot,itl,e2el \ + # --burstiness 100 \ + # --backend openai \ + # --endpoint /v1/completions \ + # --ignore-eos \ + # 2>&1 | tee -a "$logf" + + #done + + # Clean up before running next model + cleanup_instances + sleep 3 +} + +# Run tests for each model +for model in "${MODELS[@]}"; do + run_tests_for_model "$model" +done + +echo "All tests completed!" diff --git a/examples/nixl/run_tpu_disagg_accuracy_test.sh b/examples/nixl/run_tpu_disagg_accuracy_test.sh new file mode 100644 index 000000000..ea125f99f --- /dev/null +++ b/examples/nixl/run_tpu_disagg_accuracy_test.sh @@ -0,0 +1,159 @@ +#!/bin/bash +set -xe + +# Hosts / ports +PREFILL_HOST=${PREFILL_HOST:-"localhost"} +PREFILL_PORT=${PREFILL_PORT:-8100} +PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577} +DECODE_HOST=${DECODE_HOST:-"localhost"} +DECODE_PORT=${DECODE_PORT:-8200} +PROXY_HOST=${PROXY_HOST:-"localhost"} +PROXY_PORT=${PROXY_PORT:-8192} +BASELINE_HOST=${BASELINE_HOST:-"localhost"} +BASELINE_PORT=${BASELINE_PORT:-9290} + + +# Model to run. +MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"} +MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024} +BLOCK_SIZE=${BLOCK_SIZE:-32} + + +# execution env +GIT_ROOT=$(git rev-parse --show-toplevel) +EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration" +CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"} +CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"} + +OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"} + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT + + +# Waits for vLLM server to start. +wait_for_server() { + local host=$1 + local port=$2 + timeout 1200 bash -c " + until curl -s ${host}:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Cleanup function +cleanup() { + echo "Caught Ctrl+C, cleaning up..." + # Cleanup commands + pgrep python | xargs kill -9 || true + # pkill -f python || true + echo "Cleanup complete. Exiting." +} + +launch_baseline() { + BASELINE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME}; + VLLM_LOGGING_LEVEL=DEBUG \ + VLLM_USE_V1=1 \ + PJRT_DEVICE=TPU \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \ + --host ${BASELINE_HOST} \ + --port ${BASELINE_PORT} \ + --max-model-len ${MAX_MODEL_LEN}\ + --seed 42 \ + --block-size ${BLOCK_SIZE} \ + --gpu-memory-utilization 0.5 \ + --enforce-eager" + echo ${BASELINE_BASE_CMD} + ssh -tt ${BASELINE_HOST} "${BASELINE_BASE_CMD}" & +} + +launch_pd() { + PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME}; + UCX_TLS=tcp \ + VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \ + VLLM_LOGGING_LEVEL=DEBUG \ + VLLM_USE_V1=1 \ + VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \ + VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \ + PJRT_DEVICE=TPU \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \ + --host ${PREFILL_HOST} \ + --port ${PREFILL_PORT} \ + --max-model-len ${MAX_MODEL_LEN}\ + --seed 42 \ + --block-size ${BLOCK_SIZE} \ + --enforce-eager \ + --gpu-memory-utilization 0.5 \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'" + + + DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME}; + UCX_TLS=tcp \ + VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \ + VLLM_LOGGING_LEVEL=DEBUG \ + VLLM_USE_V1=1 \ + PJRT_DEVICE=TPU \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \ + --host ${DECODE_HOST} \ + --port ${DECODE_PORT} \ + --max-model-len ${MAX_MODEL_LEN}\ + --seed 42 \ + --block-size ${BLOCK_SIZE} \ + --enforce-eager \ + --gpu-memory-utilization 0.5 \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'" + + echo ${PREFILL_BASE_CMD} + echo ${DECODE_BASE_CMD} + sleep 2 + + # execute on hosts + ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" & + ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" & + sleep 1 + wait_for_server ${PREFILL_HOST} ${PREFILL_PORT} + sleep 1 + wait_for_server ${DECODE_HOST} ${DECODE_PORT} + sleep 1 +} + +launch_pd_proxy(){ + PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME}; + python3 ${EXP_ROOT}/toy_proxy_server.py \ + --prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \ + --decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \ + --host=${PROXY_HOST} --port ${PROXY_PORT}" + echo ${PROXY_BASE_CMD} + ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" & +} + +run_tests(){ + local service_url=$1 + local mode=$2 + python3 ${EXP_ROOT}/test_disagg_accuracy.py --service_url=${service_url} --model_name=${MODEL_NAME} --mode=${mode} --file_name=${OUTPUT_FILE} +} + + +# run non-disagg. baseline & save outputs +launch_baseline +sleep 2 +wait_for_server ${BASELINE_HOST} ${BASELINE_PORT} +run_tests "http://${BASELINE_HOST}:${BASELINE_PORT}" "baseline" +cleanup +sleep 10 + + +# run disagg. & do exact-match with the outputs from baseline +launch_pd +launch_pd_proxy +sleep 10 +run_tests "http://${PROXY_HOST}:${PROXY_PORT}" "disagg" +echo "-----P/D success----" + +rm ${OUTPUT_FILE} +cleanup + +exit 0 \ No newline at end of file diff --git a/examples/nixl/test_accuracy.py b/examples/nixl/test_accuracy.py index 8c179620b..f3381a31d 100644 --- a/examples/nixl/test_accuracy.py +++ b/examples/nixl/test_accuracy.py @@ -5,7 +5,7 @@ import lm_eval import openai -BASE_URL = "http://localhost:9195/v1" +BASE_URL = "http://localhost:9192/v1" NUM_CONCURRENT = 100 TASK = "gsm8k" FILTER = "exact_match,strict-match" @@ -14,10 +14,10 @@ # Model-specific expected values EXPECTED_VALUES = { "Qwen/Qwen3-0.6B": 0.41, - "deepseek-ai/deepseek-vl2-small": 0.59, + "deepseek-ai/deepseek-vl2-small": 0.59 } -SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means" # noqa: E501 +SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means", # noqa: E501 # Get model name from environment variable MODEL_NAME = os.environ.get("TEST_MODEL", "Qwen/Qwen3-0.6B") @@ -25,9 +25,8 @@ def run_simple_prompt(): client = openai.OpenAI(api_key="EMPTY", base_url=BASE_URL) - completion = client.completions.create( - model=MODEL_NAME, prompt=SIMPLE_PROMPT - ) # yapf: disable + completion = client.completions.create(model=MODEL_NAME, + prompt=SIMPLE_PROMPT) print("-" * 50) print(f"Completion results for {MODEL_NAME}:") @@ -39,28 +38,25 @@ def test_accuracy(): """Run the end to end accuracy test.""" run_simple_prompt() - model_args = ( - f"model={MODEL_NAME}," - f"base_url={BASE_URL}/completions," - f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False" - ) # yapf: disable + model_args = (f"model={MODEL_NAME}," + f"base_url={BASE_URL}/completions," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") results = lm_eval.simple_evaluate( model="local-completions", model_args=model_args, tasks=TASK, - limit=128, ) measured_value = results["results"][TASK][FILTER] expected_value = EXPECTED_VALUES.get(MODEL_NAME) if expected_value is None: - print( - f"Warning: No expected value found for {MODEL_NAME}. " - "Skipping accuracy check." - ) # yapf: disable + print(f"Warning: No expected value found for {MODEL_NAME}. " + "Skipping accuracy check.") print(f"Measured value: {measured_value}") return - assert measured_value + RTOL > expected_value, f"Expected: {expected_value} | Measured: {measured_value}" + assert (measured_value - RTOL < expected_value + and measured_value + RTOL > expected_value + ), f"Expected: {expected_value} | Measured: {measured_value}" diff --git a/examples/nixl/test_edge_cases.py b/examples/nixl/test_edge_cases.py index 520dd665b..95465a25f 100644 --- a/examples/nixl/test_edge_cases.py +++ b/examples/nixl/test_edge_cases.py @@ -10,8 +10,7 @@ if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None: raise ValueError( - "Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT." - ) # yapf: disable + "Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT.") LONG_PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result, when working on projects like vLLM we are able to meet many amazing people from various organizations like AMD, Google, NVIDIA, " # noqa: E501 PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result," # noqa: E501 @@ -39,13 +38,13 @@ def test_edge_cases(): # (1) Check that we can handle a very short prompt, # less than the length of the block size. - completion = proxy_client.completions.create( - model=MODEL, prompt=SHORT_PROMPT, temperature=0 - ) # yapf: disable + completion = proxy_client.completions.create(model=MODEL, + prompt=SHORT_PROMPT, + temperature=0) proxy_response = completion.choices[0].text - completion = prefill_client.completions.create( - model=MODEL, prompt=SHORT_PROMPT, temperature=0 - ) # yapf: disable + completion = prefill_client.completions.create(model=MODEL, + prompt=SHORT_PROMPT, + temperature=0) prefill_response = completion.choices[0].text print(f"SMALL PROMPT: {proxy_response=}") assert proxy_response == prefill_response @@ -53,27 +52,27 @@ def test_edge_cases(): # (2) Check that we can handle a full prefix cache # hit on the D worker but not on the P worker. # (2a): prime the D worker. - completion = decode_client.completions.create( - model=MODEL, prompt=PROMPT, temperature=0 - ) # yapf: disable + completion = decode_client.completions.create(model=MODEL, + prompt=PROMPT, + temperature=0) decode_response = completion.choices[0].text # (2b): send via the P/D setup - completion = proxy_client.completions.create( - model=MODEL, prompt=PROMPT, temperature=0 - ) # yapf: disable + completion = proxy_client.completions.create(model=MODEL, + prompt=PROMPT, + temperature=0) proxy_response = completion.choices[0].text print(f"FULL CACHE HIT: {proxy_response=}") assert proxy_response == decode_response # (3) Check that we can handle a partial prefix cache # hit on the D worker. - completion = proxy_client.completions.create( - model=MODEL, prompt=LONG_PROMPT, temperature=0 - ) # yapf: disable + completion = proxy_client.completions.create(model=MODEL, + prompt=LONG_PROMPT, + temperature=0) proxy_response = completion.choices[0].text - completion = prefill_client.completions.create( - model=MODEL, prompt=LONG_PROMPT, temperature=0 - ) # yapf: disable + completion = prefill_client.completions.create(model=MODEL, + prompt=LONG_PROMPT, + temperature=0) prefill_response = completion.choices[0].text print(f"PARTIAL CACHE HIT: {proxy_response=}") assert proxy_response == prefill_response diff --git a/examples/nixl/toy_proxy_server.py b/examples/nixl/toy_proxy_server.py index c0a26e485..cb2a7bc4c 100644 --- a/examples/nixl/toy_proxy_server.py +++ b/examples/nixl/toy_proxy_server.py @@ -4,7 +4,7 @@ import argparse import itertools import logging -import os +import os,time,sys import uuid from contextlib import asynccontextmanager @@ -27,53 +27,49 @@ async def lifespan(app: FastAPI): # Create prefill clients for i, (host, port) in enumerate(global_args.prefiller_instances): - prefiller_base_url = f"http://{host}:{port}/v1" - app.state.prefill_clients.append( - { - "client": httpx.AsyncClient( - timeout=None, base_url=prefiller_base_url - ), - "host": host, - "port": port, - "id": i, - } - ) # yapf: disable + prefiller_base_url = f'http://{host}:{port}/v1' + app.state.prefill_clients.append({ + 'client': + httpx.AsyncClient(timeout=None, base_url=prefiller_base_url), + 'host': + host, + 'port': + port, + 'id': + i + }) # Create decode clients for i, (host, port) in enumerate(global_args.decoder_instances): - decoder_base_url = f"http://{host}:{port}/v1" - app.state.decode_clients.append( - { - "client": httpx.AsyncClient( - timeout=None, base_url=decoder_base_url - ), - "host": host, - "port": port, - "id": i, - } - ) # yapf: disable + decoder_base_url = f'http://{host}:{port}/v1' + app.state.decode_clients.append({ + 'client': + httpx.AsyncClient(timeout=None, base_url=decoder_base_url), + 'host': + host, + 'port': + port, + 'id': + i + }) # Initialize round-robin iterators app.state.prefill_iterator = itertools.cycle( - range(len(app.state.prefill_clients)) - ) # yapf: disable + range(len(app.state.prefill_clients))) app.state.decode_iterator = itertools.cycle( - range(len(app.state.decode_clients)) - ) # yapf: disable + range(len(app.state.decode_clients))) - print( - f"Initialized {len(app.state.prefill_clients)} prefill clients " - f"and {len(app.state.decode_clients)} decode clients." - ) # yapf: disable + print(f"Initialized {len(app.state.prefill_clients)} prefill clients " + f"and {len(app.state.decode_clients)} decode clients.") yield # Shutdown: Close all clients for client_info in app.state.prefill_clients: - await client_info["client"].aclose() + await client_info['client'].aclose() for client_info in app.state.decode_clients: - await client_info["client"].aclose() + await client_info['client'].aclose() # Update FastAPI app initialization to use lifespan @@ -87,54 +83,43 @@ def parse_args(): parser.add_argument("--host", type=str, default="localhost") # For prefiller instances - parser.add_argument( - "--prefiller-hosts", - "--prefiller-host", - type=str, - nargs="+", - default=["localhost"], - ) - parser.add_argument( - "--prefiller-ports", - "--prefiller-port", - type=int, - nargs="+", - default=[8100], - ) + parser.add_argument("--prefiller-hosts", + "--prefiller-host", + type=str, + nargs="+", + default=["localhost"]) + parser.add_argument("--prefiller-ports", + "--prefiller-port", + type=int, + nargs="+", + default=[8100]) # For decoder instances - parser.add_argument( - "--decoder-hosts", - "--decoder-host", - type=str, - nargs="+", - default=["localhost"], - ) - parser.add_argument( - "--decoder-ports", - "--decoder-port", - type=int, - nargs="+", - default=[8200], - ) + parser.add_argument("--decoder-hosts", + "--decoder-host", + type=str, + nargs="+", + default=["localhost"]) + parser.add_argument("--decoder-ports", + "--decoder-port", + type=int, + nargs="+", + default=[8200]) args = parser.parse_args() # Validate and pair hosts with ports if len(args.prefiller_hosts) != len(args.prefiller_ports): raise ValueError( - "Number of prefiller hosts must match number of prefiller ports" - ) # yapf: disable + "Number of prefiller hosts must match number of prefiller ports") if len(args.decoder_hosts) != len(args.decoder_ports): raise ValueError( - "Number of decoder hosts must match number of decoder ports" - ) # yapf: disable + "Number of decoder hosts must match number of decoder ports") # Create tuples of (host, port) for each service type args.prefiller_instances = list( - zip(args.prefiller_hosts, args.prefiller_ports) - ) # yapf: disable + zip(args.prefiller_hosts, args.prefiller_ports)) args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports)) return args @@ -151,30 +136,29 @@ def get_next_client(app, service_type: str): Returns: The next client to use """ - if service_type == "prefill": + if service_type == 'prefill': client_idx = next(app.state.prefill_iterator) return app.state.prefill_clients[client_idx] - elif service_type == "decode": + elif service_type == 'decode': client_idx = next(app.state.decode_iterator) return app.state.decode_clients[client_idx] else: raise ValueError(f"Unknown service type: {service_type}") -async def send_request_to_service( - client_info: dict, endpoint: str, req_data: dict, request_id: str -): # yapf: disable +async def send_request_to_service(client_info: dict, endpoint: str, + req_data: dict, request_id: str): """ Send a request to a service using a client from the pool. """ req_data = req_data.copy() - req_data["kv_transfer_params"] = { + req_data['kv_transfer_params'] = { "do_remote_decode": True, "do_remote_prefill": False, "remote_engine_id": None, "remote_block_ids": None, "remote_host": None, - "remote_port": None, + "remote_port": None } req_data["stream"] = False req_data["max_tokens"] = 1 @@ -184,79 +168,98 @@ async def send_request_to_service( del req_data["stream_options"] headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id, + "X-Request-Id": request_id } - response = await client_info["client"].post( - endpoint, json=req_data, headers=headers - ) # yapf: disable + response = await client_info['client'].post(endpoint, + json=req_data, + headers=headers) response.raise_for_status() return response -async def stream_service_response( - client_info: dict, endpoint: str, req_data: dict, request_id: str -): # yapf: disable +async def stream_service_response(client_info: dict, endpoint: str, + req_data: dict, request_id: str): """ Asynchronously stream response from a service using a client from the pool. """ + s1 = time.perf_counter() headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id, + "X-Request-Id": request_id } - async with client_info["client"].stream( - "POST", endpoint, json=req_data, headers=headers - ) as response: # yapf: disable + async with client_info['client'].stream("POST", + endpoint, + json=req_data, + headers=headers) as response: response.raise_for_status() async for chunk in response.aiter_bytes(): yield chunk async def _handle_completions(api: str, request: Request): + s1 = time.perf_counter() try: req_data = await request.json() request_id = str(uuid.uuid4()) # Get the next prefill client in round-robin fashion - prefill_client_info = get_next_client(request.app, "prefill") + prefill_client_info = get_next_client(request.app, 'prefill') # Send request to prefill service - response = await send_request_to_service( - prefill_client_info, api, req_data, request_id - ) # yapf: disable - + p = send_request_to_service(prefill_client_info, api, + req_data, request_id) + s2 = time.perf_counter() + sys.stdout.flush() + response = await p + s3 = time.perf_counter() # Extract the needed fields response_json = response.json() - kv_transfer_params = response_json.get("kv_transfer_params", {}) + kv_transfer_params = response_json.get('kv_transfer_params', {}) if kv_transfer_params: + #remote_block_len = len(kv_transfer_params['remote_block_ids']) + #logger.debug('buke: cut:', type(kv_transfer_params), kv_transfer_params['remote_block_ids'],kv_transfer_params['remote_block_ids'][:(remote_block_len//8)*8]) + + #kv_transfer_params['remote_block_ids'] = kv_transfer_params['remote_block_ids'][:(remote_block_len//8)*8] + #if remote_block_len % 8 == 0: + # kv_transfer_params['remote_block_ids'] = kv_transfer_params['remote_block_ids'][:(remote_block_len//8)*8-1] + # logger.info('buke hit corner case multiples of 8:', remote_block_len) req_data["kv_transfer_params"] = kv_transfer_params - + #print(req_data) # Get the next decode client in round-robin fashion - decode_client_info = get_next_client(request.app, "decode") + decode_client_info = get_next_client(request.app, 'decode') logger.debug("Using %s %s", prefill_client_info, decode_client_info) # Stream response from decode service async def generate_stream(): - async for chunk in stream_service_response( - decode_client_info, api, req_data, request_id=request_id - ): # yapf: disable + is_first = False + s6 = time.perf_counter() + async for chunk in stream_service_response(decode_client_info, + api, + req_data, + request_id=request_id): + + if is_first is False: + s4 = time.perf_counter() + sys.stdout.flush() + is_first = True yield chunk - return StreamingResponse( - generate_stream(), media_type="application/json" - ) # yapf: disable + re = StreamingResponse(generate_stream(), + media_type="application/json") + s5 = time.perf_counter() + + #sys.stdout.flush() + return re except Exception as e: - import sys import traceback - exc_info = sys.exc_info() - print( - f"Error occurred in disagg prefill proxy server - {api} endpoint" - ) # yapf: disable + print("Error occurred in disagg prefill proxy server" + f" - {api} endpoint") print(e) print("".join(traceback.format_exception(*exc_info))) raise @@ -278,14 +281,13 @@ async def healthcheck(): return { "status": "ok", "prefill_instances": len(app.state.prefill_clients), - "decode_instances": len(app.state.decode_clients), + "decode_instances": len(app.state.decode_clients) } -if __name__ == "__main__": +if __name__ == '__main__': global global_args global_args = parse_args() import uvicorn - uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/vllm_gaudi/attention/backends/hpu_attn.py b/vllm_gaudi/attention/backends/hpu_attn.py index 9b541a4ef..a5ccd01de 100644 --- a/vllm_gaudi/attention/backends/hpu_attn.py +++ b/vllm_gaudi/attention/backends/hpu_attn.py @@ -446,7 +446,7 @@ def __init__( if head_size not in supported_head_sizes: raise ValueError(f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {supported_head_sizes}.") - + self.is_prompt = True self.attn_type = attn_type if (self.attn_type != AttentionType.DECODER and self.attn_type != AttentionType.ENCODER_DECODER and self.attn_type != AttentionType.ENCODER_ONLY): @@ -539,6 +539,7 @@ def forward( if attn_metadata.is_prompt: # Prompt run. + self.is_prompt = True query_shape = (batch_size, seq_len, self.num_heads, self.head_size) kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size) @@ -630,8 +631,8 @@ def common_attention_args(self, block_list=None, key_cache=None, value_cache=Non 'batch2block_matmul_op': self.batch2block_matmul, 'block2batch_matmul_op': self.block2batch_matmul, 'fsdpa_op': self.fused_scaled_dot_product_attention, - 'keys_fetch_func': self.k_cache.fetch_from_cache, - 'values_fetch_func': self.v_cache.fetch_from_cache, + 'keys_fetch_func': self.k_cache.fetch_from_cache if (not self.is_prompt or not self.use_contiguous_pa) else self.k_cache.fetch_from_cache_prompt, + 'values_fetch_func': self.v_cache.fetch_from_cache if (not self.is_prompt or not self.use_contiguous_pa) else self.v_cache.fetch_from_cache_prompt, 'softmax_op': self.softmax, 'block_list': block_list, 'key_cache': key_cache, diff --git a/vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hpu_nixl_connector.py b/vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hpu_nixl_connector.py index aecfde995..aa8929ee1 100644 --- a/vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hpu_nixl_connector.py +++ b/vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hpu_nixl_connector.py @@ -1,9 +1,1021 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any +import math +import uuid +import queue +import threading +import time import torch -from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (NixlConnectorWorker) +from collections import defaultdict +from concurrent.futures import Future, ThreadPoolExecutor +from vllm.distributed.utils import divide +from vllm import envs +from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.selector import get_attn_backend +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (NixlConnector, NixlConnectorWorker, NixlKVConnectorStats, + NixlConnectorMetadata, NixlConnectorScheduler, NixlAgentMetadata) from vllm_gaudi.platform import logger -import habana_frameworks.torch.utils.experimental as htexp +from vllm.platforms import current_platform +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, +) +from vllm.v1.attention.backends.utils import get_kv_cache_layout +Transfer = tuple[int, float] # (xfer_handle, start_time) +EngineId = str +ReqId = str + +try: + from nixl._api import nixl_agent as NixlWrapper + from nixl._bindings import nixlXferTelemetry + import habana_frameworks.torch.utils as htutils + logger.info("htutils is available") +except ImportError: + logger.warning("htutils is not available") + htutils = None + logger.warning("NIXL is not available") + NixlWrapper = None + nixlXferTelemetry = None + +try: + from nixl._api import nixl_agent_config +except ImportError: + nixl_agent_config = None + logger.warning("NIXL agent config is not available") + +# Supported platforms and types of kv transfer buffer. +# {device: tuple of supported kv buffer types} +_NIXL_SUPPORTED_DEVICE = { + "cuda": ( + "cuda", + "cpu", + ), + "tpu": ("cpu",), + "xpu": ("cpu",), +} +# support for oot platform by providing mapping in current_platform +_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices()) + +@dataclass +class ReqMeta: + local_block_ids: list[int] + remote_block_ids: list[int] + remote_host: str + remote_port: int + remote_engine_id: str + tp_size: int + # Whether this request had a full/partial in-memory (local) hit so + # that only the remainining blocks are required to read. + # This is a wicked fix for heterogeneous devices test between + # Nvidia device and Habana device since the block ids are not aligned. + # We should ideally set kv_transfer_params["is_mem_hit"] to True + # by scheduler/worker logic once a memory hit condition is detected. + # TODO: remove this field once vllm-fork rebases vllm upstream repo + is_mem_hit: bool = False + + +def add_new_req( + self, + request_id: ReqId, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + load_remote_cache: bool = True, + save_to_host: bool = False, +): + # save and load are mutually exclusive + assert load_remote_cache ^ save_to_host + _req = ReqMeta( + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params["remote_block_ids"], + remote_engine_id=kv_transfer_params["remote_engine_id"], + remote_host=kv_transfer_params["remote_host"], + remote_port=kv_transfer_params["remote_port"], + # P workers don't need to receive tp_size from proxy here. + tp_size=kv_transfer_params.get("tp_size", 1), + is_mem_hit=kv_transfer_params.get("is_mem_hit", False), + ) + if save_to_host: + self.reqs_to_save[request_id] = _req + if load_remote_cache: + self.reqs_to_recv[request_id] = _req + +NixlConnectorMetadata.add_new_req = add_new_req + +def NixlConnectorScheduler__init__(self, vllm_config: VllmConfig, engine_id: str): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.engine_id: EngineId = engine_id + self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + self.side_channel_port = ( + envs.VLLM_NIXL_SIDE_CHANNEL_PORT + + vllm_config.parallel_config.data_parallel_rank + * vllm_config.parallel_config.tensor_parallel_size + ) + assert vllm_config.kv_transfer_config is not None + self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu" + logger.info("Initializing NIXL Scheduler %s", engine_id) + self.hetero_blk_id_wa = os.getenv('PT_HPU_HETERO_BLOCK_ID_WA', '1') == '1' + + # Requests that need to start recv/send. + # New requests are added by update_state_after_alloc in + # the scheduler. Used to make metadata passed to Worker. + self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} + self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} + # Reqs to send and their expiration time + self._reqs_need_send: dict[ReqId, float] = {} + self._reqs_in_batch: set[ReqId] = set() + # Reqs to remove from processed set because they're not to send after + # remote prefill or aborted. + self._reqs_not_processed: set[ReqId] = set() + +NixlConnectorScheduler.__init__ = NixlConnectorScheduler__init__ + +def wait_for_save(self): + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, NixlConnectorMetadata) + self.connector_worker.rewrite_kv_based_on_transfer_layout(self._connector_metadata) + if self.connector_worker.use_host_buffer and \ + self.connector_worker.copy_blocks: + self.connector_worker.save_kv_to_host(self._connector_metadata) + +NixlConnector.wait_for_save = wait_for_save + +NixlConnectorScheduler.hetero_blk_id_wa = os.getenv('PT_HPU_HETERO_BLOCK_ID_WA', '1') == '1' + +def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + + params = request.kv_transfer_params + logger.debug( + "NIXLConnector update_state_after_alloc: " + "num_external_tokens=%s, kv_transfer_params=%s", + num_external_tokens, params) + logger.debug(f'buke update_state_after_alloc: {vars(request)=}') + if not params: + return + if params.get("do_remote_decode"): + self._reqs_in_batch.add(request.request_id) + if self.use_host_buffer and params.get("do_remote_decode"): + # NOTE: when accelerator is not directly supported by Nixl, + # prefilled blocks need to be saved to host memory before transfer. + + # figure out full computed blocks to save + block_ids = blocks.get_block_ids()[0] + all_full = request.num_tokens % self.block_size == 0 + full_block_ids = (block_ids if all_full else block_ids[:-1]) + # TODO: skip the blocks that are already in the host xfer buffer. + # Currently, the host xfer buffer block is 1-to-1 mapped to device + # kv blocks, so host blocks won't be flushed as long as its device + # block is not overwritten; and it will be safe to skip saving them + # to host xfer buffer. + if full_block_ids: + self._reqs_need_save[request.request_id] = \ + (request, full_block_ids) + elif params.get("do_remote_prefill"): + if params.get("remote_block_ids"): + if all(p in params for p in ("remote_engine_id", "remote_host", + "remote_port")): + if self.hetero_blk_id_wa: + block_ids = blocks.get_block_ids()[0] + local_block_ids = blocks.get_unhashed_block_ids() + if num_external_tokens > 0: + # Get unhashed blocks to pull from remote. + self._reqs_need_recv[request.request_id] = ( + request, local_block_ids) + if len(block_ids) > len(local_block_ids): + params["is_mem_hit"] = True + logger.debug(f"jwang {request.request_id=} {block_ids=} {local_block_ids=} need _reqs_need_recv ") + else: + #self._reqs_need_recv[request.request_id] = (request, []) + assert len(block_ids) >= len(local_block_ids), \ + f"jwang oops, it really happens {request.request_id=} {block_ids=} {local_block_ids=}" + else: + # If remote_blocks and num_external_tokens = 0, we have + # a full prefix cache hit on the D worker. We need to call + # send_notif in _read_blocks to free the memory on the P. + local_block_ids = (blocks.get_unhashed_block_ids() + if num_external_tokens > 0 else []) + # Get unhashed blocks to pull from remote. + self._reqs_need_recv[request.request_id] = ( + request, local_block_ids) + + else: + logger.warning( + "Got invalid KVTransferParams: %s. This " + "request will not utilize KVTransfer", params) + else: + assert num_external_tokens == 0 + # Only trigger 1 KV transfer per request. + params["do_remote_prefill"] = False + +NixlConnectorScheduler.update_state_after_alloc = update_state_after_alloc + +def NixlConnectorWorker__init__(self, vllm_config: VllmConfig, engine_id: str): + if NixlWrapper is None: + logger.error("NIXL is not available") + raise RuntimeError("NIXL is not available") + logger.info("Initializing NIXL wrapper") + logger.info("Initializing NIXL worker %s", engine_id) + self.decoder_tp_ratio = int(os.getenv('DECODER_TP_RATIO', 1)) + + # Config. + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + # block_factor = G2.block_size/remote_hw.block_size + self.block_factor = int(os.getenv('PT_HPU_BLOCK_SIZE_FACTOR', '1')) + self.block_shape = None + self.is_hetero = os.getenv('PT_HPU_ENABLE_RESTORE_KV_LAYOUT', '0') == '1' + + if vllm_config.kv_transfer_config is None: + raise ValueError("kv_transfer_config must be set for NixlConnector") + + self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config( + "backends", ["UCX"] + ) + # TODO temporary, once nixl allows for telemetry flag in config + # (next release), we can remove this env var. + os.environ["NIXL_TELEMETRY_ENABLE"] = "1" + # Agent. + non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"] + if nixl_agent_config is None: + config = None + else: + config = ( + nixl_agent_config(backends=self.nixl_backends) + if len(non_ucx_backends) > 0 + else nixl_agent_config(num_threads=8) + ) + + self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config) + # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. + self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) + + # NIXL handshake port. + # NOTE(rob): Within a DP group, each DP rank gets its own + # base port (which is sent in the KVTransferParams). + # Each TP rank listens/queries on the base_port + tp_rank. + self.side_channel_port: int = ( + envs.VLLM_NIXL_SIDE_CHANNEL_PORT + + vllm_config.parallel_config.data_parallel_rank + * vllm_config.parallel_config.tensor_parallel_size + ) + + # Metadata. + self.engine_id: EngineId = engine_id + self.tp_rank = get_tensor_model_parallel_rank() + self.world_size = get_tensor_model_parallel_world_size() + self.tp_group = get_tp_group() + self.num_blocks = 0 + self.enable_permute_local_kv = False + + # KV Caches and nixl tracking data. + self.device_type = current_platform.device_type + self.kv_buffer_device: str = vllm_config.kv_transfer_config.kv_buffer_device + if self.device_type not in _NIXL_SUPPORTED_DEVICE: + raise RuntimeError(f"{self.device_type} is not supported.") + elif self.kv_buffer_device not in _NIXL_SUPPORTED_DEVICE[self.device_type]: + raise RuntimeError( + f"{self.device_type} with {self.kv_buffer_device} kv_buffer " + "is not supported." + ) + self.device_kv_caches: dict[str, torch.Tensor] = {} + + # cpu kv buffer for xfer + # used when device memory can not be registered under nixl + self.host_xfer_buffers: dict[str, torch.Tensor] = {} + self.use_host_buffer = self.kv_buffer_device == "cpu" + # support for oot platform which can't register nixl memory + # type based on kv_buffer_device + nixl_memory_type = current_platform.get_nixl_memory_type() + if nixl_memory_type is None: + if self.kv_buffer_device == "cuda" or self.kv_buffer_device == "hpu": + nixl_memory_type = "VRAM" + elif self.kv_buffer_device == "cpu": + nixl_memory_type = "DRAM" + if nixl_memory_type is None: + raise RuntimeError( + f"{self.device_type} with {self.kv_buffer_device} kv_buffer " + "is not supported." + ) + if self.kv_buffer_device == "cpu" and self.is_hetero: + self.remote_nixl_memory_type = "VRAM" + else: + self.nixl_memory_type = nixl_memory_type + + # Note: host xfer buffer ops when use_host_buffer is True + self.copy_blocks: CopyBlocksOp | None = None + + # Map of engine_id -> kv_caches_base_addr. For TP case, each local + # rank will still only pull from a single remote TP worker. + self.kv_caches_base_addr: dict[EngineId, list[int]] = {} + + # Number of NIXL regions. Currently one region per cache + # (so 1 per layer for MLA, otherwise 2 per layer) + self.num_regions = 0 + self.num_layers = 0 + + # nixl_prepped_dlist_handle. + self.src_xfer_side_handle: int = 0 + # Map of engine_id -> nixl_prepped_dlist_handle (int)]. + self.dst_xfer_side_handles: dict[EngineId, int] = {} + + # Map of engine_id -> num_blocks. All ranks in the same deployment will + # have the same number of blocks. + self.dst_num_blocks: dict[EngineId, int] = {} + self._registered_descs: list[Any] = [] + + # In progress transfers. + # [req_id -> list[handle]] + self._recving_metadata: dict[ReqId, ReqMeta] = {} + self._recving_transfers = defaultdict[ReqId, list[Transfer]](list) + # Track the expiration time of requests that are waiting to be sent. + self._reqs_to_send: dict[ReqId, float] = {} + # Set of requests that have been part of a batch, regardless of status. + self._reqs_to_process: set[ReqId] = set() + + # invalid blocks from failed NIXL operations + self._invalid_block_ids: set[int] = set() + # requests that skipped transfer (handshake or transfer failures) + self._failed_recv_reqs: set[ReqId] = set() + + # Background thread for handling new handshake requests. + self._nixl_handshake_listener_t: threading.Thread | None = None + # Background thread for initializing new NIXL handshakes. + self._handshake_initiation_executor = ThreadPoolExecutor( + # NIXL is not guaranteed to be thread-safe, limit 1 worker. + max_workers=1, + thread_name_prefix="vllm-nixl-handshake-initiator", + ) + self._ready_requests = queue.Queue[tuple[ReqId, ReqMeta]]() + self._handshake_futures: dict[EngineId, Future[dict[int, str]]] = {} + # Protects _handshake_futures and _remote_agents. + self._handshake_lock = threading.RLock() + + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + + # TODO(mgoin): remove this once we have hybrid memory allocator + # Optimization for models with local attention (Llama 4) + # List of block window sizes for each layer for local attention + self.block_window_per_layer: list[int | None] = [] + self.use_mla = self.model_config.use_mla + + backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + use_mla=self.use_mla, + ) + self.backend_name = backend.get_name() + attn_backend = backend_name_to_enum(self.backend_name) + self._use_flashinfer = attn_backend == _Backend.FLASHINFER + self._use_pallas = attn_backend == _Backend.PALLAS + self.kv_cache_layout = get_kv_cache_layout() + logger.debug("Detected attention backend %s", self.backend_name) + logger.debug("Detected kv cache layout %s", self.kv_cache_layout) + + self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} + # With heterogeneous TP, P must wait for all assigned D TP workers to + # finish reading before safely freeing the blocks. + self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) + self.xfer_stats = NixlKVConnectorStats() + +NixlConnectorWorker.__init__ = NixlConnectorWorker__init__ + +def add_remote_agent( + self, + nixl_agent_meta: NixlAgentMetadata, + remote_tp_rank: int = 0, + remote_tp_size: int = 1, +) -> str: + """ + Add the remote NIXL agent and prepare the descriptors for reading cache + blocks from remote. + + In particular, handle both homogeneous and heterogeneous TP. The former + requires local rank_i to read from remote rank_i. + The latter, assuming D.world_size > P.world_size, requires that two or + more local TP worker share the xfer from a single TP worker. + + Here's an example (non-MLA case): + + rank_offset p_remote_tp_rank + (kv split no) + -------------------------------- + 0 0 Worker0 ---- 1st half of KV ----> Worker0 [ KV Cache ] + / + 1 0 Worker1 ---- 2nd half of KV -----/ + + 0 1 Worker2 ---- 1st half of KV ----> Worker1 [ KV Cache ] + / + 1 1 Worker3 ---- 2nd half of KV -----/ + + + Decoder TP workers Prefix TP workers + (world_size=4) (world_size=2) + tp_ratio = 4 // 2 = 2 + + Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim] + then D-Worker_j has [2, num_blocksD, kv_heads//tp_ratio, block_size, head_dim]. Mind the "HND" layout format. + Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio + first heads from all the slots of all the blocks. D-Worker1 will do the same, but reading the second split + along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0. + + Note that the above will also hold true for the homogeneous TP case, where tp_ratio evaluates to 1. + + Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0 + so that the whole cache is shared by "tp_ratio" D TP workers. + """ # noqa: E501 + engine_id = nixl_agent_meta.engine_id + # TODO re-evaluate refreshing for scaling/recovery + if remote_tp_rank in self._remote_agents.get(engine_id, {}): + return self._remote_agents[engine_id][remote_tp_rank] + + if engine_id not in self._tp_size: + self._tp_size[engine_id] = remote_tp_size + else: + assert self._tp_size[engine_id] == remote_tp_size + # TODO We may eventually want to skip enforcing the same attn backend. + #assert nixl_agent_meta.attn_backend_name == self.backend_name + assert nixl_agent_meta.attn_backend_name == "FLASH_ATTN_VLLM_V1" or nixl_agent_meta.attn_backend_name == "HPU_ATTN_V1" + + remote_agent_name = self.nixl_wrapper.add_remote_agent( + nixl_agent_meta.agent_metadata + ) + + # Number of D TP workers reading from a single P TP worker. This is + # 1 when P and D `--tensor-parallel-size` match. + tp_ratio = divide(self._tp_size[self.engine_id], self._tp_size[engine_id]) + assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" + assert not self._use_pallas or tp_ratio == 1, ( + "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." + ) + + # Handle tp_size>num_kv_heads: replicate KV cache. + total_num_kv_heads = self.model_config.get_total_num_kv_heads() + is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1 + + remote_block_len = nixl_agent_meta.block_lens[0] + if nixl_agent_meta.kv_cache_layout != self.kv_cache_layout: + if ( + self.vllm_config.kv_transfer_config is not None + and self.vllm_config.kv_transfer_config.enable_permute_local_kv + and nixl_agent_meta.kv_cache_layout == "HND" + ): + logger.info( + "Remote is HND and local is NHD, enabled additional permute " + "on local device KV." + ) + self.enable_permute_local_kv = True + else: + raise RuntimeError( + "Heterogeneous TP expects same kv_cache_layout. " + "Or enable experimental feature to use HND to NHD support by " + "setting 'enable_permute_local_kv'=True in --kv-transfer-config." + ) + if self.use_mla or is_kv_replicated: + # With replicated KV cache, only the number of blocks can differ. + assert self.block_len_per_layer == nixl_agent_meta.block_lens, ( + "KV cache sizes must match between P and D when replicated" + ) + remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) + else: + # When MLA is not used, this is a list of the same block length + for block_len in nixl_agent_meta.block_lens: + assert block_len == remote_block_len, ( + "All remote layers must have the same block size" + ) + remote_block_size = remote_block_len // ( + self.slot_size_per_layer[0] * tp_ratio + ) + if self._use_flashinfer: + # With flashinfer, KV are sent in the same message. + remote_block_size //= 2 + if tp_ratio > 1: + # Heterogeneous TP expects same kv_cache_layout. + if nixl_agent_meta.kv_cache_layout == "NHD": + raise ValueError( + "Heterogeneous TP is not supported for remote with NHD." + ) + if self.device_type == "xpu": + raise ValueError("Heterogeneous TP is not supported on XPU") + + #assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( + # "Remote P worker KV layer cache must be of shape [2, N, " + # "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." + #) + + #assert self.block_size == remote_block_size, ( + # "Remote P worker with different page/block size is not supported " + # f"{self.block_size=}, {remote_block_size=}" + #) + + # Create dst descs and xfer side handles. TP workers have same #blocks. + if engine_id in self.dst_num_blocks: + assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks + else: + self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + + blocks_data = [] + # With homogeneous TP, D pulls the whole kv cache from corresponding + # rank. With heterogeneous TP, prepare the descriptors by splitting the + # P KV cache along kv_head dim, of D worker's kv_head size (D>P). + # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. + self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr + + assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) + # Register all remote blocks, but only the corresponding kv heads. + for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): + kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) + rank_offset = ( + self.tp_rank % tp_ratio * kv_block_len // tp_ratio + if not (self.use_mla or is_kv_replicated) + else 0 + ) + for block_id in range(nixl_agent_meta.num_blocks): + block_offset = block_id * nixl_agent_meta.block_lens[i] + # For each block, grab the heads chunk belonging to rank_i + # of size remote_nheads // tp_ratio, which correspond to + # self.block_len == remote_block_len//tp_ratio bytes. + addr = base_addr + block_offset + rank_offset + # (addr, len, device id) + #blocks_data.append((addr, kv_block_len // tp_ratio, remote_tp_rank)) + blocks_data.append((addr, nixl_agent_meta.block_lens[i]//tp_ratio, remote_tp_rank)) + + if self._use_flashinfer: + # With FlashInfer index V separately to allow head splitting. + for block_id in range(nixl_agent_meta.num_blocks): + block_offset = block_id * nixl_agent_meta.block_lens[0] + addr = base_addr + block_offset + rank_offset + v_addr = addr + nixl_agent_meta.block_lens[0] // 2 + blocks_data.append((v_addr, kv_block_len, remote_tp_rank)) + + logger.debug( + "Created %s blocks for dst engine %s with remote rank %s and local rank %s", + len(blocks_data), + engine_id, + remote_tp_rank, + self.tp_rank, + ) + + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) + self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist( + remote_agent_name, descs + ) + + return remote_agent_name + +NixlConnectorWorker.add_remote_agent = add_remote_agent + +def get_finished(self) -> tuple[set[str], set[str]]: + """ + Get requests that are done sending or recving on this specific worker. + The scheduler process (via the MultiprocExecutor) will use this output + to track which workers are done. + """ + done_sending = self._get_new_notifs() + done_recving = self._pop_done_transfers(self._recving_transfers) + + # add requests that skipped transfer to done_recving + done_recving.update(self._failed_recv_reqs) + self._failed_recv_reqs.clear() + + if len(done_sending) > 0 or len(done_recving) > 0: + logger.debug( + "Rank %s, get_finished: %s requests done sending " + "and %s requests done recving", + self.tp_rank, + len(done_sending), + len(done_recving), + ) + + if self.is_hetero and self.kv_buffer_device == "hpu": + #import remote_pdb; remote_pdb.set_trace() + remote_block_size = self.block_size // self.block_factor + block_size, n_kv_heads, head_dim = self.block_shape + for req_id in done_recving: + #print(req_id, self._recving_metadata) + meta = self._recving_metadata.pop(req_id) + for k, v in self.device_kv_caches.values(): + local_block_ids = meta.local_block_ids + #print(f'buke {local_block_ids=}|{k.shape=}') + assert len(local_block_ids) == local_block_ids[-1]-local_block_ids[0] + 1 # simple check if the indices are contiguous + block_idx = local_block_ids[0] + num_blocks = len(local_block_ids) + k[block_idx*self.block_size: (num_blocks+block_idx)*self.block_size] = k[block_idx*self.block_size: (num_blocks+block_idx)*self.block_size].reshape(num_blocks*self.block_factor, n_kv_heads, remote_block_size, head_dim).permute(0,2,1,3).contiguous().reshape(num_blocks*self.block_size,n_kv_heads,head_dim) + v[block_idx*self.block_size: (num_blocks+block_idx)*self.block_size] = v[block_idx*self.block_size: (num_blocks+block_idx)*self.block_size].reshape(num_blocks*self.block_factor, n_kv_heads, remote_block_size, head_dim).permute(0,2,1,3).contiguous().reshape(num_blocks*self.block_size,n_kv_heads,head_dim) + #import remote_pdb; remote_pdb.set_trace() + + # clean up metadata for completed requests + if self.use_host_buffer: + for req_id in done_recving: + meta = self._recving_metadata.pop(req_id, None) + if self.use_host_buffer and meta: + self.sync_recved_kv_to_device(req_id, meta) + + # Handle timeout to avoid stranding blocks on remote. + now = time.perf_counter() + while self._reqs_to_send: + req_id, expires = next(iter(self._reqs_to_send.items())) + # Sorted dict, oldest requests are put first so we can exit early. + if now < expires: + break + count = self.consumer_notification_counts_by_req.pop(req_id, 0) + logger.warning( + "Releasing expired KV blocks for request %s which were " + "retrieved by %d decode worker(s) within %d seconds.", + req_id, + count, + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT, + ) + self._reqs_to_process.remove(req_id) + del self._reqs_to_send[req_id] + done_sending.add(req_id) + + if self.enable_permute_local_kv and len(done_recving) > 0: + block_ids = [] + for req_id in done_recving: + meta = self._recving_metadata.pop(req_id) + assert meta, f"{req_id} not found in recving_metadata list" + block_ids += meta.local_block_ids + + self.permute_device_kv(block_ids) + + return done_sending, done_recving + +NixlConnectorWorker.get_finished = get_finished + +def rewrite_kv_based_on_transfer_layout(self, metadata: NixlConnectorMetadata): + if self.decoder_tp_ratio == 1: + return + t = time.perf_counter() + for req_id, meta in metadata.reqs_to_save.items(): + block_ids = meta.local_block_ids + for k, v in self.device_kv_caches.items(): + gb, h, d = v[0].shape + indices = torch.tensor(block_ids, device=v[0].device) + gbhd = [int(gb/self.block_size), self.block_size, h, d] + for i in range(len(self.device_kv_caches[k])): + kv = v[i].reshape(gbhd) + kv_selected = torch.index_select(kv, 0, indices) + bc, bs, h, d = kv_selected.shape + shape = int(bs*h/self.decoder_tp_ratio*d) + blocks = torch.chunk(kv_selected, 2, dim=2) + vecs = [b.reshape([bc, shape]) for b in blocks] + kv_selected = torch.concat(vecs, dim=1).reshape(kv_selected.shape) + kv.index_copy_(dim=0, index=indices, source=kv_selected) + if len(metadata.reqs_to_save) > 0: + torch.hpu.synchronize() + +NixlConnectorWorker.rewrite_kv_based_on_transfer_layout = rewrite_kv_based_on_transfer_layout + +def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): + logger.debug( + "Remote agent %s available, calling _read_blocks for req %s", + meta.remote_engine_id, + req_id, + ) + self._read_blocks( + request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + is_mem_hit=meta.is_mem_hit, + ) + +NixlConnectorWorker._read_blocks_for_req = _read_blocks_for_req + +def _read_blocks( + self, + local_block_ids: list[int], + remote_block_ids: list[int], + dst_engine_id: str, + request_id: str, + is_mem_hit: bool = False, +): + # NOTE(rob): having the staging blocks be on the READER side is + # not going to work well (since we will have to call rearrange tensors). + # after we detect the txn is complete (which means we cannot make the + # read trxn async easily). If we want to make "READ" happen cleanly, + # then we will need to have the staging blocks on the remote side. + + # NOTE(rob): according to nvidia the staging blocks are used to + # saturate IB with heterogeneous TP sizes. We should remove the staging + # blocks until we are ready. + + # Number of D TP workers that will read from dst P. Propagate tp_ratio + # on notification so that dst worker can wait before freeing blocks. + tp_ratio = self._tp_size[self.engine_id] // self._tp_size[dst_engine_id] + notif_id = f"{request_id}:{tp_ratio}".encode() + + # Full prefix cache hit: do not need to read remote blocks, + # just notify P worker that we have the blocks we need. + num_local_blocks = len(local_block_ids) + if num_local_blocks == 0: + remote_rank = self.tp_rank // tp_ratio + agent_name = self._remote_agents[dst_engine_id][remote_rank] + try: + self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) + except Exception: + logger.exception( + "NIXL send_notif failed for request %s: " + "P worker blocks will be freed after timeout. " + "This may indicate network issues.", + request_id, + ) + self.xfer_stats.record_failed_notification() + return + + # Partial prefix cache hit: just read uncomputed blocks. + num_remote_blocks = len(remote_block_ids) + assert num_local_blocks <= num_remote_blocks + if num_local_blocks < num_remote_blocks: + remote_block_ids = remote_block_ids[-num_local_blocks:] + + # Get side handles. + local_xfer_side_handle = self.src_xfer_side_handle + remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] + + # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from + # corresponding rank. With heterogeneous TP, fixing D>P, the D tp + # workers will issue xfers to parts of the P worker remote kv caches. + + # Get descs ids. + local_block_descs_ids: np.ndarray + remote_block_descs_ids: np.ndarray + if self.block_factor > 1: + local_sub_block_ids = [b for x in local_block_ids for b in range(x * self.block_factor, (x + 1) * self.block_factor)] + assert len(local_sub_block_ids) <= len(remote_block_ids) + valid_len = len(local_sub_block_ids) + logger.debug(f'buke {local_block_ids=} |{remote_block_ids=} |{valid_len=} |{len(remote_block_ids)}') + if is_mem_hit: + remote_block_ids = remote_block_ids[-valid_len:] + else: + remote_block_ids = remote_block_ids[:valid_len] + local_block_ids = local_sub_block_ids[:valid_len] + logger.debug(f'buke {local_block_ids=} |{remote_block_ids=} |{local_sub_block_ids=} | {is_mem_hit=}') + else: + if num_local_blocks < num_remote_blocks: + remote_block_ids = remote_block_ids[-num_local_blocks:] + + if not self.block_window_per_layer: + # Default case: assume global attention + remote_block_descs_ids = self._get_block_descs_ids( + dst_engine_id, remote_block_ids + ) + local_block_descs_ids = self._get_block_descs_ids( + self.engine_id, local_block_ids + ) + else: + # TODO(mgoin): remove this once we have hybrid memory allocator + # Optimization for models with local attention (Llama 4) + local_descs_list = [] + remote_descs_list = [] + for layer_idx, block_window in enumerate(self.block_window_per_layer): + # For each layer: + if block_window is None: + # If not chunked, we just use the + # full block lists (global attention) + layer_local_block_ids = local_block_ids + layer_remote_block_ids = remote_block_ids + else: + # If chunked, get the last block_window blocks + layer_local_block_ids = local_block_ids[-block_window:] + layer_remote_block_ids = remote_block_ids[-block_window:] + + # Get descs ids for the layer. + layer_local_desc_ids = self._get_block_descs_ids( + self.engine_id, layer_local_block_ids, layer_idx + ) + layer_remote_desc_ids = self._get_block_descs_ids( + dst_engine_id, layer_remote_block_ids, layer_idx + ) + + local_descs_list.append(layer_local_desc_ids) + remote_descs_list.append(layer_remote_desc_ids) + + local_block_descs_ids = np.concatenate(local_descs_list) + remote_block_descs_ids = np.concatenate(remote_descs_list) + + assert len(local_block_descs_ids) == len(remote_block_descs_ids) + + # Prepare transfer with Nixl. + handle = None + try: + handle = self.nixl_wrapper.make_prepped_xfer( + "READ", + local_xfer_side_handle, + local_block_descs_ids, + remote_xfer_side_handle, + remote_block_descs_ids, + notif_msg=notif_id, + ) + + # Begin async xfer. + self.nixl_wrapper.transfer(handle) + + # Use handle to check completion in future step(). + self._recving_transfers[request_id].append((handle, time.perf_counter())) + except Exception: + logger.exception( + "NIXL transfer setup/initiation failed for request %s. " + "Marking blocks as invalid.", + request_id, + ) + # mark all blocks for this request as invalid + if meta := self._recving_metadata.get(request_id): + self._invalid_block_ids.update(meta.local_block_ids) + self.xfer_stats.record_failed_transfer() + if handle is not None: + self.nixl_wrapper.release_xfer_handle(handle) + self._failed_recv_reqs.add(request_id) + +NixlConnectorWorker._read_blocks = _read_blocks + +def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Register the KV Cache data in nixl.""" + _, first_kv_cache = next(iter(kv_caches.items())) + if self.device_type == "hpu": + kv_elem_size = first_kv_cache[0][0].dtype.itemsize + else: + kv_elem_size = first_kv_cache.element_size() + + if self.use_host_buffer: + self.initialize_host_xfer_buffer(kv_caches=kv_caches) + assert len(self.host_xfer_buffers) == len(kv_caches), ( + f"host_buffer: {len(self.host_xfer_buffers)}, " + f"kv_caches: {len(kv_caches)}") + xfer_buffers = self.host_xfer_buffers + else: + xfer_buffers = kv_caches + assert not self.host_xfer_buffers, ( + "host_xfer_buffer should not be initialized when " + f"kv_buffer_device is {self.kv_buffer_device}") + + # TODO(tms): Find a more robust way to detect and handle MLA + # NOTE (NickLucche) To move blocks efficiently with NIXL, the expected + # KV memory layout is HND, as opposed to the default NHD. Note that it + # will only affects the strides. For MLA instead, we make require no + # such thing and resort to the standard layout. + use_mla = len(first_kv_cache.shape) == 3 if self.device_type != "hpu" else False + if self.device_type == "hpu": + # habana kv_cache: [2, num_blocks*block_size, kv_heads, head_dim] + #from remote_pdb import RemotePdb; RemotePdb('0.0.0.0', 4444).set_trace() + self.num_blocks = first_kv_cache[0].shape[0] // self.block_size + block_rank = 3 # [block_size, kv_heads, head_dim] + block_shape = first_kv_cache[0].shape[-block_rank:] + block_shape = list(block_shape) + block_shape[0] = block_shape[0] // self.num_blocks + block_shape = torch.Size(block_shape) + block_size, n_kv_heads, head_dim = block_shape[-3:] + self.block_shape = [block_size, n_kv_heads, head_dim] + # head size in bytes. + self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim + else: + raise RuntimeError( + f"{self.device_type} ({self.backend_name}) is not supported.") + + # TODO(tms): self.block_len needs to be per-layer for sliding window, + # hybrid attn, etc + # block size in bytes + self.block_len = kv_elem_size * math.prod(block_shape) + logger.info( + "Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, " + "use_host_buffer: %s, num_blocks: %s, block_shape: %s, " + "per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device, + self.use_host_buffer, self.num_blocks, block_shape, + first_kv_cache[0].shape) + self.dst_num_blocks[self.engine_id] = self.num_blocks * self.block_factor + self.device_kv_caches = kv_caches + kv_caches_base_addr = [] + caches_data = [] + seen_base_addresses = [] + # Note(tms): I modified this from the original region setup code. + # K and V are now in different regions. Advantage is that we cans + # elegantly support MLA and any cases where the K and V tensors + # are non-contiguous (it's not locally guaranteed that they will be) + # Disadvantage is that the encoded NixlAgentMetadata is now larger + # (roughly 8KB vs 5KB). + # Conversely for FlashInfer, K and V are transferred in the same tensor + # to better exploit the memory layout (ie num_blocks is the first dim). + tensor_size_bytes = None + self.block_len_per_layer = list[int]() + self.slot_size_per_layer = list[int]() # HD bytes in kv terms + for cache_or_caches in xfer_buffers.values(): + # Normalize to always be a list of caches + cache_list = [cache_or_caches] if use_mla \ + else cache_or_caches + for cache in cache_list: + if self.device_type == "hpu" and not self.use_host_buffer and htutils is not None: + base_addr = htutils.experimental._data_ptr(cache) + logger.debug(f'buke register gaudi memory for gdr: {base_addr=}|{hex(base_addr)=}|{cache.data_ptr()=}') + else: + base_addr = cache.data_ptr() + if base_addr in seen_base_addresses: + continue + + seen_base_addresses.append(base_addr) + curr_tensor_size_bytes = cache.numel() * cache.element_size() + + if tensor_size_bytes is None: + tensor_size_bytes = curr_tensor_size_bytes + self.num_blocks = cache.shape[0] + + assert cache.shape[0] == self.num_blocks, ( + "All kv cache tensors must have the same number of blocks" + ) + + self.block_len_per_layer.append( + curr_tensor_size_bytes // self.num_blocks + ) + self.slot_size_per_layer.append( + self.block_len_per_layer[-1] // self.block_size + ) + region_len = self.num_blocks * self.block_len + # NOTE: use tp_rank for device_id since multi-node TP + # is rarely used. + caches_data.append((base_addr, region_len, self.tp_rank, "")) + kv_caches_base_addr.append(base_addr) + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + self.num_regions = len(caches_data) + self.num_layers = len(xfer_buffers.keys()) + + # TODO(mgoin): remove this once we have hybrid memory allocator + # Optimization for models with local attention (Llama 4) + if self.vllm_config.model_config.hf_config.model_type == "llama4": + from transformers import Llama4TextConfig + assert isinstance(self.vllm_config.model_config.hf_text_config, + Llama4TextConfig) + llama4_config = self.vllm_config.model_config.hf_text_config + no_rope_layers = llama4_config.no_rope_layers + chunk_size = llama4_config.attention_chunk_size + chunk_block_size = math.ceil(chunk_size / self.block_size) + for layer_idx in range(self.num_layers): + # no_rope_layers[layer_idx] == 0 means NoPE (global) + # Any other value means RoPE (local chunked) + is_local_attention = no_rope_layers[layer_idx] != 0 + block_window = chunk_block_size if is_local_attention else None + self.block_window_per_layer.append(block_window) + logger.debug("Llama 4 block window per layer mapping: %s", + self.block_window_per_layer) + assert len(self.block_window_per_layer) == self.num_layers + + descs = self.nixl_wrapper.get_reg_descs(caches_data, + self.nixl_memory_type) + logger.debug("Registering descs: %s", caches_data) + self.nixl_wrapper.register_memory(descs) + logger.debug("Done registering descs") + self._registered_descs.append(descs) + + # Register local/src descr for NIXL xfer. + blocks_data = [] + for base_addr in self.kv_caches_base_addr[self.engine_id]: + # NOTE With heter-TP, more blocks are prepared than what are + # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We + # could create fewer, but then _get_block_descs_ids needs to + # select agent_meta.num_blocks instead of self.num_blocks for + # local descr, and that makes handling regular flow less clean. + for block_id in range(self.num_blocks * self.block_factor): + block_offset = block_id * self.block_len // (self.block_factor) + addr = base_addr + block_offset + # (addr, len, device id) + # TODO: does device_id matter to DRAM? + blocks_data.append((addr, self.block_len//(self.block_factor), self.tp_rank)) + logger.debug("Created %s blocks for src engine %s and rank %s", + len(blocks_data), self.engine_id, self.tp_rank) + #print(f'buke: {blocks_data[0:10]=}') + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, + self.nixl_memory_type) + # NIXL_INIT_AGENT to be used for preparations of local descs. + self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs) + + # After KV Caches registered, listen for new connections. + metadata = NixlAgentMetadata( + engine_id=self.engine_id, + agent_metadata=self.nixl_wrapper.get_agent_metadata(), + kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], + num_blocks=self.num_blocks, + block_lens=self.block_len_per_layer, + attn_backend_name=self.backend_name, kv_cache_layout=self.kv_cache_layout,) + ready_event = threading.Event() + self._nixl_handshake_listener_t = threading.Thread( + target=self._nixl_handshake_listener, + args=(metadata, ready_event, self.side_channel_port, self.tp_rank), + daemon=True, + name="nixl_handshake_listener") + self._nixl_handshake_listener_t.start() + ready_event.wait() # Wait for listener ZMQ socket to be ready. def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None: @@ -38,13 +1050,13 @@ def _hpu_data_ptr(tensor_self): A temporary replacement for tensor.data_ptr(). Checks if the tensor is on an HPU device and if host buffers are not - in use, then calls the htexp._data_ptr utility. Otherwise, it falls + in use, then calls the htutils.experimental._data_ptr utility. Otherwise, it falls back to the original method. """ # The first `self` refers to the class instance (from the outer scope) # The `tensor_self` is the tensor instance on which .data_ptr() is called if tensor_self.device.type == 'hpu': - return htexp._data_ptr(tensor_self) + return htutils.experimental._data_ptr(tensor_self) # Fallback to the original implementation for CPU tensors or host buffers return original_data_ptr(tensor_self) @@ -53,3 +1065,4 @@ def _hpu_data_ptr(tensor_self): torch.Tensor.data_ptr = _hpu_data_ptr NixlConnectorWorker.initialize_host_xfer_buffer = initialize_host_xfer_buffer +NixlConnectorWorker.register_kv_caches = register_kv_caches diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index ddb64fee9..a2220203c 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -99,6 +99,8 @@ hpu_buffer: list[list[torch.Tensor]] = [] +is_hetero = os.getenv('PT_HPU_ENABLE_RESTORE_KV_LAYOUT', '0') == '1' +block_factor = int(os.getenv('PT_HPU_BLOCK_SIZE_FACTOR', '1')) class BucketingFailedException(Exception): pass @@ -4474,6 +4476,7 @@ def copy_kv_blocks( direction: Literal["h2d", "d2h"], block_size: int = 128, ) -> None: + """Copy kv blocks between different buffers.""" if not src_kv_caches or not dst_kv_caches or \ not src_block_ids or not dst_block_ids or \ @@ -4483,47 +4486,72 @@ def copy_kv_blocks( src_device = next(iter(src_kv_caches.values()))[0].device dst_device = next(iter(dst_kv_caches.values()))[0].device - src_slot_mapping, dst_slot_mapping = _make_src_and_dst_indices(block_size=block_size, - src_block_ids=src_block_ids, - dst_block_ids=dst_block_ids, - src_device=src_device, - dst_device=dst_device) + src_slot_mapping, dst_slot_mapping = _make_src_and_dst_indices( + block_size=block_size, + src_block_ids=src_block_ids, + dst_block_ids=dst_block_ids, + src_device=src_device, + dst_device=dst_device) start = time.perf_counter() target_device = dst_device.type i = 0 global hpu_buffer - use_hpu_buffer = False + use_hpu_buffer = False # (len(src_slot_mapping) == hpu_buffer[0][0].size(0)) and (hpu_buffer is not None) for layer_name in src_kv_caches: key_cache = src_kv_caches[layer_name][0] value_cache = src_kv_caches[layer_name][1] - if direction == "d2h": - # NOTE(chendi): in order to keep host_buffer shape[0] same as tpu and gpu case - # so we need to flatten the dst_kv_caches - dst_kv_caches[layer_name] = dst_kv_caches[layer_name].flatten(1, 2) + if is_hetero: + assert direction == "h2d", "hetero only supports h2d for now" + n_kv_heads, head_dim = key_cache.shape[-2:] + remote_block_size = block_size//block_factor + #block_factor, n_kv_heads, remote_block_size, head_dim = 8, 8, 16, 128 + if len(src_block_ids) == src_block_ids[-1]-src_block_ids[0] + 1: # simple check if the indices are contiguous + block_idx = src_block_ids[0] + num_blocks = len(src_block_ids) + dst_kv_caches[layer_name][0][block_idx*block_size: (num_blocks+block_idx)*block_size] = key_cache[block_idx*block_size: (num_blocks+block_idx)*block_size].reshape(num_blocks*block_factor, n_kv_heads, remote_block_size, head_dim).permute(0,2,1,3).contiguous().reshape(num_blocks*block_size,n_kv_heads,head_dim) + dst_kv_caches[layer_name][1][block_idx*block_size: (num_blocks+block_idx)*block_size] = value_cache[block_idx*block_size: (num_blocks+block_idx)*block_size].reshape(num_blocks*block_factor, n_kv_heads, remote_block_size, head_dim).permute(0,2,1,3).contiguous().reshape(num_blocks*block_size,n_kv_heads,head_dim) + continue + for block_idx in src_block_ids: + #print('buke addr before:', dst_kv_caches[layer_name][0][block_idx*block_size: (1+block_idx)*block_size].data_ptr()) + dst_kv_caches[layer_name][0][block_idx*block_size: (1+block_idx)*block_size] = key_cache[block_idx*block_size: (1+block_idx)*block_size].reshape(block_factor, n_kv_heads, remote_block_size, head_dim).permute(0,2,1,3).contiguous().reshape(block_size,n_kv_heads,head_dim).to("hpu") + dst_kv_caches[layer_name][1][block_idx*block_size: (1+block_idx)*block_size] = value_cache[block_idx*block_size: (1+block_idx)*block_size].reshape(block_factor, n_kv_heads, remote_block_size, head_dim).permute(0,2,1,3).contiguous().reshape(block_size,n_kv_heads,head_dim).to("hpu") + #print('buke addr after:', dst_kv_caches[layer_name][0][block_idx*block_size: (1+block_idx)*block_size].data_ptr()) else: - key_cache = key_cache.flatten(0, 1) - if value_cache is not None: - value_cache = value_cache.flatten(0, 1) + ''' + if direction == "d2h": + # NOTE(chendi): in order to keep host_buffer shape[0] same as tpu and gpu case + # so we need to flatten the dst_kv_caches + dst_kv_caches[layer_name] = dst_kv_caches[layer_name].flatten(1, 2) + else: + key_cache = key_cache.flatten(0, 1) + if value_cache is not None: + value_cache = value_cache.flatten(0, 1) - if direction == "d2h" and use_hpu_buffer: - hpu_buffer[i][0] = key_cache.index_select(0, src_slot_mapping) - hpu_buffer[i][1] = value_cache.index_select(0, src_slot_mapping) - else: dst_kv_caches[layer_name][0].index_put_((dst_slot_mapping, ), key_cache.index_select(0, src_slot_mapping).to(target_device)) dst_kv_caches[layer_name][1].index_put_((dst_slot_mapping, ), value_cache.index_select(0, src_slot_mapping).to(target_device)) - if direction == "d2h": - dst_kv_caches[layer_name] = dst_kv_caches[layer_name].unflatten(1, (-1, block_size)) + if direction == "d2h": + dst_kv_caches[layer_name] = dst_kv_caches[layer_name].unflatten(1, (-1, block_size)) + ''' + if direction == "d2h" and use_hpu_buffer: + hpu_buffer[i][0]=key_cache.index_select_(0, src_slot_mapping) + hpu_buffer[i][1]=value_cache.index_select_(0, src_slot_mapping) + else: + #import remote_pdb;remote_pdb.set_trace() + dst_kv_caches[layer_name][0].index_put_((dst_slot_mapping,), key_cache.index_select(0, src_slot_mapping).to(target_device)) + dst_kv_caches[layer_name][1].index_put_((dst_slot_mapping,), value_cache.index_select(0, src_slot_mapping).to(target_device)) + i = i+1 + + #dst_kv_caches[layer_name][0][dst_slot_mapping] = key_cache[src_slot_mapping].to(target_device) + #dst_kv_caches[layer_name][1][dst_slot_mapping] = value_cache[src_slot_mapping].to(target_device) + #if use_hpu_buffer: + #tmp = hpu_buffer.to('cpu') + #dst_kv_caches = hpu_buffer.to('cpu') - i = i + 1 torch.hpu.synchronize() - logger.debug("copy_kv_blocks: copy takes %s" - "|direction=%s|pid=%s|block_size=%s" - "|src_blocks=%s|dst_blocks=%s", - time.perf_counter() - start, direction, os.getpid(), block_size, len(src_block_ids), - len(dst_block_ids)) + logger.info(f"copy_kv_blocks: copy takes {time.perf_counter() - start}|{direction=}|{os.getpid()=}|{block_size=}|{len(src_block_ids)=}|{len(dst_block_ids)=}| {len(src_kv_caches)=} | ") From 08832c00e2fdfe10a05671209841a411bccb12a1 Mon Sep 17 00:00:00 2001 From: Harish Subramony Date: Fri, 17 Oct 2025 21:12:50 +0000 Subject: [PATCH 2/2] working version of 1p1d , 1p2d (no register_kv_cache override) Signed-off-by: Harish Subramony --- .../kv_connector/v1/hpu_nixl_connector.py | 26 +++++++++---------- vllm_gaudi/v1/worker/hpu_model_runner.py | 3 ++- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hpu_nixl_connector.py b/vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hpu_nixl_connector.py index aa8929ee1..601bcdddd 100644 --- a/vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hpu_nixl_connector.py +++ b/vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hpu_nixl_connector.py @@ -921,26 +921,26 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): logger.debug(f'buke register gaudi memory for gdr: {base_addr=}|{hex(base_addr)=}|{cache.data_ptr()=}') else: base_addr = cache.data_ptr() - if base_addr in seen_base_addresses: - continue + if base_addr in seen_base_addresses: + continue - seen_base_addresses.append(base_addr) - curr_tensor_size_bytes = cache.numel() * cache.element_size() + seen_base_addresses.append(base_addr) + curr_tensor_size_bytes = cache.numel() * cache.element_size() - if tensor_size_bytes is None: - tensor_size_bytes = curr_tensor_size_bytes - self.num_blocks = cache.shape[0] + if tensor_size_bytes is None: + tensor_size_bytes = curr_tensor_size_bytes + self.num_blocks = cache.shape[0] - assert cache.shape[0] == self.num_blocks, ( + assert cache.shape[0] == self.num_blocks, ( "All kv cache tensors must have the same number of blocks" ) - self.block_len_per_layer.append( + self.block_len_per_layer.append( curr_tensor_size_bytes // self.num_blocks - ) - self.slot_size_per_layer.append( + ) + self.slot_size_per_layer.append( self.block_len_per_layer[-1] // self.block_size - ) + ) region_len = self.num_blocks * self.block_len # NOTE: use tp_rank for device_id since multi-node TP # is rarely used. @@ -1065,4 +1065,4 @@ def _hpu_data_ptr(tensor_self): torch.Tensor.data_ptr = _hpu_data_ptr NixlConnectorWorker.initialize_host_xfer_buffer = initialize_host_xfer_buffer -NixlConnectorWorker.register_kv_caches = register_kv_caches +#NixlConnectorWorker.register_kv_caches = register_kv_caches diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index a2220203c..4683adc0e 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -4519,7 +4519,7 @@ def copy_kv_blocks( dst_kv_caches[layer_name][1][block_idx*block_size: (1+block_idx)*block_size] = value_cache[block_idx*block_size: (1+block_idx)*block_size].reshape(block_factor, n_kv_heads, remote_block_size, head_dim).permute(0,2,1,3).contiguous().reshape(block_size,n_kv_heads,head_dim).to("hpu") #print('buke addr after:', dst_kv_caches[layer_name][0][block_idx*block_size: (1+block_idx)*block_size].data_ptr()) else: - ''' + #''' if direction == "d2h": # NOTE(chendi): in order to keep host_buffer shape[0] same as tpu and gpu case # so we need to flatten the dst_kv_caches @@ -4543,6 +4543,7 @@ def copy_kv_blocks( #import remote_pdb;remote_pdb.set_trace() dst_kv_caches[layer_name][0].index_put_((dst_slot_mapping,), key_cache.index_select(0, src_slot_mapping).to(target_device)) dst_kv_caches[layer_name][1].index_put_((dst_slot_mapping,), value_cache.index_select(0, src_slot_mapping).to(target_device)) + ''' i = i+1 #dst_kv_caches[layer_name][0][dst_slot_mapping] = key_cache[src_slot_mapping].to(target_device)