From a38b978b85e4a07a7fd296b6ce9ffd164c8420d9 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 7 Jul 2025 08:14:07 +0000 Subject: [PATCH 1/4] port cuda graph Signed-off-by: vllmellm --- csrc/attention/merge_attn_states.cu | 8 + vllm/_custom_ops.py | 17 +- vllm/attention/layer.py | 17 +- vllm/compilation/backends.py | 206 +----------- vllm/compilation/base_piecewise_backend.py | 72 ++++ vllm/compilation/counter.py | 2 +- vllm/compilation/cuda_piecewise_backend.py | 218 +++++++++++++ vllm/config.py | 19 +- vllm/entrypoints/llm.py | 3 +- vllm/forward_context.py | 28 +- .../model_executor/layers/rotary_embedding.py | 5 +- vllm/model_executor/models/deepseek_v2.py | 31 +- vllm/platforms/interface.py | 7 + vllm/platforms/rocm.py | 4 + vllm/v1/attention/backends/mla/common.py | 182 +++++------ vllm/v1/attention/backends/mla/flashmla.py | 2 +- .../attention/backends/mla/rocm_aiter_mla.py | 17 +- vllm/v1/attention/backends/mla/triton_mla.py | 2 +- vllm/v1/attention/backends/utils.py | 308 ++++++++++++++++++ vllm/v1/worker/block_table.py | 11 + vllm/v1/worker/gpu_input_batch.py | 2 + vllm/v1/worker/gpu_model_runner.py | 149 +++++++-- 22 files changed, 942 insertions(+), 368 deletions(-) create mode 100644 vllm/compilation/base_piecewise_backend.py create mode 100644 vllm/compilation/cuda_piecewise_backend.py create mode 100644 vllm/v1/attention/backends/utils.py diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu index 14e5edd7e283..6bee9e4ce116 100644 --- a/csrc/attention/merge_attn_states.cu +++ b/csrc/attention/merge_attn_states.cu @@ -143,6 +143,14 @@ void merge_attn_states_launcher(torch::Tensor& output, const uint pack_size = 16 / sizeof(scalar_t); TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size); + TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1, + "output heads must be contiguous in memory"); + TORCH_CHECK( + prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1, + "prefix_output heads must be contiguous in memory"); + TORCH_CHECK( + suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1, + "suffix_output heads must be contiguous in memory"); float* output_lse_ptr = nullptr; if (output_lse.has_value()) { output_lse_ptr = output_lse.value().data_ptr(); diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index bd930bb90653..59f460e86912 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -158,8 +158,13 @@ def rotary_embedding( cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: - torch.ops._C.rotary_embedding(positions, query, key, head_size, - cos_sin_cache, is_neox) + # TODO: Remove this contiguous call when the kernel is updated to support tensor slices + query_contiguous = query.contiguous() + key_contiguous = key.contiguous() + torch.ops._C.rotary_embedding(positions, query_contiguous, key_contiguous, + head_size, cos_sin_cache, is_neox) + query.copy_(query_contiguous) + key.copy_(key_contiguous) def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, @@ -167,9 +172,15 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, cos_sin_cache: torch.Tensor, is_neox: bool, rot_dim: int, cos_sin_cache_offsets: torch.Tensor) -> None: - torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, + # TODO: Remove this contiguous call when the kernel is updated to support tensor slices + query_contiguous = query.contiguous() + key_contiguous = key.contiguous() + torch.ops._C.batched_rotary_embedding(positions, query_contiguous, + key_contiguous, head_size, cos_sin_cache, is_neox, rot_dim, cos_sin_cache_offsets) + query.copy_(query_contiguous) + key.copy_(key_contiguous) # layer norm ops diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 68452f4c03b0..eefcd99e428c 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -183,7 +183,9 @@ def forward( `vllm.forward_context.get_forward_context().attn_metadata`. """ if self.calculate_kv_scales: - attn_metadata = get_forward_context().attn_metadata + attn_metadata: ForwardContext = get_forward_context().attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] if attn_metadata.enable_kv_scales_calculation: self.calc_kv_scales(query, key, value) if self.use_output: @@ -209,6 +211,8 @@ def forward( if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, query, @@ -225,6 +229,8 @@ def forward( if self.use_direct_call: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] return self.impl.forward(self, query, key, value, self_kv_cache, attn_metadata) @@ -343,6 +349,7 @@ def wait_for_kv_layer_from_connector(layer_name: str): if attn_metadata is None: return + assert isinstance(attn_metadata, dict) connector.wait_for_layer_load(layer_name) @@ -360,6 +367,7 @@ def maybe_save_kv_layer_to_connector( if attn_metadata is None: return + assert isinstance(attn_metadata, dict) connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata) @@ -372,6 +380,10 @@ def unified_attention( wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + attn_metadata = forward_context.attn_metadata self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] @@ -410,6 +422,9 @@ def unified_attention_with_output( wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + attn_metadata = forward_context.attn_metadata self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 45988c2e9b0d..9c62fd014c18 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -5,9 +5,7 @@ import os import pprint import time -from contextlib import ExitStack -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple -from unittest.mock import patch +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple import torch import torch.fx as fx @@ -15,12 +13,12 @@ import vllm.envs as envs from vllm.config import CompilationConfig, VllmConfig from vllm.logger import init_logger -from vllm.utils import weak_ref_tensors +from vllm.platforms import current_platform +from vllm.utils import resolve_obj_by_qualname from .compiler_interface import EagerAdaptor, InductorAdaptor from .counter import compilation_counter from .inductor_pass import InductorPass -from .monitor import end_monitoring_torch_compile from .pass_manager import PostGradPassManager logger = init_logger(__name__) @@ -267,7 +265,9 @@ def call_module(self, target: torch.fx.node.Target, num_graphs=len(self.compile_submod_names), runtime_shape=None) - self.module.__dict__[target] = PiecewiseBackend( + piecewise_backend = resolve_obj_by_qualname( + current_platform.get_piecewise_backend_cls()) + self.module.__dict__[target] = piecewise_backend( submod, self.vllm_config, self.graph_pool, index, len(self.compile_submod_names), sym_shape_indices, compiled_graph_for_general_shape, self.vllm_backend) @@ -515,197 +515,3 @@ def copy_and_call(*args): return self.split_gm(*list_args) return copy_and_call - - -@dataclasses.dataclass -class ConcreteSizeEntry: - runtime_shape: int - need_to_compile: bool # the size is in compile_sizes - use_cudagraph: bool # the size is in cudagraph_capture_sizes - - compiled: bool = False - runnable: Callable = None # type: ignore - num_finished_warmup: int = 0 - cudagraph: Optional[torch.cuda.CUDAGraph] = None - output: Optional[Any] = None - - # for cudagraph debugging, track the input addresses - # during capture, and check if they are the same during replay - input_addresses: Optional[List[int]] = None - - -class PiecewiseBackend: - - def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, - graph_pool: Any, piecewise_compile_index: int, - total_piecewise_compiles: int, sym_shape_indices: List[int], - compiled_graph_for_general_shape: Callable, - vllm_backend: VllmBackend): - """ - The backend for piecewise compilation. - It mainly handles the compilation and cudagraph capturing. - - We will compile `self.graph` once for the general shape, - and then compile for different shapes specified in - `compilation_config.compile_sizes`. - - Independently, we will capture cudagraph for different shapes. - - If a shape needs both compilation and cudagraph, we will - compile it first, and then capture cudagraph. - """ - self.graph = graph - self.vllm_config = vllm_config - self.compilation_config = vllm_config.compilation_config - self.graph_pool = graph_pool - self.piecewise_compile_index = piecewise_compile_index - self.total_piecewise_compiles = total_piecewise_compiles - self.vllm_backend = vllm_backend - - self.is_first_graph = piecewise_compile_index == 0 - self.is_last_graph = ( - piecewise_compile_index == total_piecewise_compiles - 1) - - self.compile_sizes: Set[int] = set( - self.compilation_config.compile_sizes) - self.cudagraph_capture_sizes: Set[int] = set( - self.compilation_config.cudagraph_capture_sizes - ) if self.compilation_config.use_cudagraph else set() - - self.first_run_finished = False - - self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa - - self.sym_shape_indices = sym_shape_indices - - self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" - - # the entries for different shapes that we need to either - # compile or capture cudagraph - self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} - - # to_be_compiled_sizes tracks the remaining sizes to compile, - # and updates during the compilation process, so we need to copy it - self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy() - for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): - self.concrete_size_entries[shape] = ConcreteSizeEntry( - runtime_shape=shape, - need_to_compile=shape in self.compile_sizes, - use_cudagraph=shape in self.cudagraph_capture_sizes, - ) - - def check_for_ending_compilation(self): - if self.is_last_graph and not self.to_be_compiled_sizes: - # no specific sizes to compile - # save the hash of the inductor graph for the next run - self.vllm_backend.compiler_manager.save_to_file() - end_monitoring_torch_compile(self.vllm_config) - - def __call__(self, *args) -> Any: - if not self.first_run_finished: - self.first_run_finished = True - self.check_for_ending_compilation() - return self.compiled_graph_for_general_shape(*args) - - runtime_shape = args[self.sym_shape_indices[0]] - if runtime_shape not in self.concrete_size_entries: - # we don't need to do anything for this shape - return self.compiled_graph_for_general_shape(*args) - - entry = self.concrete_size_entries[runtime_shape] - - if entry.runnable is None: - entry.runnable = self.compiled_graph_for_general_shape - - if entry.need_to_compile and not entry.compiled: - entry.compiled = True - self.to_be_compiled_sizes.remove(runtime_shape) - # args are real arguments - entry.runnable = self.vllm_backend.compiler_manager.compile( - self.graph, - args, - self.compilation_config.inductor_compile_config, - self.compilation_config, - graph_index=self.piecewise_compile_index, - num_graphs=self.total_piecewise_compiles, - runtime_shape=runtime_shape) - - # finished compilations for all required shapes - if self.is_last_graph and not self.to_be_compiled_sizes: - self.check_for_ending_compilation() - - if not entry.use_cudagraph: - return entry.runnable(*args) - - if entry.cudagraph is None: - if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa - entry.num_finished_warmup += 1 - if self.is_first_graph: - logger.debug( - "Warming up %s/%s for shape %s", - entry.num_finished_warmup, - self.compilation_config.cudagraph_num_of_warmups, - runtime_shape) - return entry.runnable(*args) - - if self.is_first_graph: - # Since we capture cudagraph for many different shapes and - # capturing is fast, we don't need to log it for every shape. - # We only log it in the debug mode. - logger.debug("Capturing a cudagraph for shape %s", - runtime_shape) - - input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - entry.input_addresses = input_addresses - cudagraph = torch.cuda.CUDAGraph() - - with ExitStack() as stack: - if not self.is_first_graph: - # during every model forward, we will capture - # many pieces of cudagraphs (roughly one per layer). - # running gc again and again across layers will - # make the cudagraph capture very slow. - # therefore, we only run gc for the first graph, - # and disable gc for the rest of the graphs. - stack.enter_context(patch("gc.collect", lambda: None)) - stack.enter_context( - patch("torch.cuda.empty_cache", lambda: None)) - - # mind-exploding: carefully manage the reference and memory. - with torch.cuda.graph(cudagraph, pool=self.graph_pool): - # `output` is managed by pytorch's cudagraph pool - output = entry.runnable(*args) - if self.is_last_graph: - # by converting it to weak ref, - # the original `output` will immediately be released - # to save memory. It is only safe to do this for - # the last graph, because the output of the last graph - # will not be used by any other cuda graph. - output = weak_ref_tensors(output) - - # here we always use weak ref for the output - # to save memory - entry.output = weak_ref_tensors(output) - entry.cudagraph = cudagraph - - compilation_counter.num_cudagraph_caputured += 1 - - # important: we need to return the output, rather than - # the weak ref of the output, so that pytorch can correctly - # manage the memory during cuda graph capture - return output - - if self.is_debugging_mode: - # check if the input addresses are the same - new_input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - assert new_input_addresses == entry.input_addresses, ( - "Input addresses for cudagraphs are different during replay." - f" Expected {entry.input_addresses}, got {new_input_addresses}" - ) - - entry.cudagraph.replay() - return entry.output diff --git a/vllm/compilation/base_piecewise_backend.py b/vllm/compilation/base_piecewise_backend.py new file mode 100644 index 000000000000..4d7aeeb4d03e --- /dev/null +++ b/vllm/compilation/base_piecewise_backend.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Callable, Protocol + +import torch.fx as fx + +from vllm.compilation.backends import VllmBackend +from vllm.config import VllmConfig + + +class AbstractPiecewiseBackend(Protocol): + """ + PiecewiseBackend interface that allows platforms to extend + piecewise static graph. + """ + + def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, + graph_pool: Any, piecewise_compile_index: int, + total_piecewise_compiles: int, sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + vllm_backend: VllmBackend, **kwargs): + """ + Initializes the PiecewiseBackend class with compilation and + execution-related configurations. + + This class handles piecewise compilation, graph capturing, + and dispatching for specific input shapes. + + Args: + graph (fx.GraphModule): The graph represented in fx. + vllm_config (VllmConfig): Global configuration for vLLM. + graph_pool (Any): + Graph memory pool handle, e.g., + `torch.cuda.graph_pool_handle()`. + piecewise_compile_index (int): + Index of the current piecewise subgraph. + total_piecewise_compiles (int): + Total number of piecewise-compiled graphs. + sym_shape_indices (list[int]): + Indices of symbolic shape. + compiled_graph_for_general_shape (Callable): + Callable that executes the graph compiled for general shapes. + vllm_backend (VllmBackend): + Backend compiler that manages compilation and graph runtime + for vLLM. + + Keyword Args: + kwargs: Additional keyword arguments reserved for future + extensions or custom platforms. + """ + raise NotImplementedError + + def __call__(self, *args) -> Any: + """Executes the compiled graph for given input args. + + If this is the first invocation, executes the general compiled graph + and initiates the compilation process tracking. For subsequent calls, + dynamically dispatches execution to either a compiled graph or a static + graph based on the input shape. + + Args: + *args: Variable length input arguments to be passed into the + graph. The symbolic shape is expected to be in position + `sym_shape_indices[0]`. + + Returns: + Any: Output of the executed graph. This can be from the general + compiled graph, a specialized compiled version for the given shape, + or a replayed static graph. + """ + raise NotImplementedError diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index 5be452593c62..d2c847e28ce2 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -14,7 +14,7 @@ class CompilationCounter: # not including the splitting ops num_piecewise_capturable_graphs_seen: int = 0 num_backend_compilations: int = 0 - num_cudagraph_caputured: int = 0 + num_cudagraph_captured: int = 0 def clone(self) -> "CompilationCounter": return copy.deepcopy(self) diff --git a/vllm/compilation/cuda_piecewise_backend.py b/vllm/compilation/cuda_piecewise_backend.py new file mode 100644 index 000000000000..8c49ea6cc107 --- /dev/null +++ b/vllm/compilation/cuda_piecewise_backend.py @@ -0,0 +1,218 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +from contextlib import ExitStack +from typing import Any, Callable, Optional +from unittest.mock import patch + +import torch +import torch.fx as fx + +import vllm.envs as envs +from vllm.compilation.backends import VllmBackend +from vllm.compilation.counter import compilation_counter +from vllm.compilation.monitor import end_monitoring_torch_compile +from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.utils import weak_ref_tensors + +logger = init_logger(__name__) + + +@dataclasses.dataclass +class ConcreteSizeEntry: + runtime_shape: int + need_to_compile: bool # the size is in compile_sizes + use_cudagraph: bool # the size is in cudagraph_capture_sizes + + compiled: bool = False + runnable: Callable = None # type: ignore + num_finished_warmup: int = 0 + cudagraph: Optional[torch.cuda.CUDAGraph] = None + output: Optional[Any] = None + + # for cudagraph debugging, track the input addresses + # during capture, and check if they are the same during replay + input_addresses: Optional[list[int]] = None + + +class CUDAPiecewiseBackend: + + def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, + graph_pool: Any, piecewise_compile_index: int, + total_piecewise_compiles: int, sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + vllm_backend: VllmBackend): + """ + The backend for piecewise compilation. + It mainly handles the compilation and cudagraph capturing. + + We will compile `self.graph` once for the general shape, + and then compile for different shapes specified in + `compilation_config.compile_sizes`. + + Independently, we will capture cudagraph for different shapes. + + If a shape needs both compilation and cudagraph, we will + compile it first, and then capture cudagraph. + """ + self.graph = graph + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.graph_pool = graph_pool + self.piecewise_compile_index = piecewise_compile_index + self.total_piecewise_compiles = total_piecewise_compiles + self.vllm_backend = vllm_backend + + self.is_first_graph = piecewise_compile_index == 0 + self.is_last_graph = ( + piecewise_compile_index == total_piecewise_compiles - 1) + + self.compile_sizes: set[int] = set( + self.compilation_config.compile_sizes) + self.cudagraph_capture_sizes: set[int] = set( + self.compilation_config.cudagraph_capture_sizes + ) if self.compilation_config.use_cudagraph else set() + + self.first_run_finished = False + + self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa + + self.sym_shape_indices = sym_shape_indices + + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + + # the entries for different shapes that we need to either + # compile or capture cudagraph + self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} + + # to_be_compiled_sizes tracks the remaining sizes to compile, + # and updates during the compilation process, so we need to copy it + self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() + for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): + self.concrete_size_entries[shape] = ConcreteSizeEntry( + runtime_shape=shape, + need_to_compile=shape in self.compile_sizes, + use_cudagraph=shape in self.cudagraph_capture_sizes, + ) + + def check_for_ending_compilation(self): + if self.is_last_graph and not self.to_be_compiled_sizes: + # no specific sizes to compile + # save the hash of the inductor graph for the next run + self.vllm_backend.compiler_manager.save_to_file() + end_monitoring_torch_compile(self.vllm_config) + + def __call__(self, *args) -> Any: + if not self.first_run_finished: + self.first_run_finished = True + self.check_for_ending_compilation() + return self.compiled_graph_for_general_shape(*args) + + runtime_shape = args[self.sym_shape_indices[0]] + if runtime_shape not in self.concrete_size_entries: + # we don't need to do anything for this shape + return self.compiled_graph_for_general_shape(*args) + + entry = self.concrete_size_entries[runtime_shape] + + if entry.runnable is None: + entry.runnable = self.compiled_graph_for_general_shape + + if entry.need_to_compile and not entry.compiled: + entry.compiled = True + self.to_be_compiled_sizes.remove(runtime_shape) + # args are real arguments + entry.runnable = self.vllm_backend.compiler_manager.compile( + self.graph, + args, + self.compilation_config.inductor_compile_config, + self.compilation_config, + graph_index=self.piecewise_compile_index, + num_graphs=self.total_piecewise_compiles, + runtime_shape=runtime_shape) + + # finished compilations for all required shapes + if self.is_last_graph and not self.to_be_compiled_sizes: + self.check_for_ending_compilation() + + # Skip CUDA graphs if this entry doesn't use them OR + # if we're supposed to skip them globally + skip_cuda_graphs = get_forward_context().skip_cuda_graphs + if not entry.use_cudagraph or skip_cuda_graphs: + return entry.runnable(*args) + + if entry.cudagraph is None: + if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa + entry.num_finished_warmup += 1 + if self.is_first_graph: + logger.debug( + "Warming up %s/%s for shape %s", + entry.num_finished_warmup, + self.compilation_config.cudagraph_num_of_warmups, + runtime_shape) + return entry.runnable(*args) + + if self.is_first_graph: + # Since we capture cudagraph for many different shapes and + # capturing is fast, we don't need to log it for every shape. + # We only log it in the debug mode. + logger.debug("Capturing a cudagraph for shape %s", + runtime_shape) + + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses + cudagraph = torch.cuda.CUDAGraph() + + with ExitStack() as stack: + if not self.is_first_graph: + # during every model forward, we will capture + # many pieces of cudagraphs (roughly one per layer). + # running gc again and again across layers will + # make the cudagraph capture very slow. + # therefore, we only run gc for the first graph, + # and disable gc for the rest of the graphs. + stack.enter_context(patch("gc.collect", lambda: None)) + stack.enter_context( + patch("torch.cuda.empty_cache", lambda: None)) + + # mind-exploding: carefully manage the reference and memory. + with torch.cuda.graph(cudagraph, pool=self.graph_pool): + # `output` is managed by pytorch's cudagraph pool + output = entry.runnable(*args) + if self.is_last_graph: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph, because the output of the last graph + # will not be used by any other cuda graph. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.cudagraph = cudagraph + + compilation_counter.num_cudagraph_captured += 1 + + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during cuda graph capture + return output + + if self.is_debugging_mode: + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + "Input addresses for cudagraphs are different during replay." + f" Expected {entry.input_addresses}, got {new_input_addresses}" + ) + + entry.cudagraph.replay() + return entry.output diff --git a/vllm/config.py b/vllm/config.py index 5b5ac40f6aa2..97b8518fa9c2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3344,6 +3344,10 @@ class CompilationConfig(BaseModel): are always used, it can set this to False. Otherwise, it should set this to True, and the compiler will copy the input to an internally managed buffer. Default is False. + - full_cuda_graph: whether to use a full cuda graph for the entire forward + pass rather than splitting certain operations such as attention into subgraphs. + Thus this flag cannot be used together with splitting_ops. This may provide + performance benefits for smaller models - Inductor compilation: - use_inductor: whether to use inductor compilation. - False: inductor compilation is not used. graph runs in eager. @@ -3388,6 +3392,7 @@ class CompilationConfig(BaseModel): cudagraph_num_of_warmups: int = 0 cudagraph_capture_sizes: Optional[list[int]] = None cudagraph_copy_inputs: bool = False + full_cuda_graph: bool = False class PassConfig(BaseModel): """ @@ -3606,10 +3611,13 @@ def init_with_cudagraph_sizes(self, self.max_capture_size] = self.max_capture_size def set_splitting_ops_for_v1(self): - # If default, override splitting ops for piecewise cudagraph on V1. # NOTE: this function needs to be called + if self.splitting_ops and self.full_cuda_graph: + raise ValueError("full_cuda_graph cannot be used together with " + "splitting_ops, as Full CUDA graph will override " + f"the splitting_ops: {self.splitting_ops}") if not self.splitting_ops: - self.splitting_ops = [ + self.splitting_ops = [] if self.full_cuda_graph else [ "vllm.unified_attention", "vllm.unified_attention_with_output", ] @@ -3862,6 +3870,13 @@ def __post_init__(self): self.compilation_config.level = CompilationLevel.NO_COMPILATION + if self.compilation_config.full_cuda_graph and \ + not self.model_config.disable_cascade_attn: + logger.warning_once( + "full_cuda_graph is not supported with " + "cascade attention. Disabling cascade attention.") + self.model_config.disable_cascade_attn = True + if self.model_config and self.model_config.use_mla and \ not (current_platform.is_cuda() or current_platform.is_rocm()): logger.info( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 57c7ab73de37..95f7f19d5f9c 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -186,7 +186,8 @@ def __init__( # After positional args are removed, move this right below `model` task: TaskOption = "auto", override_pooler_config: Optional[PoolerConfig] = None, - compilation_config: Optional[Union[int, dict[str, Any]]] = None, + compilation_config: Optional[Union[int, dict[str, Any], + CompilationConfig]] = None, **kwargs, ) -> None: ''' diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 06790d8ee2f8..f6431600a831 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -4,7 +4,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union import torch import torch.distributed as dist @@ -39,11 +39,19 @@ class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context no_compile_layers: dict[str, Any] # TODO: extend to support per-layer dynamic forward context - attn_metadata: "AttentionMetadata" # set dynamically for each forward pass - # TODO: remove after making all virtual_engines share the same kv cache + """ + Type AttentionMetadata for v0, + Type Dict[str, AttentionMetadata] for v1, map from layer_name of each + attention layer to its attention metadata + set dynamically for each forward pass + """ + attn_metadata: Union["AttentionMetadata", dict[ + str, + "AttentionMetadata"]] # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None + skip_cuda_graphs: bool = False _forward_context: Optional[ForwardContext] = None @@ -58,10 +66,13 @@ def get_forward_context() -> ForwardContext: @contextmanager -def set_forward_context(attn_metadata: Any, - vllm_config: VllmConfig, - virtual_engine: int = 0, - num_tokens: int = 0): +def set_forward_context( + attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: Optional[int] = None, + skip_cuda_graphs: bool = False, +): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -101,7 +112,8 @@ def set_forward_context(attn_metadata: Any, static_forward_context, virtual_engine=virtual_engine, attn_metadata=attn_metadata, - dp_metadata=dp_metadata) + dp_metadata=dp_metadata, + skip_cuda_graphs=skip_cuda_graphs) # KVConnector: trigger (possibly async) load before forward. # Each attn layer will block until the reading is complete. diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 624ed63ab8b4..978952040807 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -778,8 +778,9 @@ def forward( query_pass = query[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim:] - self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( - positions.device) + if self.cos_sin_cache.device != positions.device: + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( + positions.device) cos_sin = self.cos_sin_cache[torch.add(positions, offsets) if offsets is not None else positions] cos, sin = cos_sin.chunk(2, dim=-1) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 79934cafb5a8..4e02caee144a 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -454,10 +454,7 @@ def __init__( qk_rope_head_dim=self.qk_rope_head_dim, qk_head_dim=self.qk_head_dim, v_head_dim=self.v_head_dim, - rotary_emb=self.rotary_emb, - q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, kv_b_proj=self.kv_b_proj, - o_proj=self.o_proj, ) self.prefix = prefix @@ -469,17 +466,29 @@ def forward( hidden_states: torch.Tensor, ) -> torch.Tensor: if self.q_lora_rank is not None: - ckq = self.q_a_proj(hidden_states)[0] - hidden_states_or_q_c = self.q_a_layernorm(ckq) + q_c = self.q_a_proj(hidden_states)[0] + q_c = self.q_a_layernorm(q_c) + q = self.q_b_proj(q_c)[0] else: - hidden_states_or_q_c = hidden_states + q = self.q_proj(hidden_states)[0] kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) - return self.mla_attn(hidden_states_or_q_c, - kv_c_normed, - k_pe, - output_shape=hidden_states.shape) + + q = q.view(-1, self.num_local_heads, self.qk_head_dim) + # Add head dim of 1 to k_pe + k_pe = k_pe.unsqueeze(1) + + q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( + positions, q[..., self.qk_nope_head_dim:], k_pe) + + attn_out = self.mla_attn( + q, + kv_c_normed, + k_pe, + output_shape=(hidden_states.shape[0], + self.num_local_heads * self.v_head_dim)) + return self.o_proj(attn_out)[0] class DeepseekV2DecoderLayer(nn.Module): @@ -837,4 +846,4 @@ def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, for i in range(config.num_nextn_predict_layers): if weight_name.startswith(f"model.layers.{layer_idx+i}."): return layer_idx + i - return None \ No newline at end of file + return None diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 5931a620dba7..22184d0c0e40 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -406,6 +406,13 @@ def validate_request( ) -> None: """Raises if this request is unsupported on this platform""" + @classmethod + def get_piecewise_backend_cls(cls) -> str: + """ + Get piecewise backend class for piecewise graph. + """ + return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 0e94acd60b80..bea89803ce3b 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -344,3 +344,7 @@ def use_custom_allreduce(cls) -> bool: gcn_arch = torch.cuda.get_device_properties(0).gcnArchName supported_archs = ['gfx94'] return any(gfx in gcn_arch for gfx in supported_archs) + + @classmethod + def get_piecewise_backend_cls(cls) -> str: + return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f28dd02de547..520de5eee22b 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -199,11 +199,14 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, RowParallelLinear, + LinearBase, UnquantizedLinearMethod) -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform from vllm.utils import cdiv, round_down +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version try: @@ -255,6 +258,10 @@ def get_supported_head_sizes() -> list[int]: def use_cascade_attention(*args, **kwargs) -> bool: return False + def can_run_in_cudagraph( + self, common_attn_metadata: CommonAttentionMetadata) -> bool: + return common_attn_metadata.max_query_len == 1 + @dataclass class MLACommonPrefillMetadata: @@ -270,9 +277,6 @@ class ChunkedContextMetadata: max_seq_lens: list[int] workspace: torch.Tensor - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor block_table: torch.Tensor query_start_loc: torch.Tensor max_query_len: int @@ -281,9 +285,6 @@ class ChunkedContextMetadata: @dataclass class MLACommonDecodeMetadata: - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor block_table: torch.Tensor seq_lens: torch.Tensor @@ -316,9 +317,6 @@ class MLACommonMetadata(Generic[D]): num_decode_tokens: int num_prefills: int - # For logging. - num_input_tokens: int = 0 # Number of tokens including padding. - # The dimension of the attention heads head_dim: Optional[int] = None @@ -337,7 +335,7 @@ def __post_init__(self): M = TypeVar("M", bound=MLACommonMetadata) -class MLACommonMetadataBuilder(Generic[M]): +class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): """ NOTE: Please read the comment at the top of the file before trying to understand this class @@ -345,6 +343,8 @@ class MLACommonMetadataBuilder(Generic[M]): def __init__(self, runner: "GPUModelRunner", + kv_cache_spec: AttentionSpec, + block_table: BlockTable, metadata_cls: Optional[type[M]] = None): self.metadata_cls = metadata_cls \ if metadata_cls is not None else MLACommonMetadata @@ -357,10 +357,11 @@ def __init__(self, runner.parallel_config) self.mla_dims = get_mla_dims(model_config) self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3) + self.kv_cache_spec = kv_cache_spec # Dont try to access the runner on AMD if self.aot_schedule: - self.page_size = self.runner.block_size + self.page_size = self.kv_cache_spec.block_size if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( @@ -387,6 +388,8 @@ def __init__(self, device=runner.device, ) + self.block_table = block_table + def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: # We now want to reorder the batch so that the "decode" requests are and @@ -447,38 +450,58 @@ def reorder_batch(self, input_batch: "InputBatch", return modified_batch - def _build_decode(self, input_positions: torch.Tensor, - block_table: torch.Tensor, seq_lens: torch.Tensor): + def _build_decode(self, block_table: torch.Tensor, seq_lens: torch.Tensor): return MLACommonDecodeMetadata( - input_positions=input_positions, block_table=block_table, seq_lens=seq_lens, ) - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int) -> M: + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata) -> M: + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with MLA. + """ + m = common_attn_metadata + assert m.num_reqs == m.num_actual_tokens, \ + "MLA only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + + m.max_query_len = 1 # decode-only + + # Update state usually set in reorder_batch. + self._num_decodes = m.num_reqs + self._num_decode_tokens = m.num_actual_tokens + self._num_prefills = 0 + self._num_prefill_tokens = 0 + return self.build(0, m) + + def build(self, common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata) -> M: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + assert self._num_decodes + self._num_prefills == num_reqs # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. device = self.runner.device - block_table = ( - self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) - query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( - device, non_blocking=True) - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( - device, non_blocking=True).long() - input_positions = self.runner.positions_cpu[:num_actual_tokens].to( - device, non_blocking=True).long() - - seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] - seq_lens = seq_lens_cpu.to(device, non_blocking=True) + block_table = self.block_table + block_table_tensor = block_table.get_device_tensor()[:num_reqs] + block_table.slot_mapping[:num_actual_tokens].copy_( + block_table.slot_mapping_cpu[:num_actual_tokens], + non_blocking=True) + block_table.slot_mapping[num_actual_tokens:].fill_(-1) + slot_mapping = block_table.slot_mapping[:num_actual_tokens] + + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens prefill_metadata = None if self._num_prefills > 0: reqs_start = self._num_decodes # prefill_start - tokens_start = self._num_decode_tokens context_lens_cpu = self.runner.input_batch.\ num_computed_tokens_cpu_tensor[reqs_start:num_reqs] @@ -500,6 +523,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, # longer context lengths max_context_chunk = (self.chunked_prefill_workspace_size // num_prefills_with_context_cpu) + if self.aot_schedule: # align max_context_chunk to page_size by rounding down, # currently the `gather_cache` kernel cannot handle @@ -546,8 +570,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, self.chunked_prefill_workspace_size prefill_metadata = MLACommonPrefillMetadata( - input_positions=input_positions[tokens_start:], - block_table=block_table[reqs_start:, ...], + block_table=block_table_tensor[reqs_start:, ...], query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, @@ -556,8 +579,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, decode_metadata = None if self._num_decodes > 0: decode_metadata = self._build_decode( - input_positions=input_positions[:self._num_decode_tokens], - block_table=block_table[:self._num_decodes, ...], + block_table_tensor=block_table_tensor[:self._num_decodes, ...], seq_lens=seq_lens[:self._num_decodes], ) @@ -600,13 +622,7 @@ def __init__( qk_rope_head_dim: int, qk_head_dim: int, v_head_dim: int, - rotary_emb: RotaryEmbedding, - # q_proj should be q_b_proj if q_lora_rank is not None, but from an - # attention backend perspective we rely on the layer to pass in the - # correct matrix - q_proj: ColumnParallelLinear, kv_b_proj: ColumnParallelLinear, - o_proj: RowParallelLinear, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -621,17 +637,7 @@ def __init__( self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim - # Hack for V1 for now to avoid torch library overhead (since we are - # already inside an attention custom op), pull out the forward - # method from the rotary embedding and call it directly - # TODO(lucas): we should probably find a cleaner way to do this - self.rotary_emb = rotary_emb.forward_native - if current_platform.is_cuda(): - self.rotary_emb = rotary_emb.forward_cuda - - self.q_proj = q_proj self.kv_b_proj = kv_b_proj - self.o_proj = o_proj self.vllm_flash_attn_version = get_flash_attn_version() # Handle the differences between the flash_attn_varlen from flash_attn @@ -688,27 +694,13 @@ def _flash_attn_varlen_diff_headdims(self, return attn_out, lse return attn_out - def _v_up_proj_and_o_proj(self, x): + def _v_up_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) # Multiply (N, B, L) x (N, L, V) -> (N, B, V) x = torch.bmm(x, self.W_UV) # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - return self.o_proj(x)[0] - - # Return `ql_nope`, `q_pe` - def _q_proj_and_k_up_proj(self, x): - q_nope, q_pe = self.q_proj(x)[0]\ - .view(-1, self.num_heads, self.qk_head_dim)\ - .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - - # Convert from (B, N, P) to (N, B, P) - q_nope = q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - ql_nope = torch.bmm(q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - return ql_nope.transpose(0, 1), q_pe + return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) def process_weights_after_loading(self, act_dtype: torch.dtype): @@ -878,7 +870,11 @@ def _forward_prefill( suffix_lse=suffix_lse, ) - return self.o_proj(output.flatten(start_dim=-2))[0] + # unpad if necessary + if self._pad_v: + output = output[..., :v.shape[-1]] + + return output.flatten(start_dim=-2) @abstractmethod def _forward_decode( @@ -893,32 +889,37 @@ def _forward_decode( def forward( self, layer: AttentionLayer, - hidden_states_or_q_c: torch.Tensor, # query in unified attn + q: torch.Tensor, k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, attn_metadata: M, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for MLACommonImpl") + if attn_metadata is None: - # Profiling run. - return output + # The zero fill is required when used with DP + EP + # to ensure all ranks within a DP group compute the + # same expert outputs. + return output.fill_(0) num_actual_toks = attn_metadata.num_actual_tokens # Inputs and outputs may be padded for CUDA graphs output_padded = output output = output[:num_actual_toks, ...] - hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] + q = q[:num_actual_toks, ...] k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] - # Restore head dim (for rotary embedding) - k_pe = k_pe.unsqueeze(1) - assert attn_metadata.num_decodes is not None and \ attn_metadata.num_prefills is not None and \ attn_metadata.num_decode_tokens is not None @@ -927,31 +928,12 @@ def forward( has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] - decode_k_pe = k_pe[:num_decode_tokens] + decode_q = q[:num_decode_tokens] - prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] + prefill_q = q[num_decode_tokens:] prefill_k_pe = k_pe[num_decode_tokens:] prefill_k_c_normed = k_c_normed[num_decode_tokens:] - if has_decode: - assert attn_metadata.decode is not None - decode_ql_nope, decode_q_pe = \ - self._q_proj_and_k_up_proj(decode_hs_or_q_c) - decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - attn_metadata.decode.input_positions, decode_q_pe.contiguous(), - decode_k_pe) - - if has_prefill: - assert attn_metadata.prefill is not None - prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ - .view(-1, self.num_heads, self.qk_head_dim) - prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] - - prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( - attn_metadata.prefill.input_positions, - prefill_q_pe.contiguous(), prefill_k_pe) - # write the latent and rope to kv cache if kv_cache.numel() > 0: ops.concat_and_cache_mla( @@ -969,6 +951,16 @@ def forward( attn_metadata) if has_decode: + assert attn_metadata.decode is not None + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) + output[:num_decode_tokens] = self._forward_decode( decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 143bfe35bb5e..f18c9c8b6462 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -146,4 +146,4 @@ def _forward_decode( causal=True, ) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 68245913ee15..6c2694b2ab94 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -14,6 +14,8 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable # yapf: enable @@ -61,8 +63,9 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): - def __init__(self, runner): - super().__init__(runner) + def __init__(self, runner, kv_cache_spec: AttentionSpec, + block_table: BlockTable): + super().__init__(runner, kv_cache_spec, block_table) assert self.runner.block_size == 1, "AITER MLA" \ "only supports block size 1." @@ -100,8 +103,7 @@ def _get_paged_kv_tensors( qo_indptr, ) - def _build_decode(self, input_positions: torch.Tensor, - block_table: torch.Tensor, + def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: ( @@ -109,11 +111,10 @@ def _build_decode(self, input_positions: torch.Tensor, paged_kv_indptr, paged_last_page_len, qo_indptr, - ) = self._get_paged_kv_tensors(block_table, seq_lens) + ) = self._get_paged_kv_tensors(block_table_tensor, seq_lens) attn_metadata = AiterMLADecodeMetadata( - input_positions=input_positions, - block_table=block_table, + block_table=block_table_tensor, seq_lens=seq_lens, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, @@ -205,4 +206,4 @@ def _forward_decode( attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_last_page_len) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 8e7e4f10b81b..2e6b619db628 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -115,4 +115,4 @@ def _forward_decode( attn_metadata.decode.seq_lens, attn_logits, num_kv_splits, self.scale, PAGE_SIZE) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py new file mode 100644 index 000000000000..73d17b1c7d4b --- /dev/null +++ b/vllm/v1/attention/backends/utils.py @@ -0,0 +1,308 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import abc +from abc import abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar + +import numpy as np +import torch + +from vllm.utils import cdiv + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + +# from vllm.distributed.kv_transfer.kv_connector.utils import ( +# get_kv_connector_cache_layout) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@dataclass +class CommonAttentionMetadata: + """ + Per-batch attention metadata, shared across layers and backends. + AttentionMetadataBuilder instances use it to construct per-layer metadata. + """ + + query_start_loc: torch.Tensor + """(batch_size + 1,), the start location of each request in query Tensor""" + seq_lens: torch.Tensor + """(batch_size,), the length of each request including both computed tokens + and newly scheduled tokens""" + + num_reqs: int + """Number of requests""" + num_actual_tokens: int + """Total number of tokens in batch""" + max_query_len: int + """Longest query in batch""" + + +M = TypeVar("M") + + +class AttentionMetadataBuilder(abc.ABC, Generic[M]): + # Does this backend/builder support CUDA Graphs for attention. + full_cudagraph_supported: ClassVar[bool] = False + + @abstractmethod + def build(self, common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata) -> M: + """ + Central method that builds attention metadata. + Some builders (MLA) require reorder_batch to be called prior to build. + """ + raise NotImplementedError + + def can_run_in_cudagraph( + self, common_attn_metadata: CommonAttentionMetadata) -> bool: + """ + Can this batch (with given metadata) use CUDA Graphs for attention. + """ + return False + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata) -> M: + """ + Build attention metadata for CUDA graph capture. Uses build by default. + Subclasses that override this method should call self.build or + super().build_for_cudagraph_capture. + """ + return self.build(common_prefix_len=0, + common_attn_metadata=common_attn_metadata) + + def use_cascade_attention( + self, + common_prefix_len: int, + query_lens: np.ndarray, + num_query_heads: int, + num_kv_heads: int, + use_alibi: bool, + use_sliding_window: bool, + num_sms: int, + ) -> bool: + return False + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + """ + This method can reorder the batch if desired by the backend. + :return: Has the batch been reordered (default False). + """ + return False + + +def validate_kv_sharing_target(current_layer_name, target_layer_name, + static_forward_context): + error_msg = (f"Specified KV sharing target layer for {current_layer_name} " + f"is not valid: target layer {target_layer_name} ") + + if current_layer_name == target_layer_name: + raise ValueError(error_msg + + "cannot be the same as the current layer.") + + if target_layer_name not in static_forward_context: + from vllm.model_executor.models.utils import extract_layer_index + + # If target layer name is not in the static fwd context, it means either + # a) the target layer does not come BEFORE the current layer, or + # b) the target layer is not an Attention layer that exists in the model + current_layer_idx = extract_layer_index(current_layer_name) + target_layer_idx = extract_layer_index(target_layer_name) + if current_layer_idx <= target_layer_idx: + raise ValueError(error_msg + "must come before the current layer.") + else: + raise ValueError(error_msg + + "is not a valid Attention layer in the model.") + + # Currently KV sharing is only supported between layers of the same type + target_layer_attn_type = static_forward_context[ + target_layer_name].attn_type + expected = static_forward_context[current_layer_name].attn_type + if target_layer_attn_type != expected: + raise ValueError( + error_msg + + f"must be the same type as the current layer ({expected}).") + + +# @functools.lru_cache +# def get_kv_cache_layout(): +# # Override with format specified by the user. +# cache_layout = envs.VLLM_KV_CACHE_LAYOUT +# if cache_layout is None: +# cache_layout = get_kv_connector_cache_layout() +# else: +# logger.info_once("`FLASHINFER_KV_CACHE_LAYOUT` environment variable " \ +# "detected. Setting KV cache layout to %s.", cache_layout) + +# return cache_layout + + +# +# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into +# local attention blocks, where each block is passed to the attention kernel +# as an independent local ("virtual") batch item. +# +# For example, if are performing a chunked prefill a batch of 3 sequences: +# q_seqlens = [4, 10, 5] +# kv_seqlens = [6, 17, 9] +# Then normally for regular attention we would compute with an attention mask +# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like: +# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6) +# k_toks > 0 1 2 3 4 5 +# q_toks v _____________ +# 0 | 1 1 1 +# 1 | 1 1 1 1 +# 2 | 1 1 1 1 1 +# 3 | 1 1 1 1 1 1 +# +# for local attention (with attn_chunk_size = 4) we would compute with an +# attention mask like: +# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4) +# k_toks > 0 1 2 3 4 5 +# q_toks v _____________ +# 0 | 1 1 1 +# 1 | 1 1 1 1 +# 2 | 1 +# 3 | 1 1 +# +# We can simulate this mask using standard flash-attention by breaking the +# sequences into local ("virtual") batches, where each local batch item is a +# local attention block, so in this case batch idx 0 would be broken up into: +# +# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0) +# k_toks > 0 1 2 3 +# q_toks v _____________ +# 0 | 1 1 1 +# 1 | 1 1 1 1 +# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0) +# k_toks > 4 5 +# q_toks v _____________ +# 2 | 1 +# 3 | 1 1 +# +# e.g. if we have: +# attn_chunk_size = 4 +# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5]) +# Then this function would return: +# __b0__ ______b1______ __b2__ < orig batch indices +# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1] +# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24] +# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1] +# block_table_local : shape[local_virtual_batches, pages_per_local_batch] +def make_local_attention_virtual_batches( + attn_chunk_size: int, + query_start_loc_np: np.ndarray, + seq_lens_np: np.ndarray, + block_table: torch.Tensor, + block_size: int = 0, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]: + q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] + actual_batch_size = seq_lens_np.shape[0] + + # Handle if we are starting in the middle of a local attention block, + # we assume q_seqlens > 0 (for all elements), for each batch idx we compute + # the number of tokens that are not in the first local attention block and + # then we can simply use a cdiv for the rest. + # For example if we have: + # attn_chunk_size = 4 + # q_seqlens = [4, 10, 5] + # k_seqlens = [6, 17, 9] + # Then we would get: + # new_tokens_in_first_block = [2, 1, 4] + # local_blocks = [2, 4, 2] + q_tokens_in_first_block = np.minimum( + attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), + q_seqlens).astype(np.int32) + tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) + local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, + attn_chunk_size) + + # Once we know the number of local blocks we can compute the request spans + # for each batch idx, we can figure out the number of "virtual" requests we + # have to make, + # For the above example we would get: + # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1] + # + # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1]) + # (TODO: max a utility to share this code with _prepare_inputs) + # arange step 1. [2, 4, 2] -> [2, 6, 8] + cu_num_blocks = np.cumsum(local_blocks) + virtual_batches = cu_num_blocks[-1] + # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6] + block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks) + # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1] + arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets + # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0]) + rarange = np.repeat(local_blocks, local_blocks) - arange - 1 + # Then we can compute the seqlens_q_local, handling the fact that the + # first and last blocks could be partial + seqlens_q_local = \ + np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) + # set the first block since this may be a partial block + seqlens_q_local[arange == 0] = q_tokens_in_first_block + # set the remaining blocks + seqlens_q_local[arange > 0] = np.minimum( + seqlens_q_local - attn_chunk_size * (arange - 1), + attn_chunk_size)[arange > 0] + + # convert from q_seqlens to cu_seqlens_q + cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\ + .astype(np.int32) + + # compute the seqlens_k_local, + # basically a full local attention block for all but the last block in each + # batch + # For our example this will be: + # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] + seqlens_k_local = np.full(cu_num_blocks[-1], + attn_chunk_size, + dtype=np.int32) + seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block + + k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \ + (rarange * attn_chunk_size + \ + np.repeat(tokens_in_last_block, local_blocks)) + # For the example the local attention blocks start at: + # _b0_ _____b1_____ _b2_ + # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] + block_starts = k_seqstarts_absolute // block_size + assert attn_chunk_size % block_size == 0, \ + f"attn_chunk_size {attn_chunk_size} is not " \ + f"divisible by block_size {block_size}" + pages_per_local_batch = attn_chunk_size // block_size + + # Create a block_table for the local attention blocks + # For out example if we have a block-table like (assuming block_size=2): + # block_table = [ + # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0 + # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1 + # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2 + # ] + # Then for the local batches we would want a block-table like + # block_table_local = [ + # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0]) + # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4]) + # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4]) + # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8]) + # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12]) + # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16]) + # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) + # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) + # ] + block_indices= np.broadcast_to( + np.arange(pages_per_local_batch, dtype=np.int32), + (virtual_batches, pages_per_local_batch)) \ + + np.expand_dims(block_starts, axis=1) + block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1) + batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), + local_blocks * pages_per_local_batch) + block_table_local = block_table[batch_indices, block_indices]\ + .view(virtual_batches, -1) + + return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \ + block_table_local diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 7d4082b73992..581d3d9bd11b 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -14,11 +14,13 @@ def __init__( self, max_num_reqs: int, max_num_blocks_per_req: int, + max_num_batched_tokens: int, pin_memory: bool, device: torch.device, ): self.max_num_reqs = max_num_reqs self.max_num_blocks_per_req = max_num_blocks_per_req + self.max_num_batched_tokens = max_num_batched_tokens self.pin_memory = pin_memory self.device = device @@ -36,6 +38,15 @@ def __init__( self.block_table_np = self.block_table_cpu.numpy() self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) + self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + self.slot_mapping_np = self.slot_mapping_cpu.numpy() + self.slot_mapping = torch.zeros(self.max_num_batched_tokens, + dtype=torch.int64, + device=self.device) + def append_row( self, block_ids: list[int], diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index a64cb97e0123..4c691a065bc6 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -60,6 +60,7 @@ def __init__( max_num_reqs: int, max_model_len: int, max_num_blocks_per_req: int, + max_num_batched_tokens: int, device: torch.device, pin_memory: bool, vocab_size: int, @@ -101,6 +102,7 @@ def __init__( self.block_table = BlockTable( max_num_reqs=max_num_reqs, max_num_blocks_per_req=max_num_blocks_per_req, + max_num_batched_tokens=max_num_batched_tokens, pin_memory=pin_memory, device=device, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ac0701c45986..6ed97b760135 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3,12 +3,13 @@ import gc import time import weakref -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np import torch import torch.distributed import torch.nn as nn +from tqdm import tqdm from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention @@ -29,7 +30,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LayerBlockType, LazyLoader, cdiv, check_use_alibi, is_pin_memory_available) -from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, @@ -69,6 +70,7 @@ def __init__( self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config + self.compilation_config = vllm_config.compilation_config self.lora_config = vllm_config.lora_config self.load_config = vllm_config.load_config self.parallel_config = vllm_config.parallel_config @@ -137,8 +139,6 @@ def __init__( raise NotImplementedError( "Non-Attention backend is not supported by V1 GPUModelRunner.") - self.attn_metadata_builder = self.attn_backend.get_builder_cls()( - weakref.proxy(self)) self.cascade_attn_enabled = not self.model_config.disable_cascade_attn # Multi-modal data support @@ -181,12 +181,13 @@ def __init__( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, max_num_blocks_per_req=self.max_num_blocks_per_req, + max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, vocab_size=model_config.get_vocab_size(), ) - self.use_cuda_graph = (self.vllm_config.compilation_config.level + self.use_cuda_graph = (self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. @@ -194,9 +195,9 @@ def __init__( # self.cudagraph_batch_sizes sorts in ascending order. # The batch sizes in the config are in descending order. self.cudagraph_batch_sizes = list( - reversed( - self.vllm_config.compilation_config.cudagraph_capture_sizes)) + reversed(self.compilation_config.cudagraph_capture_sizes)) + self.full_cuda_graph = self.compilation_config.full_cuda_graph # Cache the device properties. self.device_properties = torch.cuda.get_device_properties(self.device) self.num_sms = self.device_properties.multi_processor_count @@ -208,6 +209,16 @@ def __init__( self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) + self.query_start_loc = torch.zeros(self.max_num_reqs + 1, + dtype=torch.int32, + device=self.device) + self.seq_lens = torch.zeros(self.max_num_reqs, + dtype=torch.int32, + device=self.device) + self.slot_mapping = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=self.device) + # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -258,11 +269,6 @@ def __init__( device="cpu", pin_memory=self.pin_memory) self.positions_np = self.positions_cpu.numpy() - self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, dtype=torch.int32, device="cpu", @@ -472,8 +478,16 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[FlashAttentionMetadata, torch.Tensor, + ) -> tuple[dict[str, Any], bool, torch.Tensor, Optional[SpecDecodeMetadata]]: + """ + :return: tuple[ + attn_metadata: layer-to-attention_metadata mapping, + attention_cuda_graphs: whether attention can run in cudagraph + logits_indices, spec_decode_metadata + ] + """ + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -545,7 +559,8 @@ def _prepare_inputs( block_offsets = positions_np % self.block_size np.add(block_numbers * self.block_size, block_offsets, - out=self.slot_mapping_np[:total_num_scheduled_tokens]) + out=self.input_batch.block_table. + slot_mapping_np[:total_num_scheduled_tokens]) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 @@ -569,7 +584,37 @@ def _prepare_inputs( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) - # Prepare for cascade attention if enabled & beneficial. + self.query_start_loc[:num_reqs + 1].copy_( + self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], + non_blocking=True) + + self.seq_lens[num_reqs:].fill_(0) + self.query_start_loc[num_reqs + 1:].fill_(-1) + + query_start_loc = self.query_start_loc[:num_reqs + 1] + seq_lens = self.seq_lens[:num_reqs] + + query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( + self.device, non_blocking=True) + seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device, + non_blocking=True) + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, + seq_lens=seq_lens, + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens) + + attn_metadata: dict[str, Any] = {} + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + # NOTE(Chen): there is exactly one KV cache group that contains all + # attetnion layers in the model for now, so the current logic for + # getting attn_metadata is not related to kv_cache_group information. + # Will extend this part to support multiple KV cache groups later. + common_prefix_len = 0 if self.cascade_attn_enabled: common_prefix_len = self._compute_cascade_attn_prefix_len( @@ -578,12 +623,13 @@ def _prepare_inputs( ) attn_metadata = self.attn_metadata_builder.build( - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, ) + attention_cuda_graphs = self.attn_metadata_builder.can_run_in_cudagraph( + common_attn_metadata) + use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: @@ -592,7 +638,7 @@ def _prepare_inputs( # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. - logits_indices = attn_metadata.query_start_loc[1:] - 1 + logits_indices = query_start_loc[1:] - 1 spec_decode_metadata = None else: # Get the number of draft tokens for each request. @@ -612,7 +658,7 @@ def _prepare_inputs( if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return attn_metadata, logits_indices, spec_decode_metadata + return attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata def _compute_cascade_attn_prefix_len( self, @@ -1000,7 +1046,7 @@ def execute_model( return EMPTY_MODEL_RUNNER_OUTPUT # Prepare the decoder inputs. - attn_metadata, logits_indices, spec_decode_metadata = ( + attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata = ( self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph @@ -1012,7 +1058,6 @@ def execute_model( else: # Eager mode. num_input_tokens = num_scheduled_tokens - attn_metadata.num_input_tokens = num_input_tokens # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order @@ -1062,9 +1107,19 @@ def execute_model( for k, v in self.intermediate_tensors.items() }) + # Some attention backends only support CUDA Graphs in pure decode. + # If attention doesn't support CUDA Graphs for this batch, but we + # compiled with full CUDA graphs, we have to skip them entirely. + skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs + # Run the decoder. # Use persistent buffers for CUDA graphs. - with set_forward_context(attn_metadata, self.vllm_config): + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + skip_cuda_graphs=skip_cuda_graphs, + ): hidden_states = self.model( input_ids=input_ids, positions=positions, @@ -1399,8 +1454,8 @@ def _get_prompt_logprobs_dict( def _dummy_run( self, num_tokens: int, + capture_attn_cudagraph: bool = False, ) -> torch.Tensor: - # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. @@ -1415,6 +1470,25 @@ def _dummy_run( num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + if capture_attn_cudagraph: + query_start_loc = self.query_start_loc_cpu[:num_reqs + 1] + # Make sure max_model_len is used at the graph capture time. + self.seq_lens_np[:num_reqs] = self.max_model_len + self.seq_lens_np[num_reqs:] = 0 + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], + non_blocking=True) + seq_lens = self.seq_lens[:num_reqs] + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, + seq_lens=seq_lens, + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=num_tokens, + ) + self.attn_metadata_builder.build_for_cuda_graph_capture( + common_attn_metadata, ) + with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): model = self.model @@ -1620,11 +1694,14 @@ def capture_model(self) -> None: # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. with graph_capture(device=self.device): - for num_tokens in reversed(self.cudagraph_batch_sizes): - for _ in range(self.vllm_config.compilation_config. - cudagraph_num_of_warmups): - self._dummy_run(num_tokens) - self._dummy_run(num_tokens) + full_cg = self.full_cuda_graph + for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes), + desc="Capturing CUDA graphs", + total=len(self.cudagraph_batch_sizes)): + for _ in range( + self.compilation_config.cudagraph_num_of_warmups): + self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg) + self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] @@ -1675,10 +1752,14 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: # KV cache specs. raise ValueError("Unknown KV cache spec type.") - bind_kv_cache( - kv_caches, - self.vllm_config.compilation_config.static_forward_context, - self.kv_caches) + bind_kv_cache(kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches) + + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( + weakref.proxy(self), + kv_cache_config.kv_cache_groups[0].kv_cache_spec, + self.input_batch.block_table) def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ @@ -1689,7 +1770,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: format. Layers that do not need KV cache are not included. """ - forward_ctx = self.vllm_config.compilation_config.static_forward_context + forward_ctx = self.compilation_config.static_forward_context block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} From 04f4131a3f9f934be81371b961df8f01cbe3cebe Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 7 Jul 2025 09:05:04 +0000 Subject: [PATCH 2/4] add full graph support in aiter mla Signed-off-by: vllmellm --- .../attention/backends/mla/rocm_aiter_mla.py | 112 +++++++++++------- 1 file changed, 72 insertions(+), 40 deletions(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 6c2694b2ab94..d5f9dfaea065 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, ClassVar, Optional import torch @@ -62,63 +63,91 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): + full_cudagraph_supported: ClassVar[bool] = True # decode only def __init__(self, runner, kv_cache_spec: AttentionSpec, block_table: BlockTable): - super().__init__(runner, kv_cache_spec, block_table) - assert self.runner.block_size == 1, "AITER MLA" \ + super().__init__(runner, kv_cache_spec, block_table, AiterMLAMetadata) + assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ "only supports block size 1." - def _get_paged_kv_tensors( - self, block_table: torch.Tensor, - seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: - page_size = self.runner.block_size + # Preparing persistent buffers + if self.runner.full_cuda_graph: + device = self.runner.device + max_num_reqs = self.runner.max_num_reqs + self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, + dtype=torch.int32, + device=device) + self.paged_kv_indices = torch.zeros( + block_table.get_device_tensor().numel( + ), # max num pages possible + dtype=torch.int32, + device=device) + self.paged_kv_last_page_len = torch.zeros(max_num_reqs, + dtype=torch.int32, + device=device) + + self.qo_indptr = torch.arange(0, + max_num_reqs + 1, + dtype=torch.int32, + device=device) + + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: + page_size = self.kv_cache_spec.block_size block_table_bounds = (seq_lens + page_size - 1) // page_size + device = self.runner.device - mask = (torch.arange(block_table.size(1), - dtype=block_table.dtype, - device=block_table.device).unsqueeze(0) + mask = (torch.arange(block_table_tensor.size(1), + dtype=block_table_tensor.dtype, + device=device).unsqueeze(0) < block_table_bounds.unsqueeze(1)) - paged_kv_indices = block_table[mask] + paged_kv_indices = block_table_tensor[mask] + + paged_kv_last_page_len = seq_lens % page_size + paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, + page_size, paged_kv_last_page_len) paged_kv_indptr = torch.cat([ - torch.zeros(1, - dtype=block_table_bounds.dtype, - device=block_table_bounds.device), + torch.zeros(1, dtype=block_table_bounds.dtype, device=device), block_table_bounds.cumsum(dim=0, dtype=torch.int32) ]) - paged_kv_last_page_len = seq_lens % page_size - paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, - page_size, paged_kv_last_page_len) - qo_indptr = torch.arange(0, - self._num_decodes + 1, - step=1, - dtype=torch.int32, - device=block_table_bounds.device) - return ( - paged_kv_indices, - paged_kv_indptr, - paged_kv_last_page_len, - qo_indptr, - ) + if self.runner.full_cuda_graph: + num_reqs = self._num_decodes - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: + num_actual_pages = paged_kv_indices.size(0) + + self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices, + non_blocking=True) + self.paged_kv_indices[num_actual_pages:].fill_(-1) + paged_kv_indices = self.paged_kv_indices[:num_actual_pages] + + self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr, + non_blocking=True) + self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1]) + paged_kv_indptr = self.paged_kv_indptr[:1 + num_reqs] - ( - paged_kv_indices, - paged_kv_indptr, - paged_last_page_len, - qo_indptr, - ) = self._get_paged_kv_tensors(block_table_tensor, seq_lens) + self.paged_kv_last_page_len[:num_reqs].copy_( + paged_kv_last_page_len, non_blocking=True) + self.paged_kv_last_page_len[num_reqs:].fill_(1) + paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] + + qo_indptr = self.qo_indptr[:1 + num_reqs] + + else: + qo_indptr = torch.arange(0, + self._num_decodes + 1, + step=1, + dtype=torch.int32, + device=device) attn_metadata = AiterMLADecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len=paged_last_page_len, + paged_kv_last_page_len=paged_kv_last_page_len, qo_indptr=qo_indptr) return attn_metadata @@ -138,13 +167,17 @@ def __init__( blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, + kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap, attn_type, - **mla_args) - + kv_sharing_target_layer_name, **mla_args) + assert (num_heads == 16 or num_heads == 128), ( + f"Aiter MLA only supports 16 or 128 number of heads.\n" + f"Provided {num_heads} number of heads.\n" + "Try adjusting tensor_parallel_size value.") unsupported_features = [ alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap ] @@ -199,7 +232,6 @@ def _forward_decode( # max_seqlen_qo must be 1 except for MTP # TODO: Find the best value for MTP max_seqlen_qo = 1 - aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, attn_metadata.decode.qo_indptr, max_seqlen_qo, attn_metadata.decode.paged_kv_indptr, From 45fd5d642e2834f1ab0ca8c0f54f86d2d8865358 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 7 Jul 2025 15:37:49 +0000 Subject: [PATCH 3/4] bugfixes Signed-off-by: vllmellm --- vllm/v1/attention/backends/mla/rocm_aiter_mla.py | 3 +-- vllm/v1/worker/gpu_model_runner.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index d5f9dfaea065..8d1a0791a06b 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -167,13 +167,12 @@ def __init__( blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, - kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) + **mla_args) assert (num_heads == 16 or num_heads == 128), ( f"Aiter MLA only supports 16 or 128 number of heads.\n" f"Provided {num_heads} number of heads.\n" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6ed97b760135..cfd1d7717620 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1486,7 +1486,7 @@ def _dummy_run( num_actual_tokens=num_tokens, max_query_len=num_tokens, ) - self.attn_metadata_builder.build_for_cuda_graph_capture( + self.attn_metadata_builder.build_for_cudagraph_capture( common_attn_metadata, ) with self.maybe_dummy_run_with_lora(self.lora_config, From d1bf7f0ba51ca648a9c4f5cab446738b8b167895 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 8 Jul 2025 05:46:05 +0000 Subject: [PATCH 4/4] add example for enabilng cuda graph Signed-off-by: vllmellm --- .../basic/generate_with_full_graph.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 examples/offline_inference/basic/generate_with_full_graph.py diff --git a/examples/offline_inference/basic/generate_with_full_graph.py b/examples/offline_inference/basic/generate_with_full_graph.py new file mode 100644 index 000000000000..bdc3901e8c67 --- /dev/null +++ b/examples/offline_inference/basic/generate_with_full_graph.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm import LLM, EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def create_parser(): + parser = FlexibleArgumentParser() + # Add engine args + engine_group = parser.add_argument_group("Engine arguments") + EngineArgs.add_cli_args(engine_group) + engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") + # Add sampling params + sampling_group = parser.add_argument_group("Sampling parameters") + sampling_group.add_argument("--max-tokens", type=int) + sampling_group.add_argument("--temperature", type=float) + sampling_group.add_argument("--top-p", type=float) + sampling_group.add_argument("--top-k", type=int) + + return parser + + +def main(args: dict): + # Pop arguments not used by LLM + max_tokens = args.pop("max_tokens") + temperature = args.pop("temperature") + top_p = args.pop("top_p") + top_k = args.pop("top_k") + + # Create an LLM + args.pop("compilation_config", + None) # Remove compilation_config if it exists + args.pop("max_num_seqs", None) # Remove max_num_seqs if it exists + llm = LLM(**args, + max_num_seqs=256, + compilation_config={ + "full_cuda_graph": True, + "cudagraph_capture_sizes": [64, 256] + }) + + # Create a sampling params object + sampling_params = llm.get_default_sampling_params() + if max_tokens is not None: + sampling_params.max_tokens = max_tokens + if temperature is not None: + sampling_params.temperature = temperature + if top_p is not None: + sampling_params.top_p = top_p + if top_k is not None: + sampling_params.top_k = top_k + + # Generate texts from the prompts. The output is a list of RequestOutput + # objects that contain the prompt, generated text, and other information. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + print("-" * 50) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) + + +if __name__ == "__main__": + parser = create_parser() + args: dict = vars(parser.parse_args()) + main(args)