Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 95 additions & 93 deletions tests/distributed/test_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,35 @@
import pytest
import torch

from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k
from tests.utils import RemoteOpenAIServer, create_new_process_for_each_test
from vllm.config.model import RunnerOption
from vllm.logger import init_logger

from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import compare_two_settings, create_new_process_for_each_test

logger = init_logger("test_context_parallel")

VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"

CP_TEST_MODELS = [
# TODO support other models
# [LANGUAGE GENERATION]
"deepseek-ai/DeepSeek-V2-Lite-Chat",
"Qwen/Qwen2.5-1.5B-Instruct",
]

# GSM8K eval configuration
NUM_QUESTIONS = 256 # Fast eval for CI
NUM_SHOTS = 5 # Few-shot examples
# tp accuracy with 2% buffer
MIN_ACCURACY = {
# .buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.64,
# .buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml
"Qwen/Qwen2.5-1.5B-Instruct": 0.52,
}


class ParallelSetup(NamedTuple):
tp_size: int
Expand All @@ -38,7 +57,6 @@ class ParallelSetup(NamedTuple):

class CPTestOptions(NamedTuple):
multi_node_only: bool
load_format: str | None = None
attn_backend: str | None = None


Expand All @@ -54,17 +72,20 @@ def detailed(
*,
tp_base: int = 4,
pp_base: int = 1,
dcp_base: int = 1,
dcp_multipliers: list[float] | None = None,
cp_kv_cache_interleave_size: int = 1,
multi_node_only: bool = False,
runner: RunnerOption = "auto",
load_format: str | None = None,
attn_backend: str | None = None,
):
parallel_setups = []
if dcp_multipliers is None:
dcp_multipliers = [
0.5,
]
for eager_mode_val in [False]:
for pp_multiplier in [1]:
for dcp_multiplier in [0.5, 1]:
for dcp_multiplier in dcp_multipliers:
for chunked_prefill_val in [True]:
parallel_setups.append(
ParallelSetup(
Expand All @@ -82,7 +103,6 @@ def detailed(
runner=runner,
test_options=CPTestOptions(
multi_node_only=multi_node_only,
load_format=load_format,
attn_backend=attn_backend,
),
)
Expand All @@ -101,7 +121,24 @@ def iter_params(self, model_id: str):
)


def _compare_cp_with_tp(
CP_TEXT_GENERATION_MODELS = {
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
CPTestSettings.detailed(
dcp_multipliers=[0.5, 1], cp_kv_cache_interleave_size=64
),
],
"Qwen/Qwen2.5-1.5B-Instruct": [
CPTestSettings.detailed(
cp_kv_cache_interleave_size=16, attn_backend="FLASH_ATTN"
),
CPTestSettings.detailed(
cp_kv_cache_interleave_size=16, attn_backend="FLASHINFER"
),
],
}


def _test_cp_gsm8k(
model_id: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
Expand All @@ -121,7 +158,7 @@ def _compare_cp_with_tp(
chunked_prefill,
) = parallel_setup

multi_node_only, load_format, attn_backend = test_options
multi_node_only, attn_backend = test_options

model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_transformers_version(on_fail="skip")
Expand All @@ -130,22 +167,7 @@ def _compare_cp_with_tp(
tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides

if load_format == "dummy":
# Avoid OOM
text_overrides = {
"num_hidden_layers": 4,
"hidden_size": 512,
"intermediate_size": 800,
"num_attention_heads": 4,
"num_key_value_heads": 1,
}

if is_multimodal:
hf_overrides.update({"text_config": text_overrides})
else:
hf_overrides.update(text_overrides)
else:
model_info.check_available_online(on_fail="skip")
model_info.check_available_online(on_fail="skip")

if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
Expand All @@ -157,90 +179,70 @@ def _compare_cp_with_tp(
if multi_node_only and not VLLM_MULTI_NODE:
pytest.skip("Not in multi-node setting")

common_args = [
server_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"4096",
"--max-num-seqs",
"8",
"64",
]
if chunked_prefill:
common_args.append("--enable-chunked-prefill")
server_args.append("--enable-chunked-prefill")
if eager_mode:
common_args.append("--enforce-eager")
server_args.append("--enforce-eager")
if runner != "auto":
common_args.extend(["--runner", runner])
server_args.extend(["--runner", runner])
if trust_remote_code:
common_args.append("--trust-remote-code")
server_args.append("--trust-remote-code")
if tokenizer_mode:
common_args.extend(["--tokenizer-mode", tokenizer_mode])
if load_format:
common_args.extend(["--load-format", load_format])
server_args.extend(["--tokenizer-mode", tokenizer_mode])
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])

if not attn_backend:
cp_env = tp_env = {}
else:
cp_env = tp_env = {
"VLLM_ATTENTION_BACKEND": attn_backend,
}

cp_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--pipeline-parallel-size",
str(pp_size),
"--decode-context-parallel-size",
str(dcp_size),
"--dcp-kv-cache-interleave-size",
str(cp_kv_cache_interleave_size),
"--distributed-executor-backend",
distributed_backend,
]
server_args.extend(["--hf-overrides", json.dumps(hf_overrides)])

server_args.extend(
[
"--tensor-parallel-size",
str(tp_size),
"--pipeline-parallel-size",
str(pp_size),
"--decode-context-parallel-size",
str(dcp_size),
"--dcp-kv-cache-interleave-size",
str(cp_kv_cache_interleave_size),
"--distributed-executor-backend",
distributed_backend,
]
)

tp_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--pipeline-parallel-size",
str(pp_size),
"--distributed-executor-backend",
distributed_backend,
]
server_env = {}
if attn_backend:
server_env["VLLM_ATTENTION_BACKEND"] = attn_backend

compare_two_settings(
with RemoteOpenAIServer(
model_id,
cp_args,
tp_args,
cp_env,
tp_env,
method=method,
server_args,
env_dict=server_env,
max_wait_seconds=720,
)


CP_TEXT_GENERATION_MODELS = {
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
CPTestSettings.detailed(),
CPTestSettings.detailed(tp_base=2),
CPTestSettings.detailed(tp_base=2, cp_kv_cache_interleave_size=64),
],
"bigcode/gpt_bigcode-santacoder": [
CPTestSettings.detailed(),
CPTestSettings.detailed(tp_base=2),
],
}
) as remote_server:
host = f"http://{remote_server.host}"
port = remote_server.port

# Run GSM8K evaluation
results = evaluate_gsm8k(
num_questions=NUM_QUESTIONS,
num_shots=NUM_SHOTS,
host=host,
port=port,
)

CP_TEST_MODELS = [
# TODO support other models
# [LANGUAGE GENERATION]
"deepseek-ai/DeepSeek-V2-Lite-Chat",
"bigcode/gpt_bigcode-santacoder",
]
# Validate accuracy is reasonable
accuracy = results["accuracy"]
min_accuracy = MIN_ACCURACY[model_id]
assert accuracy >= min_accuracy, (
f"TP+DCP accuracy too low: {accuracy:.3f} < {min_accuracy:.3f}"
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -274,12 +276,12 @@ def test_cp_generation(
):
pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher")
if (
model_id == "bigcode/gpt_bigcode-santacoder"
model_id == "Qwen/Qwen2.5-1.5B-Instruct"
and torch.cuda.get_device_capability() != (9, 0)
):
pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0")

_compare_cp_with_tp(
_test_cp_gsm8k(
model_id,
parallel_setup,
distributed_backend,
Expand Down
6 changes: 5 additions & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,11 @@ def check_available_online(
trust_remote_code=True,
),
"Qwen2ForCausalLM": _HfExamplesInfo(
"Qwen/Qwen2-0.5B-Instruct", extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"}
"Qwen/Qwen2-0.5B-Instruct",
extras={
"2.5": "Qwen/Qwen2.5-0.5B-Instruct",
"2.5-1.5B": "Qwen/Qwen2.5-1.5B-Instruct",
},
),
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
Expand Down