Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 270 additions & 17 deletions tensorrt_llm/_torch/models/modeling_nanov2vlm.py

Large diffs are not rendered by default.

41 changes: 31 additions & 10 deletions tensorrt_llm/commands/serve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import gc
import json
import os
import signal # Added import
import subprocess # nosec B404
Expand All @@ -18,6 +19,7 @@
from tensorrt_llm._torch.auto_deploy.llm import LLM as AutoDeployLLM
from tensorrt_llm._utils import mpi_rank
from tensorrt_llm.executor.utils import LlmLauncherEnvs
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
DynamicBatchConfig, KvCacheConfig,
SchedulerConfig)
Expand Down Expand Up @@ -138,12 +140,14 @@ def get_llm_args(model: str,
return llm_args, llm_args_extra_dict


def launch_server(host: str,
port: int,
llm_args: dict,
metadata_server_cfg: Optional[MetadataServerConfig] = None,
server_role: Optional[ServerRole] = None,
disagg_cluster_config: Optional[DisaggClusterConfig] = None):
def launch_server(
host: str,
port: int,
llm_args: dict,
metadata_server_cfg: Optional[MetadataServerConfig] = None,
server_role: Optional[ServerRole] = None,
disagg_cluster_config: Optional[DisaggClusterConfig] = None,
multimodal_server_config: Optional[MultimodalServerConfig] = None):

backend = llm_args["backend"]
model = llm_args["model"]
Expand All @@ -165,7 +169,8 @@ def launch_server(host: str,
model=model,
server_role=server_role,
metadata_server_cfg=metadata_server_cfg,
disagg_cluster_config=disagg_cluster_config)
disagg_cluster_config=disagg_cluster_config,
multimodal_server_config=multimodal_server_config)

# Optionally disable GC (default: not disabled)
if os.getenv("TRTLLM_SERVER_DISABLE_GC", "0") == "1":
Expand Down Expand Up @@ -325,6 +330,10 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
is_flag=True,
default=False,
help="Enable chunked prefill")
@click.option("--media-io-kwargs",
type=str,
default=None,
help="Keyword arguments for media I/O.")
def serve(
model: str, tokenizer: Optional[str], host: str, port: int,
log_level: str, backend: str, max_beam_width: int, max_batch_size: int,
Expand All @@ -335,7 +344,9 @@ def serve(
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
metadata_server_config_file: Optional[str], server_role: Optional[str],
fail_fast_on_attention_window_too_large: bool,
enable_chunked_prefill: bool, disagg_cluster_uri: Optional[str]):
enable_chunked_prefill: bool,
disagg_cluster_uri: Optional[str],
media_io_kwargs: Optional[str]):
"""Running an OpenAI API compatible server

MODEL: model name | HF checkpoint path | TensorRT engine path
Expand Down Expand Up @@ -391,8 +402,18 @@ def serve(
except ValueError:
raise ValueError(f"Invalid server role: {server_role}. " \
f"Must be one of: {', '.join([role.name for role in ServerRole])}")
launch_server(host, port, llm_args, metadata_server_cfg, server_role,
disagg_cluster_config)

# Parse media_io_kwargs from JSON string to dict if provided
parsed_media_io_kwargs = None
if media_io_kwargs is not None:
try:
parsed_media_io_kwargs = json.loads(media_io_kwargs)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON for media_io_kwargs: {e}")

multimodal_server_config = MultimodalServerConfig(
media_io_kwargs=parsed_media_io_kwargs)
launch_server(host, port, llm_args, metadata_server_cfg, server_role, disagg_cluster_config, multimodal_server_config)


@click.command("mm_embedding_serve")
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/inputs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .data import PromptInputs, TextPrompt, TokensPrompt, prompt_inputs
from .evs import compute_retained_tokens_count, compute_retention_mask
from .multimodal import MultimodalInput
from .registry import (BaseDummyInputsBuilder, BaseMultimodalInputProcessor,
ExtraProcessedInputs, InputProcessor,
Expand Down Expand Up @@ -50,4 +51,6 @@
"load_image",
"load_video",
"get_cache_salt_id",
"compute_retained_tokens_count",
"compute_retention_mask",
]
93 changes: 93 additions & 0 deletions tensorrt_llm/inputs/evs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

import torch


def compute_retained_tokens_count(video_size: torch.LongTensor,
spatial_merge_size: int,
pruning_ratio: float) -> int:
"""
Compute the number of retained tokens for a given video.
Method ensures that we retain all the tokens from the first frame
regardless of the pruning rate.

Args:
video_size: The size of the video in the format of (T, H, W).
spatial_merge_size: The size of the spatial merge.
pruning_ratio: The pruning ratio.

Returns:
The number of retained tokens.
"""
# Note about why map(int,..) exists here.
# In vLLM a rounding issue was observed when input was Tensor versus when input was tuple of integers.
# Tuple of ints input came from Preprocessing stage, while in actual forward() it was a Tensor.
# To make sure number of output tokens stays the case - an explicit cast was added.
T, H, W = map(int, video_size)
min_num_tokens = (H // spatial_merge_size) * (W // spatial_merge_size)
evs_num_tokens = int(T * min_num_tokens * (1 - pruning_ratio))
return max(min_num_tokens, evs_num_tokens)


def compute_retention_mask(
video_embeds: torch.Tensor,
video_size: torch.LongTensor,
spatial_merge_size: int,
pruning_ratio: float,
flatten_output: bool = True,
) -> torch.Tensor:
"""
Computes the retention mask for input video embeddings.

Args:
video_embeds (`torch.Tensor`): The input video embeddings
of shape `(T * H * W // spatial_merge_size ^ 2, hidden_size)`
or shape `(T, H * W // spatial_merge_size ^ 2, hidden_size)`.
video_size (`torch.LongTensor` of shape `(3)`):
The temporal, height and width of video.
spatial_merge_size: Size reduction for rows & cols dimensions.
pruning_ratio: (`float`): Pruning ratio factor [0,1)
flatten_output: (`bool`): Whether to flatten the output mask.

Returns:
`torch.Tensor`: The retention mask for the video embeddings of
`(T * H * W // spatial_merge_size ^ 2)` shape.
"""
T, H, W = video_size

# Use reshape instead of einops to avoid graph breaks
video_embeds = video_embeds.reshape(
T,
H // spatial_merge_size,
W // spatial_merge_size,
video_embeds.size(-1),
)

# Core EVS
similarity = torch.nn.functional.cosine_similarity(video_embeds[1:, ...],
video_embeds[:-1, ...],
dim=-1)
dissimilarity = 1 - similarity

# Always ensure we include all tokens from the first frame
dissimilarity = torch.cat(
[255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity],
dim=0)

dissimilarity_flat = dissimilarity.view(-1)
order = torch.argsort(dissimilarity_flat,
dim=-1,
descending=True,
stable=True)
retain_num_tokens = compute_retained_tokens_count(video_size,
spatial_merge_size,
pruning_ratio)
topk_indices = order[:retain_num_tokens]

retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool)
retention_mask[topk_indices] = True
retention_mask = retention_mask.reshape(dissimilarity.size())

mask = retention_mask.view(-1) if flatten_output else retention_mask
return mask
33 changes: 30 additions & 3 deletions tensorrt_llm/inputs/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,11 @@ def has_content(self) -> bool:
return bool(self.multimodal_input or self.multimodal_data)


@dataclass
class MultimodalServerConfig():
media_io_kwargs: Optional[dict] = None


# adopt from vllm : https://github.com/vllm-project/vllm/blob/main/vllm/vllm/multimodal/hash.py
def serialize_item(obj: object) -> bytes:
# Simple cases
Expand Down Expand Up @@ -558,6 +563,13 @@ def _hash_image(image):
if isinstance(frame, torch.Tensor):
frame = frame.detach().cpu().contiguous()
hasher.update(serialize_item(frame))
elif isinstance(image, dict):
frames = image["frames"]
for frame in frames:
hasher.update(b"<frame>")
if isinstance(frame, torch.Tensor):
frame = frame.detach().cpu().contiguous()
hasher.update(serialize_item(frame))
else:
hasher.update(serialize_item(image))

Expand Down Expand Up @@ -622,6 +634,8 @@ def find_mm_token_lengths(mm_data: Dict[str, Any],
image=item, )
modality_token_lengths.append(num_tokens)
elif modality == "video":
if isinstance(item, dict):
item = item["frames"]
assert isinstance(item, list), "Video must be a list of frames"
if isinstance(item[0], torch.Tensor):
item = [ToPILImage()(frame) for frame in item]
Expand Down Expand Up @@ -658,6 +672,7 @@ def find_mm_token_positions(
num_mm_tokens: List of contiguous chunk lengths for each multimodal item
vocab_size: Size of the model's vocabulary (used to identify tokens > vocab_size)
mm_token_ids: Specific token IDs that represent multimodal tokens
mm_special_token_ids: Specific token IDs that represent special multimodal tokens

Returns:
List of starting positions for each contiguous multimodal token
Expand All @@ -679,13 +694,25 @@ def find_mm_token_positions(
elif isinstance(input_ids, np.ndarray):
input_ids = torch.from_numpy(input_ids)

# Create mask for multimodal tokens
# Create mask for multimodal tokens including special tokens if provided
if mm_token_ids is None:
mm_mask = input_ids >= vocab_size
if mm_special_token_ids is not None:
mm_special_token_ids = mm_special_token_ids.to(
device=input_ids.device, dtype=input_ids.dtype)
mm_mask = mm_mask | torch.isin(input_ids, mm_special_token_ids)
else:
mm_token_ids = mm_token_ids.to(device=input_ids.device,
dtype=input_ids.dtype)
if mm_token_ids.ndim != 1:
raise ValueError("mm_token_ids must be a 1D tensor")
mm_token_ids = torch.unique(mm_token_ids)
if mm_special_token_ids is not None:
mm_special_token_ids = mm_special_token_ids.to(
device=input_ids.device, dtype=input_ids.dtype)
mm_token_ids = torch.unique(
torch.cat([mm_token_ids, mm_special_token_ids]))
else:
mm_token_ids = torch.unique(mm_token_ids)
mm_mask = torch.isin(input_ids, mm_token_ids)

# If no multimodal tokens found, return empty list
Expand All @@ -696,7 +723,7 @@ def find_mm_token_positions(
mm_positions = torch.where(mm_mask)[0].tolist()
assert len(mm_positions) == sum(
num_mm_tokens
), f"Number of multimodal tokens does not match sum of all lengths"
), f"Number of multimodal tokens does not match sum of all lengths {len(mm_positions)} != {sum(num_mm_tokens)}"

# Use num_mm_tokens to find the starting position of each chunk
start_positions = []
Expand Down
Loading
Loading