Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
aae62b1
WIP enabling llama4 models
afierka-intel Sep 3, 2025
3da91bf
Fix importing and typing
afierka-intel Sep 3, 2025
bc4c16b
Merge branch 'vllm-project:main' into dev/afierka/llama4-enabling
afierka-intel Sep 3, 2025
46d9625
Remove inheritance from HPUPagedAttentionMetadataBuilder - llama4 works
afierka-intel Sep 3, 2025
cbcff01
Enable Llama4 multimodal functionality
afierka-intel Sep 8, 2025
b18c8fb
Formatting
afierka-intel Sep 9, 2025
41cf4ab
Merge branch 'vllm-project:main' into dev/afierka/llama4-enabling
afierka-intel Sep 9, 2025
f639be9
Merge branch 'main' into dev/afierka/llama4-enabling
afierka-intel Sep 9, 2025
ece6955
Merge branch 'main' into dev/afierka/llama4-enabling
afierka-intel Sep 10, 2025
40a13a5
Update hpu_attn.py
afierka-intel Sep 10, 2025
22ae094
Merge branch 'main' into dev/afierka/llama4-enabling
afierka-intel Sep 10, 2025
53c03b5
Fix pre-commit issues
afierka-intel Sep 10, 2025
c942e37
Handel exception when is_embed is None
afierka-intel Sep 10, 2025
93ca40b
Merge branch 'main' into dev/afierka/llama4-enabling
afierka-intel Sep 11, 2025
7676336
Merge branch 'main' into dev/afierka/llama4-enabling
kzawora-intel Sep 11, 2025
ee670e8
Merge branch 'main' into dev/afierka/llama4-enabling
afierka-intel Sep 11, 2025
3c4a903
Merge branch 'main' into dev/afierka/llama4-enabling
afierka-intel Sep 12, 2025
c7f9c49
Return forward_native in HPULlama4VisionRotaryEmbedding
afierka-intel Sep 12, 2025
b4fc5c1
Typo: forward_native -> forward_oot
afierka-intel Sep 12, 2025
49f800c
Merge branch 'main' into dev/afierka/llama4-enabling
kzawora-intel Sep 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions .jenkins/test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,12 @@ stages:
cd .jenkins/lm-eval-harness &&
PT_HPU_LAZY_MODE=1
bash run-tests.sh -c configs/models-fp8.txt -t 2
# Chendi: llama4 upstream modeling changed, need to fix
# - name: gsm8k_fp8_llama4_scout_g3_tp2_compressed_tensor
# flavor: g3.s
# command: >-
# cd .jenkins/lm-eval-harness &&
# VLLM_CONTIGUOUS_PA=False PT_HPU_LAZY_MODE=1
# bash run-tests.sh -c configs/models-fp8-compressedtensor.txt -t 2
- name: gsm8k_fp8_llama4_scout_g3_tp2_compressed_tensor
flavor: g3.s
command: >-
cd .jenkins/lm-eval-harness &&
VLLM_WEIGHT_LOAD_FORCE_SYNC=1 VLLM_CONTIGUOUS_PA=False PT_HPU_LAZY_MODE=1
bash run-tests.sh -c configs/models-fp8-compressedtensor.txt -t 2
# Chendi: crash on model weight loading, need to fix
# - name: gsm8k_fp8_qwen3_30B_g3_tp1_block_scale_dynamic
# flavor: g3
Expand All @@ -102,10 +101,9 @@ stages:
# cd .jenkins/lm-eval-harness &&
# VLLM_CONTIGUOUS_PA=False PT_HPU_LAZY_MODE=1 VLLM_HPU_FORCE_CHANNEL_FP8=0
# bash run-tests.sh -c configs/models-fp8-blockfp8.txt -t 1
# Chendi: comment multimodal test since it is not enabled in V1 yet.
# - name: multimodal_llama4_scout_g3_tp2_ep
# flavor: g3.s
# command: >-
# cd .jenkins/vision &&
# PT_HPU_LAZY_MODE=1 VLLM_WEIGHT_LOAD_FORCE_SYNC=1
# bash run-tests.sh -c configs/models-llama4-scout.txt -t 2
- name: multimodal_llama4_scout_g3_tp2_ep
flavor: g3.s
command: >-
cd .jenkins/vision &&
PT_HPU_LAZY_MODE=1 VLLM_WEIGHT_LOAD_FORCE_SYNC=1
bash run-tests.sh -c configs/models-llama4-scout.txt -t 2
7 changes: 6 additions & 1 deletion vllm_gaudi/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
AttentionType)
from vllm.attention.backends.mla.common import MLACommonImpl
from vllm.attention.backends.utils import CommonAttentionState
from vllm_gaudi.attention.ops.hpu_paged_attn import (HPUPagedAttention, HPUPagedAttentionMetadata)
from vllm_gaudi.attention.ops.hpu_paged_attn import (HPUPagedAttention, HPUPagedAttentionMetadata,
HPUPagedAttentionMetadataBuilder)

from vllm_gaudi.extension.logger import logger as init_logger
from vllm_gaudi.extension.unified import (unified_attn, HPUUnifiedAttentionMetadata)
Expand All @@ -45,6 +46,10 @@ def get_metadata_cls() -> type["AttentionMetadata"]:
def get_state_cls() -> type["CommonAttentionState"]:
raise NotImplementedError()

@staticmethod
def get_builder_cls() -> type[HPUPagedAttentionMetadataBuilder]:
return HPUPagedAttentionMetadataBuilder

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down
22 changes: 22 additions & 0 deletions vllm_gaudi/attention/ops/hpu_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,28 @@ class HPUPagedAttentionMetadata:
alibi_blocks: Optional[torch.Tensor]


@dataclass
class HPUPagedAttentionMetadataBuilder:

def __init__(self, input_builder: "HPUPageAttentionInputBuilderBase") -> None:
"""Create the builder, remember some configuration and parameters."""
self.input_builder = input_builder

def prepare(self) -> None:
"""Prepare for one batch."""
pass

def build(self, seq_lens: list[int], query_lens: list[int], cuda_graph_pad_size: int,
batch_size: int) -> type[HPUPagedAttentionMetadata]:
"""Build attention metadata with on-device tensors."""
return HPUPagedAttentionMetadata


@dataclass
class HPUPageAttentionInputBuilderBase:
pass


class HPUPagedAttention:

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,7 +1155,8 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput", req_ids: list

self.encoder_cache[mm_hash] = scatter_mm_placeholders(
output,
is_embed=pos_info.is_embed,
is_embed=pos_info.is_embed.to(
device=output.device) if pos_info.is_embed is not None else pos_info.is_embed,
)

# modified from: vllm/v1/worker/gpu_model_runner.py
Expand Down