Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions lmdeploy/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,7 @@ def pipeline(model_path: str,
if backend_config is not None else None
model_path = get_model(model_path, download_dir, revision)

task, pipeline_class = get_task(model_path)
if task == 'vlm':
if backend_config and backend_config.enable_prefix_caching:
backend_config.enable_prefix_caching = False
logger.warning('VLM does not support prefix caching.')
_, pipeline_class = get_task(model_path)

if type(backend_config) is not PytorchEngineConfig:
# set auto backend mode
Expand Down
10 changes: 10 additions & 0 deletions lmdeploy/pytorch/engine/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,16 @@ def swap_out(self, src_to_dst: Dict[int, int]) -> None:
"""
self._swap(self.full_gpu_cache, self.full_cpu_cache, src_to_dst)

def copy_to(self, src_to_dst: Dict[int, int], cache_type: str = 'gpu') -> None:
"""Copy cache.

Args:
src_to_dst (Dict[int, int]): Map between src and dst.
cache_type (str): cache type 'cpu', 'gpu'
"""
target_cache = self.full_gpu_cache if cache_type == 'gpu' else self.full_cpu_cache
self._swap(target_cache, target_cache, src_to_dst)

@classmethod
def get_cache_block_size(cls,
block_size: int,
Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,7 @@ def __make_dummy_inputs():
inputs=ModelInputs.make_dummy(1, is_decoding=not prefill),
swap_in_map=dict(),
swap_out_map=dict(),
copy_map=dict(),
loop_count=num_loops,
is_dummy=True,
sync_long_context=False,
Expand Down Expand Up @@ -826,6 +827,7 @@ def __make_dummy_inputs():
running = scheduler_output.running
swap_in_map = scheduler_output.swap_in_map
swap_out_map = scheduler_output.swap_out_map
copy_map = scheduler_output.copy_map

if self.should_execute_dummy_batch and len(running) == 0:
return __make_dummy_inputs()
Expand All @@ -847,6 +849,7 @@ def __make_dummy_inputs():
inputs=inputs,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map,
copy_map=copy_map,
loop_count=num_loops,
all_ids=all_ids,
guided_input_ids=guided_input_ids,
Expand Down
29 changes: 19 additions & 10 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def msg_with_rank(rank: int, msg: str):
return f'rank[{rank}] - {msg}'


def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict):
def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict, copy_map: dict):
"""perform cache swapping."""
issued_cache_op = False
if len(swap_in_map) > 0:
Expand All @@ -37,7 +37,9 @@ def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: d
if len(swap_out_map) > 0:
cache_engine.swap_out(swap_out_map)
issued_cache_op = True

if len(copy_map) > 0:
cache_engine.copy_to(copy_map)
issued_cache_op = True
if issued_cache_op:
cache_engine.events.wait()

Expand Down Expand Up @@ -135,7 +137,7 @@ def all_context(self):
def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
raise NotImplementedError('NotImplemented.')

async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap, copy_map: SwapMap):
"""model forward.

Args:
Expand Down Expand Up @@ -200,6 +202,7 @@ async def _async_model_forward(
inputs: ModelInputs,
swap_in_map: Dict,
swap_out_map: Dict,
copy_map: Dict,
return_logits: bool,
sync_long_context: bool,
):
Expand Down Expand Up @@ -241,12 +244,15 @@ def get_output(self):

async def __forward(inputs):
"""forward."""
nonlocal swap_done, swap_in_map, swap_out_map
nonlocal swap_done, swap_in_map, swap_out_map, copy_map
if swap_done:
return await self.async_forward(inputs, swap_in_map=dict(), swap_out_map=dict())
return await self.async_forward(inputs, swap_in_map=dict(), swap_out_map=dict(), copy_map=dict())
else:
swap_done = True
return await self.async_forward(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
return await self.async_forward(inputs,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map,
copy_map=copy_map)

async def __long_context_single_forward(new_inputs, max_seqlen: int):
"""one large sequence."""
Expand Down Expand Up @@ -334,6 +340,7 @@ async def _async_step_background(
inputs: ModelInputs,
swap_in_map: Dict,
swap_out_map: Dict,
copy_map: Dict,
loop_count: int,
all_ids: torch.Tensor = None,
guided_input_ids: torch.Tensor = None,
Expand Down Expand Up @@ -420,6 +427,7 @@ async def __await_distworker(worker, timeout: float = 0.001):
inputs,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map,
copy_map=copy_map,
return_logits=return_logits,
sync_long_context=sync_long_context,
)
Expand Down Expand Up @@ -467,6 +475,7 @@ async def __await_distworker(worker, timeout: float = 0.001):
if is_decoding and idx < loop_count - 1:
swap_in_map = dict()
swap_out_map = dict()
copy_map = dict()
inputs.model_metas = model_metas
__update_inputs(next_token_ids)

Expand Down Expand Up @@ -637,8 +646,8 @@ def build_cache_engine(self):

self.cache_engine = CacheEngine(self.cache_config, self.model_config, world_size=tp)

def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap, copy_map: SwapMap):
cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map, copy_map=copy_map)
output = model_forward(
self.patched_model,
inputs,
Expand All @@ -647,15 +656,15 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map:
)
return output

async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap, copy_map: SwapMap):
"""model forward.

Args:
inputs (Dict): The input data comes from _make_inputs.
swap_in_map (SwapMap): Cache maps to swap in.
swap_out_map (SwapMap): Cache maps to swap out.
"""
output = self._forward_impl(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
output = self._forward_impl(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map, copy_map=copy_map)
await asyncio.sleep(0)
return output

Expand Down
86 changes: 76 additions & 10 deletions lmdeploy/pytorch/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,21 +371,67 @@ def __init__(self, multimodals: MultiModalInputs):
if multimodals is None:
multimodals = dict()
self.multimodals = multimodals
self._init_mm_ranges()

def _init_mm_ranges(self):
"""init mm ranges and sort it."""
mm_ranges = []
for _, modal_datas in self.multimodals.items():
for modal_data in modal_datas:
data = (modal_data.start, modal_data.end, modal_data.meta.get('hash_value', None))
mm_ranges.append(data)
mm_ranges.sort(key=lambda x: x[1])
self._mm_ranges = mm_ranges

@property
def mm_ranges(self):
"""mm_ranges."""
return self._mm_ranges

def get_datas(self, start=0, end=-1):
"""get multimodals from prompts position [start, end)."""
outs = dict()
test_range = range(start, end)
for modal_type, modal_datas in self.multimodals.items():
data = []
for modal_data in modal_datas:
if (modal_data.start not in test_range and modal_data.end not in test_range):
continue
data.append(modal_data)
if modal_data.start < end and modal_data.end > start:
data.append(modal_data)
if len(data) > 0:
outs[modal_type] = data
return outs

def get_step(self, step: int) -> int:
"""get step that before a whole image."""
real_step = step
for start, end, _ in self._mm_ranges:
if start <= real_step < end:
real_step = start
return real_step

def has_data(self, start: int, end: int) -> bool:
"""whether has multimodal data in [start, end)"""
return any([s < end and e > start for s, e, _ in self._mm_ranges])

def get_hash_values(self, start: int, end: int):
"""get multimodals hash values that from [start, end)"""
mm_hash_values = []
multimodal_ends = []

for mm_start, mm_end, hash_value in self._mm_ranges:
# the mm range intersect with the target range
if mm_start < end and mm_end > start:
mm_hash_values.append(hash_value)
# the mm end in the target range
if start < mm_end <= end:
cur_data = (tuple(mm_hash_values), mm_end)
multimodal_ends.append(cur_data)

if len(mm_hash_values) == 0:
mm_hash_values = None
else:
mm_hash_values = tuple(mm_hash_values)
return mm_hash_values, multimodal_ends

def add_inputs(self, input_mms: MultiModalInputs):
"""add new inputs."""
for modal_type, vals in input_mms.items():
Expand All @@ -394,9 +440,17 @@ def add_inputs(self, input_mms: MultiModalInputs):
else:
self.multimodals[modal_type] = vals

def empty(self):
# update mm_ranges
for modal_data in vals:
data = (modal_data.start, modal_data.end, modal_data.meta.get('hash_value', None))
self._mm_ranges.append(data)

# sort mm_ranges
self._mm_ranges.sort(key=lambda x: x[1])

def empty(self) -> bool:
if len(self.multimodals) == 0:
return 0
return True

return all(len(vals) == 0 for vals in self.multimodals)

Expand Down Expand Up @@ -582,7 +636,7 @@ def update_token_ids(self,

# update multimodals
if multimodals is not None:
multimodals = HistoryMultiModals.update_multimodals(multimodals, self.num_all_ids)
multimodals = HistoryMultiModals.update_multimodals(multimodals, self._num_history_ids)
self.history_multimodals.add_inputs(multimodals)

# cross
Expand Down Expand Up @@ -610,11 +664,11 @@ def set_step(self, step: int):
"""set step."""
num_all_ids = self.num_all_ids
# update step for vlm
if len(self.history_embeddings) > 0:
new_step, self._num_history_images, self._num_images = \
self.history_embeddings.get_step(step)
if self.history_multimodals is not None:
new_step = self.history_multimodals.get_step(step)
assert 0 <= new_step <= step
step = new_step

self._num_history_ids = step
self._num_token_ids = num_all_ids - step
self.num_ignored_history = min(step, self.num_ignored_history)
Expand All @@ -625,3 +679,15 @@ def set_step(self, step: int):
if self.history_multimodals is not None:
self._num_history_cross = self.history_multimodals.get_encoder_len(0, self.num_history_ids)
self._num_cross = self.history_multimodals.get_encoder_len(self._num_history_ids, num_all_ids)

def __repr__(self):
return (f'SchedulerSequence(seq_id={self.seq_id}, session_id={self.session_id}, '
f'status={self.status}, arrive_time={self.arrive_time}, '
f'return_logits={self.return_logits}, sampling_param={self.sampling_param}, '
f'num_history_tokens={self.history_len}, num_all_tokens={self.num_all_ids}, '
f'num_new_tokens={self.num_new_tokens}, all_token_ids={self.all_ids}, '
f'mm_ranges={self.history_multimodals.mm_ranges}, '
f'num_gpu_blocks={self.num_blocks}, gpu_blocks={self.logical_blocks.get_real_blocks()}, '
f'last_shared_node={getattr(self.logical_blocks, "last_shared_node", None)})')

__str__ = __repr__
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,13 +852,14 @@ def preprocess_input(self,
offset = input_mm['offset']
num_pad = input_mm['image_tokens']
image_token_id = input_mm['image_token_id']
hash_value = input_mm.get('hash_value', None)
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=offset,
end=offset + num_pad,
meta=dict(image_token_id=image_token_id))
meta=dict(image_token_id=image_token_id, hash_value=hash_value))
input_imgs.append(mm_data)

result = PreprocessInputResult(
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/models/cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,13 +889,14 @@ def preprocess_input(self, input_ids: List[int], input_multimodals=None, **kwarg
offset = input_mm['offset']
image_token_id = input_mm['image_token_id']
num_pad = input_mm['image_tokens']
hash_value = input_mm.get('hash_value', None)
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=offset,
end=offset + num_pad,
meta=dict(image_token_id=image_token_id))
meta=dict(image_token_id=image_token_id, hash_value=hash_value))
input_imgs.append(mm_data)

result = PreprocessInputResult(
Expand Down
5 changes: 4 additions & 1 deletion lmdeploy/pytorch/models/deepseek_vl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()

self.ctx_mgr = ctx_mgr

# ----------- vision encoder ------------
Expand Down Expand Up @@ -144,7 +145,7 @@ def __init__(self,
# ----------- language model ------------
language_config = config.language_config
self.language = DeepseekV2ForCausalLM(config=language_config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)

self.config = language_config
# ----------- input processor ------------
self.input_processor = DeepSeekVLV2InputProcessor(config, dtype)

Expand Down Expand Up @@ -434,6 +435,7 @@ def preprocess_input(self,
offset = input_mm['offset']
image_token_id = input_mm['image_token_id']
num_pad = input_mm['image_tokens']
hash_value = input_mm.get('hash_value', None)
images_spatial_crop = input_mm.get('images_spatial_crop', None)
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()
Expand All @@ -443,6 +445,7 @@ def preprocess_input(self,
end=offset + num_pad,
meta=dict(
image_token_id=image_token_id,
hash_value=hash_value,
images_spatial_crop=images_spatial_crop,
))

Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/models/gemma3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,14 @@ def preprocess_input(self,
offset = input_mm['offset']
image_token_id = input_mm['image_token_id']
num_pad = input_mm['image_tokens']
hash_value = input_mm.get('hash_value', None)
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=offset,
end=offset + num_pad,
meta=dict(image_token_id=image_token_id))
meta=dict(image_token_id=image_token_id, hash_value=hash_value))
input_imgs.append(mm_data)

result = PreprocessInputResult(
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,13 +519,14 @@ def preprocess_input(self,
offset = input_mm['offset']
image_token_id = input_mm['image_token_id']
num_pad = input_mm['image_tokens']
hash_value = input_mm.get('hash_value', None)
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=offset,
end=offset + num_pad,
meta=dict(image_token_id=image_token_id))
meta=dict(image_token_id=image_token_id, hash_value=hash_value))
input_imgs.append(mm_data)

result = PreprocessInputResult(
Expand Down
Loading