Skip to content

Commit 7e56265

Browse files
committed
code review changes
Signed-off-by: Harish Subramony <[email protected]>
1 parent 141da43 commit 7e56265

File tree

7 files changed

+14
-217
lines changed

7 files changed

+14
-217
lines changed

.github/workflows/pre-merge.yaml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,8 @@ jobs:
4747
RUN git checkout main
4848
4949
# Pinning versions in requirements might be good practice for CI consistency
50-
RUN pip install pytest pytest_asyncio nixl==0.4.1
50+
RUN pip install pytest pytest_asyncio
5151
RUN pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git
52-
RUN pip install lm-eval[api]
5352
5453
ENV no_proxy=localhost,127.0.0.1
5554
ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true
@@ -114,7 +113,10 @@ jobs:
114113
-e HF_HOME=/workspace/hf_cache \
115114
-v /mnt/hf_cache:/workspace/hf_cache \
116115
hpu-plugin-v1-test-env-pre-merge-${{ github.event.pull_request.head.sha }} \
117-
/bin/bash "/workspace/vllm-gaudi/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh"
116+
/bin/bash -c "
117+
pip install nixl==0.4.1 lm-eval[api] &&
118+
/workspace/vllm-gaudi/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
119+
"
118120
119121
EXITCODE=$?
120122
echo "Test script exited with code: $EXITCODE"

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ tabulate
66
setuptools>=77.0.3,<80.0.0
77
setuptools-scm>=8
88
numba
9-
transformers>=4.1,<4.56.0
9+
transformers>=4.1,<4.56.0
10+
nixl==0.4.1

tests/v1/kv_connector/nixl_integration/run_tpu_disagg_accuracy_test.sh

Lines changed: 0 additions & 159 deletions
This file was deleted.

vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hpu_base.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
44
import torch
55
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
6-
KVConnectorBase_V1, CopyBlocksOp)
6+
KVConnectorBase_V1)
77
from vllm_gaudi.extension.logger import logger as init_logger
88

99
logger = init_logger()
@@ -24,14 +24,6 @@ def from_raw_dict(
2424
return None
2525

2626

27-
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
28-
"""
29-
Set the xPU-specific ops for copying KV between host and device.
30-
Needed when host buffer is used for kv transfer (e.g., in NixlConnector)
31-
"""
32-
return
33-
34-
3527
# ==============================
3628
# Scheduler-side methods
3729
# ==============================
@@ -44,6 +36,5 @@ def set_kv_transfer_params(self, request: "Request"):
4436
request.raw_kv_transfer_params)
4537
request.kv_transfer_params = kv_transfer_params
4638

47-
KVConnectorBase_V1.set_host_xfer_buffer_ops = set_host_xfer_buffer_ops
4839
KVConnectorBase_V1.set_kv_transfer_params = set_kv_transfer_params
4940

vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hpu_nixl_connector.py

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -73,43 +73,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
7373
# will only affects the strides. For MLA instead, we make require no
7474
# such thing and resort to the standard layout.
7575
use_mla = len(first_kv_cache.shape) == 3 if self.device_type != "hpu" else False
76-
if self.device_type == "tpu":
77-
assert not use_mla, f"{self.kv_buffer_device} does not support MLA."
78-
assert self._use_pallas_v1, f"attn backend: {self.backend_name}"
79-
# tpu (v1) kv shape per layer:
80-
# (num_blocks, block_size, num_kv_heads * 2, head_size)
81-
self.num_blocks = first_kv_cache.shape[0]
82-
block_rank = 3 # [block_size, kv_heads, head_dim]
83-
block_shape = first_kv_cache.shape[-block_rank:]
84-
block_size, n_kv_heads_x_2, head_dim = block_shape
85-
self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim
86-
elif self.device_type == "cuda":
87-
assert use_mla == self.use_mla
88-
# TODO (NickLucche) not compatible with hybrid allocator.
89-
# Enforce check once it goes live, as a single kv layout
90-
# is expected for xfers.
91-
if use_mla:
92-
# MLA case.
93-
self.num_blocks = first_kv_cache.shape[0]
94-
block_rank = 2 # [block_size, latent_dim]
95-
block_shape = first_kv_cache.shape[-block_rank:]
96-
block_size, kv_latent_dim = block_shape
97-
self.slot_size_bytes = kv_elem_size * kv_latent_dim
98-
else:
99-
# [2 (k and v), num_blocks, ...]
100-
if self._use_flashinfer:
101-
# FlashInfer swaps 2<->num_blocks dimensions.
102-
self.num_blocks = first_kv_cache.shape[0]
103-
block_rank = 4 # [2, block_size, kv_heads, head_dim]
104-
else:
105-
self.num_blocks = first_kv_cache.shape[1]
106-
block_rank = 3 # [block_size, kv_heads, head_dim]
107-
block_shape = first_kv_cache.shape[-block_rank:]
108-
block_size, n_kv_heads, head_dim = block_shape[-3:]
109-
# head size in bytes.
110-
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
111-
assert block_size == self.block_size
112-
elif self.device_type == "hpu":
76+
if self.device_type == "hpu":
11377
# habana kv_cache: [2, num_blocks*block_size, kv_heads, head_dim]
11478
#from remote_pdb import RemotePdb; RemotePdb('0.0.0.0', 4444).set_trace()
11579
self.num_blocks = first_kv_cache[0].shape[0] // self.block_size

vllm_gaudi/platform.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
7878
cache_config = vllm_config.cache_config
7979
if cache_config and cache_config.block_size is None:
8080
cache_config.block_size = 128
81-
#vllm_config.kv_transfer_config.kv_buffer_device = 'hpu'
8281
if (parallel_config.distributed_executor_backend in ['mp', 'uni']
8382
and envs.VLLM_WORKER_MULTIPROC_METHOD == 'fork'):
8483
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD",
@@ -121,8 +120,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
121120

122121
@classmethod
123122
def is_pin_memory_available(cls):
124-
logger.warning("Pin memory is supported on HPU.")
125-
return True
123+
logger.warning("Pin memory is not supported on HPU.")
124+
return False
126125

127126
@classmethod
128127
def get_punica_wrapper(cls) -> str:

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ def __init__(
622622
self.parallel_config)
623623
self.head_size = self.model_config.get_head_size()
624624
self.hidden_size = self.model_config.get_hidden_size()
625-
logger.debug(f'buke model config: {self.model_config=}')
625+
logger.debug(f'model config: {self.model_config=}')
626626
self.attn_backend = get_attn_backend(
627627
self.head_size,
628628
self.dtype,
@@ -2302,7 +2302,7 @@ def execute_model(
23022302
if not has_kv_transfer_group():
23032303
# Return empty ModelRunnerOuptut if there's no work to do.
23042304
return EMPTY_MODEL_RUNNER_OUTPUT
2305-
#logger.info(f'buke before kv_connector_no_forward |{os.getpid()=}|{scheduler_output.total_num_scheduled_tokens=}|{scheduler_output=}')
2305+
#logger.debug(f'before kv_connector_no_forward |{os.getpid()=}|{scheduler_output.total_num_scheduled_tokens=}|{scheduler_output=}')
23062306
# For D case, wait until kv finish load here
23072307
return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
23082308
# If necessary, swap decodes/prompts to have all decodes on the start
@@ -2658,7 +2658,6 @@ def execute_model(
26582658
finished_recving=finished_recving,
26592659
)
26602660
)
2661-
#logger.debug(f"buke hpu_model_runner.py: {model_runner_output=}")
26622661
if has_kv_transfer_group():
26632662
get_kv_transfer_group().clear_connector_metadata()
26642663
return model_runner_output
@@ -3181,7 +3180,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
31813180
v_cache_shape = None if self.model_config.use_mla \
31823181
else kv_cache_shape
31833182
dtype = kv_cache_spec.dtype
3184-
#logger.debug(f'buke: |{os.getpid()=}|{kv_cache_shape=}')
3183+
logger.debug(f'|{os.getpid()=}|{kv_cache_shape=}')
31853184
key_cache = torch.zeros(kv_cache_shape,
31863185
dtype=dtype,
31873186
device=self.device)

0 commit comments

Comments
 (0)