Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"merge_chunked_attention_for_mla("
"Tensor merged_attn"
"Tensor(a!) merged_attn"
", Tensor temp_attn"
", Tensor merged_softmax_stats"
", Tensor temp_softmax_stats"
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ class AttentionMetadata:
_num_ctx_tokens: int = field(init=False, default=0, repr=False)
_num_tokens: int = field(init=False, default=0, repr=False)

# The number of tokens in the padded sequence.
padded_num_tokens: Optional[int] = None

# This buffer is currently only used for TrtllmAttentionMetadata.
cache_indirection: Optional[torch.Tensor] = None

Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/compilation/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
self.custom_passes = Backend.get_custom_pass(enable_userbuffers)
self.rank = tensorrt_llm.mpi_rank()
self.enable_inductor = enable_inductor
self.capture_num_tokens = capture_num_tokens or []
self.capture_num_tokens = sorted(capture_num_tokens or [])
self.piecewise_cuda_graph = enable_piecewise_cuda_graph
self.no_optimization = False
# We only need to create aux streams.
Expand Down
11 changes: 7 additions & 4 deletions tensorrt_llm/_torch/compilation/piecewise_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

from tensorrt_llm.llmapi.utils import enable_llm_debug

from ..utils import (get_model_extra_attrs, get_piecewise_cuda_graph_flag,
make_weak_ref)
from ..utils import (get_model_extra_attrs,
get_per_request_piecewise_cuda_graph_flag,
get_piecewise_cuda_graph_flag, make_weak_ref)
from .multi_stream.auto_multi_stream import multi_stream_schedule
from .utils import get_capture_piecewise_cuda_graph_flag, is_call_function

Expand Down Expand Up @@ -154,8 +155,10 @@ def __call__(self, *args):
elif isinstance(self.compile_time_num_tokens, int):
runtime_num_of_token = self.compile_time_num_tokens

if runtime_num_of_token is None or runtime_num_of_token not in self.entries or not get_piecewise_cuda_graph_flag(
):
if (runtime_num_of_token is None
or runtime_num_of_token not in self.entries
or not get_piecewise_cuda_graph_flag()
or not get_per_request_piecewise_cuda_graph_flag()):
return self.default_callable(*args)

entry = self.entries[runtime_num_of_token]
Expand Down
18 changes: 16 additions & 2 deletions tensorrt_llm/_torch/models/modeling_speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,9 @@ def forward(
**kwargs,
)

if attn_metadata.padded_num_tokens is not None:
hidden_states = hidden_states[:attn_metadata.num_tokens]

if self.draft_model is not None:
# get logits
logits = self.logits_processor.forward(
Expand All @@ -433,9 +436,20 @@ def forward(
attn_metadata,
True,
)
mtp_input_ids = input_ids
mtp_position_ids = position_ids
if attn_metadata.padded_num_tokens is not None:
if input_ids is not None:
# Slice along the first dimension
mtp_input_ids = input_ids[:attn_metadata.num_tokens]
if position_ids is not None:
# Slice along the last dimension
mtp_position_ids = position_ids[:, :attn_metadata.
num_tokens]

# get accepted tokens and next draft tokens
return self.spec_worker(input_ids=input_ids,
position_ids=position_ids,
return self.spec_worker(input_ids=mtp_input_ids,
position_ids=mtp_position_ids,
hidden_states=hidden_states,
logits=logits,
attn_metadata=attn_metadata,
Expand Down
131 changes: 58 additions & 73 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,19 @@ def _attn_impl(
attention_sinks: Optional[torch.Tensor] = None,
):

padded_num_tokens = attn_metadata.padded_num_tokens
num_tokens = attn_metadata.num_tokens

if padded_num_tokens is not None:
assert q.shape[0] == padded_num_tokens
q = q[:num_tokens, :]
if k is not None:
assert k.shape[0] == padded_num_tokens
k = k[:num_tokens, :]
if v is not None:
assert v.shape[0] == padded_num_tokens
v = v[:num_tokens, :]

out_scale = None
out_scale_sf = None
has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
Expand Down Expand Up @@ -368,7 +381,7 @@ def _attn_impl(
attention_window_size=attention_window_size,
attention_mask_data=attention_mask_data,
enable_attn_nvfp4_output=enable_attn_nvfp4_output,
output=output,
output=output[:num_tokens, :] if output is not None else None,
output_sf=output_sf,
attention_sinks=attention_sinks)
if isinstance(attn_output, tuple):
Expand Down Expand Up @@ -936,11 +949,10 @@ def create_output(self, hidden_states: torch.Tensor):
return hidden_states.new_empty([num_tokens, hidden_size],
dtype=hidden_states.dtype)

def forward_impl(self,
position_ids: Optional[torch.Tensor],
def forward_impl(self, position_ids: Optional[torch.Tensor],
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
output: Optional[torch.Tensor] = None) -> torch.Tensor:
output: torch.Tensor) -> None:
"""
Forward pass for the MLA module.

Expand All @@ -953,6 +965,18 @@ def forward_impl(self,
Returns:
torch.Tensor: The output tensor.
"""
# split q, k, v into context and gen batches
num_contexts = attn_metadata.num_contexts
num_generations = attn_metadata.num_generations
num_ctx_tokens = attn_metadata.num_ctx_tokens
num_tokens = attn_metadata.num_tokens
padded_num_tokens = attn_metadata.padded_num_tokens

if padded_num_tokens is not None:
hidden_states = hidden_states[:num_tokens, ...]
if position_ids is not None:
position_ids = position_ids[:num_tokens, ...]

if self.is_lite:
compressed_kv, k_pe = self.kv_a_proj_with_mqa(hidden_states).split(
[self.kv_lora_rank, self.qk_rope_head_dim], -1)
Expand Down Expand Up @@ -980,15 +1004,11 @@ def forward_impl(self,
self.aux_stream,
)

# split q, k, v into context and gen batches
num_contexts = attn_metadata.num_contexts
num_generations = attn_metadata.num_generations
num_ctx_tokens = attn_metadata.num_ctx_tokens
num_tokens = attn_metadata.num_tokens

assert q.shape[
0] == num_tokens, f"Expect q.shape[0] to be {num_tokens}, but got {q.shape[0]}"

assert output is not None, "output must be provided"

if num_contexts > 0:
q_ctx = q[:num_ctx_tokens, ...]
compressed_kv_ctx = compressed_kv[:num_ctx_tokens, ...]
Expand All @@ -998,17 +1018,14 @@ def forward_impl(self,
assert position_ids is not None
k_pe_ctx = self.apply_rope(q_ctx, k_pe_ctx, position_ids)

attn_output_context = self.forward_context(
self.forward_context(
q_ctx,
compressed_kv_ctx,
k_pe_ctx,
attn_metadata,
output[:num_ctx_tokens, :],
latent_cache_ctx,
output=output if num_generations == 0 else None)
if num_generations == 0:
return attn_output_context
else:
attn_output_context = None
)

if num_generations > 0:
q_gen = q[num_ctx_tokens:, ...]
Expand All @@ -1019,48 +1036,24 @@ def forward_impl(self,
assert position_ids is not None
k_pe_gen = self.apply_rope(q_gen, k_pe_gen, position_ids)

attn_output_gen = self.forward_generation(
self.forward_generation(
q_gen,
compressed_kv_gen,
k_pe_gen,
attn_metadata,
output[num_ctx_tokens:num_tokens, :],
latent_cache_gen,
output=output if num_contexts == 0 else None)
if num_contexts == 0:
return attn_output_gen
else:
attn_output_gen = None

# release pytorch activation memory
q = None
compressed_kv = None
k_pe = None

assert attn_output_context is not None and attn_output_gen is not None
assert (
len(attn_output_context.shape) == 2
), f"attn_output_context must be rank 2, not {len(attn_output_context.shape)}"
assert (
len(attn_output_gen.shape) == 2
), f"attn_output_gen must be rank 2, not {len(attn_output_gen.shape)}"
output = output if output is not None else torch.empty(
(num_tokens, attn_output_context.shape[1]),
dtype=attn_output_context.dtype,
device=attn_output_context.device)
output[:attn_output_context.shape[0], :] = attn_output_context
output[attn_output_context.shape[0]:, :] = attn_output_gen
attn_output_context = None
attn_output_gen = None
return output
)

def forward_context_default(
self,
q: torch.Tensor,
compressed_kv: torch.Tensor,
k_pe: torch.Tensor,
attn_metadata: AttentionMetadata,
latent_cache: Optional[torch.Tensor] = None,
output: Optional[torch.Tensor] = None) -> torch.Tensor:
self,
q: torch.Tensor,
compressed_kv: torch.Tensor,
k_pe: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor,
latent_cache: Optional[torch.Tensor] = None,
) -> torch.Tensor:
kv = self.kv_b_proj(compressed_kv)
k_nope, v = kv.split(
[
Expand Down Expand Up @@ -1099,7 +1092,7 @@ def forward_context_with_cached_kv(
q: torch.Tensor,
latent_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
output: Optional[torch.Tensor] = None,
output: torch.Tensor,
) -> torch.Tensor:
assert latent_cache is not None
trtllm_attention = cast(TrtllmAttention, self.mha)
Expand Down Expand Up @@ -1168,7 +1161,7 @@ def forward_context_with_chunked_prefill(
latent_cache: torch.
Tensor, # compressed_kv + k_pe [context_tokens, 1, lora_size + rope_size]
attn_metadata: TrtllmAttentionMetadata,
output: Optional[torch.Tensor] = None,
output: torch.Tensor,
) -> torch.Tensor:
trtllm_attention = cast(TrtllmAttention, self.mha)
# apply RoPE, append compressed_kv + k_pe to paged kv cache and assign q_pe to q
Expand All @@ -1190,11 +1183,8 @@ def forward_context_with_chunked_prefill(
dtype=torch.float,
device='cuda',
)
if output is None:
attn_output = q.new_empty(
(q.size(0), self.num_heads * self.v_head_dim), dtype=q.dtype)
else:
attn_output = output

attn_output = output
temp_attn_output = q.new_empty(
(q.size(0), self.num_heads * self.v_head_dim), dtype=q.dtype)

Expand Down Expand Up @@ -1332,8 +1322,8 @@ def forward_context(
compressed_kv: torch.Tensor,
k_pe: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor,
latent_cache: Optional[torch.Tensor] = None,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(self.mha, TrtllmAttention):
assert isinstance(attn_metadata, TrtllmAttentionMetadata)
Expand All @@ -1346,16 +1336,17 @@ def forward_context(
return self.forward_context_with_cached_kv(
q, latent_cache, attn_metadata, output)
return self.forward_context_default(q, compressed_kv, k_pe,
attn_metadata, latent_cache, output)
attn_metadata, output, latent_cache)

def forward_generation(
self,
q: torch.Tensor,
compressed_kv: torch.Tensor,
k_pe: torch.Tensor,
attn_metadata: AttentionMetadata,
latent_cache: Optional[torch.Tensor] = None,
output: Optional[torch.Tensor] = None) -> torch.Tensor:
self,
q: torch.Tensor,
compressed_kv: torch.Tensor,
k_pe: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor,
latent_cache: Optional[torch.Tensor] = None,
) -> torch.Tensor:
num_tokens = q.shape[0]
q_nope, q_pe = q.view([-1, self.num_heads, self.qk_head_dim]).split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
Expand Down Expand Up @@ -1427,12 +1418,6 @@ def forward_generation(
attn_out_latent = attn_out_latent.view(
[-1, self.num_heads, self.kv_lora_rank])

# [seq, num_heads * v_head_dim]
output = output if output is not None else torch.empty(
[num_tokens, self.num_heads * self.v_head_dim],
dtype=attn_out_latent.dtype,
device=attn_out_latent.device)

attn_output = output.view([num_tokens, self.num_heads, self.v_head_dim])

if self.v_b_proj.dtype == torch.bfloat16:
Expand Down
Loading