Skip to content

MLA #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: deepseek00
Choose a base branch
from
Draft

MLA #70

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
435 changes: 435 additions & 0 deletions vllm/attention/backends/rocm_aiter_mla.py

Large diffs are not rendered by default.

100 changes: 100 additions & 0 deletions vllm/attention/ops/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional

import torch

from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op


def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
max_block_per_batch: int,
device: torch.device) -> tuple[torch.Tensor, ...]:
paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch,
dtype=torch.int32,
device=device)
paged_kv_indptr = torch.zeros(max_batch_size + 1,
dtype=torch.int32,
device=device)
paged_kv_last_page_lens = torch.full((max_batch_size, ),
block_size,
dtype=torch.int32)
qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr


def aiter_mla_decode_fwd(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
sm_scale: float,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
kv_last_page_lens: Optional[torch.Tensor] = None,
logit_cap: float = 0.0,
):

torch.ops.vllm.rocm_aiter_mla_decode_fwd(q,
kv_buffer.view(
-1, 1, 1, q.shape[-1]),
o,
qo_indptr,
max_seqlen_qo,
kv_indptr,
kv_indices,
kv_last_page_lens,
sm_scale=sm_scale,
logit_cap=logit_cap)


def mla_decode_fwd_impl(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
kv_last_page_lens: Optional[torch.Tensor] = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None:
from aiter.mla import mla_decode_fwd

mla_decode_fwd(q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
o,
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
max_seqlen_qo,
sm_scale=sm_scale,
logit_cap=logit_cap)


def mla_decode_fwd_fake(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
kv_last_page_lens: Optional[torch.Tensor] = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None:
pass


if current_platform.is_rocm():
direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd",
op_func=mla_decode_fwd_impl,
mutates_args=["o"],
fake_impl=mla_decode_fwd_fake,
tags=[torch.Tag.needs_fixed_stride_order])
55 changes: 55 additions & 0 deletions vllm/attention/utils/fa_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional

from vllm import envs
from vllm.logger import init_logger

logger = init_logger(__name__)


def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason, is_fa_version_supported)
device_capability = current_platform.get_device_capability()

assert device_capability is not None

# 1. default version depending on platform
fa_version = 3 if (device_capability.major == 9
and is_fa_version_supported(3)) else 2

# 2. override if passed by environment
if envs.VLLM_FLASH_ATTN_VERSION is not None:
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
fa_version = envs.VLLM_FLASH_ATTN_VERSION

# 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3:
logger.warning_once(
"Cannot use FA version 3 on Blackwell platform "
"defaulting to FA version 2.")
fa_version = 2

if requires_alibi and fa_version == 3:
logger.warning_once("Cannot use FA version 3 with ALiBi, "
"defaulting to FA version 2.")
fa_version = 2

if not is_fa_version_supported(fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
fa_version, fa_version_unsupported_reason(fa_version))

assert is_fa_version_supported(fa_version)
return fa_version
except (ImportError, AssertionError):
return None


def flash_attn_supports_fp8() -> bool:
from vllm.platforms import current_platform
return get_flash_attn_version() == 3 and \
current_platform.get_device_capability().major == 9
12 changes: 9 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def get_kwargs(cls: type[Config]) -> dict[str, Any]:
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32, 64, 128],
choices=[1, 8, 16, 32, 64, 128],
help='Token block size for contiguous chunks of '
'tokens. This is ignored on neuron devices and '
'set to ``--max-model-len``. On CUDA devices, '
Expand Down Expand Up @@ -1522,8 +1522,14 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:

# No FlashInfer or XFormers so far.
V1_BACKENDS = [
"FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
"TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA"
"FLASH_ATTN_VLLM_V1",
"FLASH_ATTN",
"PALLAS",
"PALLAS_VLLM_V1",
"TRITON_ATTN_VLLM_V1",
"TRITON_MLA",
"FLASHMLA",
"ROCM_AITER_MLA",
]
if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
Expand Down
9 changes: 7 additions & 2 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_ASMMOE: bool = False
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
Expand Down Expand Up @@ -546,19 +547,23 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in
("true", "1")),


# Whether to use aiter asm moe ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_ASMMOE":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_ASMMOE", "False").lower() in
("true", "1")),


# use aiter rms norm op if aiter ops are enabled.
"VLLM_ROCM_USE_AITER_RMSNORM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
("true", "1")),

# Whether to use aiter mla ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_MLA":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in
("true", "1")),

# Pad the fp8 weights to 256 bytes for ROCm
"VLLM_ROCM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
Expand Down
6 changes: 1 addition & 5 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,14 +576,11 @@ def process_weights_after_loading(self, layer: Module) -> None:
# Lazy import to avoid importing triton too early.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled, shuffle_weights)

self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.rocm_aiter_use_asm = (self.rocm_aiter_moe_enabled
and envs.VLLM_ROCM_USE_AITER_ASMMOE)

print(f"rocm_aiter_moe_enabled: {self.rocm_aiter_moe_enabled}")
print(f"rocm_aiter_use_asm: {self.rocm_aiter_use_asm}")

# TODO (rob): refactor block quant into separate class.
if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic"
Expand Down Expand Up @@ -780,7 +777,6 @@ def apply(
e_score_correction_bias=e_score_correction_bias,
)


if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_fused_experts)
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,4 +837,4 @@ def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(f"model.layers.{layer_idx+i}."):
return layer_idx + i
return None
return None
2 changes: 2 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class _Backend(enum.Enum):
TRITON_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
ROCM_AITER_MLA = enum.auto() # Supported by V1
ROCM_AITER_MLA_VLLM_V1 = enum.auto()
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
TRITON_MLA = enum.auto() # Supported by V1
Expand Down
36 changes: 34 additions & 2 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,40 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1,
use_mla) -> str:
if use_mla:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend"
from vllm.attention.backends.rocm_aiter_mla import (
is_aiter_mla_enabled)

if selected_backend is None:
selected_backend = (_Backend.ROCM_AITER_MLA if
is_aiter_mla_enabled() or block_size == 1
else _Backend.TRITON_MLA)

if selected_backend == _Backend.TRITON_MLA:
if block_size != 1:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501
else:
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}.")
elif selected_backend == _Backend.ROCM_AITER_MLA \
or selected_backend == _Backend.ROCM_AITER_MLA_VLLM_V1:
if block_size == 1:
if use_v1:
logger.info("Using AITER MLA backend on V1 engine.")
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
else:
raise ValueError(
"AITER MLA backend is not ported on V0 engine.")
else:
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}."
"(currently only supports block size 1)")
else:
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend.")
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if envs.VLLM_USE_V1:
Expand Down
12 changes: 6 additions & 6 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,12 +500,12 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
# longer context lengths
max_context_chunk = (self.chunked_prefill_workspace_size //
num_prefills_with_context_cpu)

# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
# `context_chunk_starts` that are not aligned to page_size
max_context_chunk = round_down(max_context_chunk,
self.page_size)
if self.aot_schedule:
# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
# `context_chunk_starts` that are not aligned to page_size
max_context_chunk = round_down(max_context_chunk,
self.page_size)

assert max_context_chunk > 0
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
Expand Down
Loading