Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions cpp/kernels/fmha_v2/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6379,6 +6379,16 @@ def enumerate_kernels():
and kspec.version == 2
and kspec.cross_mha == False
and kspec.flash_attention == False)
# Clip/SigLip support.
or (kspec.sm == 100
and kspec.dtype in ['fp16', 'bf16', 'fp16_fp32', 'e4m3', 'e4m3_fp32']
and kspec.head_size == 80
and kspec.head_size_v == 0
and kspec.sage_block_sizes is None
and kspec.version == 2
and kspec.cross_mha == False
and kspec.flash_attention == True
and kspec.input_layout != InputLayout.SEPARATE_Q_K_V)
# Deepseek MLA (generation 576/512 paged)
or (kspec.sm in [90, 100, 120]
and kspec.dtype in ['bf16', 'e4m3_fp32']
Expand Down
5 changes: 4 additions & 1 deletion cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ QkvLayout AttentionInputLayoutToQkvLayout(AttentionInputLayout layout)

FmhaDispatcher::FmhaDispatcher(MHARunnerFixedParams fixedParams)
: mFixedParams(fixedParams)
, mUseTllmGen(tensorrt_llm::common::isSM100Family())
// TRTLLM-GEN only supports power of 2 head sizes.
// The exception will fall back to fmha v2.
// Please update fmha_v2/setup.py if you want to add more supported head sizes.
, mUseTllmGen(tensorrt_llm::common::isSM100Family() && (fixedParams.headSize & (fixedParams.headSize - 1)) == 0)
{
if (mUseTllmGen)
{
Expand Down
8 changes: 5 additions & 3 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,15 @@ def __setattr__(self, key, value):
"""
Prevent modification of frozen instance attributes.
However, we allow modification of 'extra_attrs' attributes for torch.compile
and 'pretrained_config' attributes for mutimodal models. All the other
attributes are frozen.
and 'pretrained_config' attributes for mutimodal models.
'quant_config' is allowed to be modified to set different quantization for VLM.
All the other attributes are frozen.
This can be bypassed by manually setting '_frozen' to False. The design is
to discourage modifying the attributes unintentionally.
"""
if self._frozen:
if key not in ('_frozen', 'extra_attrs', 'pretrained_config'):
if key not in ('_frozen', 'extra_attrs', 'pretrained_config',
'quant_config'):
raise AttributeError(
f"Cannot modify ModelConfig.'{key}' - instance is frozen")
super().__setattr__(key, value)
Expand Down
244 changes: 214 additions & 30 deletions tensorrt_llm/_torch/models/modeling_nanov2vlm.py

Large diffs are not rendered by default.

28 changes: 22 additions & 6 deletions tensorrt_llm/_torch/models/modeling_radio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Note: The code is to extract image embedding from RADIO model, to support Nano v2 VLM.
# TODO: Check and add more compatible logic for the full-series RADIO model.

import copy
import math
from collections import namedtuple
from typing import (Dict, Iterable, List, Literal, NamedTuple, Optional, Tuple,
Expand All @@ -21,6 +22,7 @@
from tensorrt_llm._torch.models import modeling_utils
from tensorrt_llm._torch.modules import attention as trtllm_attention
from tensorrt_llm._torch.modules import mlp as trtllm_mlp
from tensorrt_llm.models.modeling_utils import QuantConfig

InputDimT = Union[int, Tuple[int, int]]

Expand Down Expand Up @@ -770,14 +772,26 @@ def _extract_final(self,
class RADIOVisionModel(PreTrainedModel):
"""Modify from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/hf_model.py."""

def __init__(self, model_config: model_config_lib.ModelConfig):
def __init__(self,
model_config: model_config_lib.ModelConfig,
disable_quantization: bool = True):
"""
Args:
model_config: Model configuration.
disable_quantization: Disable quantization for RADIO model.
Since the radio model is for vision only, we can disable quantization for it by default.
"""
config = model_config.pretrained_config
super().__init__(config)
self.model_config = model_config

self.model_config = copy.deepcopy(model_config)
if self.model_config.quant_config is not None:
if disable_quantization:
# The basic method `apply_quant_config_exclude_modules` in DecoderModelForCausalLM keeps the kv_cache_quant_algo so we also keep it here.
self.model_config.quant_config = QuantConfig(
kv_cache_quant_algo=self.model_config.quant_config.
kv_cache_quant_algo)

self.config = config

RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
Expand Down Expand Up @@ -809,7 +823,7 @@ def __init__(self, model_config: model_config_lib.ModelConfig):
mlp_ratio=mlp_ratio,
drop_rate=args.drop,
special_args=args,
model_config=model_config,
model_config=self.model_config,
)
if hasattr(vit_model,
'norm') and not getattr(args, 'model_norm', False):
Expand Down Expand Up @@ -848,7 +862,7 @@ def __init__(self, model_config: model_config_lib.ModelConfig):
adaptors=adaptors,
feature_normalizer=feature_normalizer,
inter_feature_normalizer=inter_feature_normalizer,
model_config=model_config,
model_config=self.model_config,
)

def load_weights(self, weights):
Expand All @@ -861,8 +875,10 @@ def load_weights(self, weights):
filter_weights, strict=False)
# Check missing and unexpected keys.
# The input conditioner is not initialized in current implementation.
unexpected_keys.remove("input_conditioner.norm_mean")
unexpected_keys.remove("input_conditioner.norm_std")
if "input_conditioner.norm_mean" in unexpected_keys:
unexpected_keys.remove("input_conditioner.norm_mean")
if "input_conditioner.norm_std" in unexpected_keys:
unexpected_keys.remove("input_conditioner.norm_std")
# Partial model.blocks weights will loaded in the following step.
for m in missing_keys:
if not m.startswith('model.blocks.'):
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/inputs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .data import PromptInputs, TextPrompt, TokensPrompt, prompt_inputs
from .evs import compute_retained_tokens_count, compute_retention_mask
from .multimodal import MultimodalInput
from .registry import (BaseMultimodalInputProcessor, ExtraProcessedInputs,
InputProcessor, MultimodalPlaceholderMetadata,
Expand Down Expand Up @@ -48,4 +49,6 @@
"load_image",
"load_video",
"get_cache_salt_id",
"compute_retained_tokens_count",
"compute_retention_mask",
]
93 changes: 93 additions & 0 deletions tensorrt_llm/inputs/evs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

import torch


def compute_retained_tokens_count(video_size: torch.LongTensor,
spatial_merge_size: int,
pruning_ratio: float) -> int:
"""
Compute the number of retained tokens for a given video.
Method ensures that we retain all the tokens from the first frame
regardless of the pruning rate.

Args:
video_size: The size of the video in the format of (T, H, W).
spatial_merge_size: The size of the spatial merge.
pruning_ratio: The pruning ratio.

Returns:
The number of retained tokens.
"""
# Note about why map(int,..) exists here.
# In vLLM a rounding issue was observed when input was Tensor versus when input was tuple of integers.
# Tuple of ints input came from Preprocessing stage, while in actual forward() it was a Tensor.
# To make sure number of output tokens stays the case - an explicit cast was added.
T, H, W = map(int, video_size)
min_num_tokens = (H // spatial_merge_size) * (W // spatial_merge_size)
evs_num_tokens = int(T * min_num_tokens * (1 - pruning_ratio))
return max(min_num_tokens, evs_num_tokens)


def compute_retention_mask(
video_embeds: torch.Tensor,
video_size: torch.LongTensor,
spatial_merge_size: int,
pruning_ratio: float,
flatten_output: bool = True,
) -> torch.Tensor:
"""
Computes the retention mask for input video embeddings.

Args:
video_embeds (`torch.Tensor`): The input video embeddings
of shape `(T * H * W // spatial_merge_size ^ 2, hidden_size)`
or shape `(T, H * W // spatial_merge_size ^ 2, hidden_size)`.
video_size (`torch.LongTensor` of shape `(3)`):
The temporal, height and width of video.
spatial_merge_size: Size reduction for rows & cols dimensions.
pruning_ratio: (`float`): Pruning ratio factor [0,1)
flatten_output: (`bool`): Whether to flatten the output mask.

Returns:
`torch.Tensor`: The retention mask for the video embeddings of
`(T * H * W // spatial_merge_size ^ 2)` shape.
"""
T, H, W = video_size

# Use reshape instead of einops to avoid graph breaks
video_embeds = video_embeds.reshape(
T,
H // spatial_merge_size,
W // spatial_merge_size,
video_embeds.size(-1),
)

# Core EVS
similarity = torch.nn.functional.cosine_similarity(video_embeds[1:, ...],
video_embeds[:-1, ...],
dim=-1)
dissimilarity = 1 - similarity

# Always ensure we include all tokens from the first frame
dissimilarity = torch.cat(
[255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity],
dim=0)

dissimilarity_flat = dissimilarity.view(-1)
order = torch.argsort(dissimilarity_flat,
dim=-1,
descending=True,
stable=True)
retain_num_tokens = compute_retained_tokens_count(video_size,
spatial_merge_size,
pruning_ratio)
topk_indices = order[:retain_num_tokens]

retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool)
retention_mask[topk_indices] = True
retention_mask = retention_mask.reshape(dissimilarity.size())

mask = retention_mask.view(-1) if flatten_output else retention_mask
return mask
28 changes: 24 additions & 4 deletions tensorrt_llm/inputs/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,8 @@ def find_mm_token_positions(
num_mm_tokens: List[int],
vocab_size: Optional[int] = None,
mm_token_ids: Optional[torch.Tensor] = None,
mm_special_token_ids: Optional[torch.Tensor] = None
mm_special_token_ids: Optional[torch.Tensor] = None,
kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[int], List[int]]:
"""Get starting positions of contiguous multimodal token chunks using known lengths.

Expand All @@ -658,6 +659,7 @@ def find_mm_token_positions(
num_mm_tokens: List of contiguous chunk lengths for each multimodal item
vocab_size: Size of the model's vocabulary (used to identify tokens > vocab_size)
mm_token_ids: Specific token IDs that represent multimodal tokens
mm_special_token_ids: Specific token IDs that represent special multimodal tokens

Returns:
List of starting positions for each contiguous multimodal token
Expand All @@ -679,13 +681,25 @@ def find_mm_token_positions(
elif isinstance(input_ids, np.ndarray):
input_ids = torch.from_numpy(input_ids)

# Create mask for multimodal tokens
# Create mask for multimodal tokens including special tokens if provided
if mm_token_ids is None:
mm_mask = input_ids >= vocab_size
if mm_special_token_ids is not None:
mm_special_token_ids = mm_special_token_ids.to(
device=input_ids.device, dtype=input_ids.dtype)
mm_mask = mm_mask | torch.isin(input_ids, mm_special_token_ids)
else:
mm_token_ids = mm_token_ids.to(device=input_ids.device,
dtype=input_ids.dtype)
if mm_token_ids.ndim != 1:
raise ValueError("mm_token_ids must be a 1D tensor")
mm_token_ids = torch.unique(mm_token_ids)
if mm_special_token_ids is not None:
mm_special_token_ids = mm_special_token_ids.to(
device=input_ids.device, dtype=input_ids.dtype)
mm_token_ids = torch.unique(
torch.cat([mm_token_ids, mm_special_token_ids]))
else:
mm_token_ids = torch.unique(mm_token_ids)
mm_mask = torch.isin(input_ids, mm_token_ids)

# If no multimodal tokens found, return empty list
Expand All @@ -694,9 +708,15 @@ def find_mm_token_positions(

# Get positions of all multimodal tokens
mm_positions = torch.where(mm_mask)[0].tolist()
try:
images = kwargs["mm_data"]['image']
image_sizes = [image.shape for image in images]
except Exception as e:
image_sizes = None

assert len(mm_positions) == sum(
num_mm_tokens
), f"Number of multimodal tokens does not match sum of all lengths"
), f"Number of multimodal tokens does not match sum of all lengths: {len(mm_positions)=}, {sum(num_mm_tokens)=}, {kwargs=} PIL {image_sizes=}"

# Use num_mm_tokens to find the starting position of each chunk
start_positions = []
Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,10 @@ def multimodal_hashing_process(
vocab_size=vocab_size,
mm_token_ids=mm_ids,
mm_special_token_ids=mm_special_token_ids,
kwargs={
"mm_data": mm_data,
"text_prompt": inputs["prompt"],
}
)
# Store special token offsets if available
if len(start_special_token_positions
Expand Down
45 changes: 41 additions & 4 deletions tensorrt_llm/inputs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async def async_load_image(
return image


def load_video(
def _load_video_by_cv2(
video: str,
num_frames: int = 10,
format: str = "pt",
Expand All @@ -151,7 +151,7 @@ def load_video(
break
frame_count -= 1
else:
raise ValueError(f"Video '{video}' has no frames.")
raise ValueError(f"Video has no frames.")

# Extract frames uniformly
indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int)
Expand All @@ -173,6 +173,38 @@ def load_video(
]


def load_base64_video(video: str) -> BytesIO:
parsed_url = urlparse(video)
data_spec, data = parsed_url.path.split(",", 1)
media_type, data_type = data_spec.split(";", 1)

if data_type != "base64":
msg = "Only base64 data URLs are supported for now."
raise NotImplementedError(msg)

content = base64.b64decode(data)
return content


def load_video(
video: str,
num_frames: int = 10,
format: str = "pt",
device: str = "cpu") -> Union[List[Image.Image], List[torch.Tensor]]:
parsed_url = urlparse(video)
if parsed_url.scheme in ["http", "https", ""]:
video_path = video
elif parsed_url.scheme == "data":
decoded_video = load_base64_video(video)
# TODO: any ways to read videos from memory, instead of writing to a tempfile?
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
tmp_file.write(decoded_video)
video_path = tmp_file.name
else:
raise ValueError(f"Unsupported video scheme: {parsed_url.scheme}")
return _load_video_by_cv2(video_path, num_frames, format, device)


async def async_load_video(
video: str,
num_frames: int = 10,
Expand All @@ -189,11 +221,16 @@ async def async_load_video(
suffix='.mp4') as tmp:
tmp.write(await response.content.read())
video_path = tmp.name
# TODO: add case for video encoded in base64
elif parsed_url.scheme == "data":
decoded_video = load_base64_video(video)
# TODO: any ways to read videos from memory, instead of writing to a tempfile?
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
tmp_file.write(decoded_video)
video_path = tmp_file.name
else:
video_path = video

return load_video(video_path, num_frames, format, device)
return _load_video_by_cv2(video_path, num_frames, format, device)


def load_audio(
Expand Down