Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ def fused_mha_with_cache(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
freqs_cis: Optional[torch.Tensor],
logit_cap: Optional[float] = None,
) -> torch.Tensor:
"""Fused MHA with cache that takes raw input from q, k, v GEMMs."""
# b, s info
Expand Down Expand Up @@ -593,6 +594,7 @@ def fused_mha_fake(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
freqs_cis: torch.Tensor,
logit_cap: Optional[float] = None,
):
return torch.empty_like(q.contiguous())

Expand Down
20 changes: 20 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class PlanParams:
q_dtype: torch.dtype
kv_dtype: torch.dtype
sm_scale: Optional[float] = None
logit_cap: Optional[float] = None

causal: bool = True

Expand Down Expand Up @@ -107,6 +108,7 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper):
q_data_type=plan_params.q_dtype,
kv_data_type=plan_params.kv_dtype,
sm_scale=plan_params.sm_scale,
logits_soft_cap=plan_params.logit_cap,
)

# we want to plan during warm-up of cuda graph capture to ensure we have the plan cached
Expand Down Expand Up @@ -143,6 +145,7 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper):
q_data_type=plan_params.q_dtype,
kv_data_type=plan_params.kv_dtype,
sm_scale=plan_params.sm_scale,
logits_soft_cap=plan_params.logit_cap,
)
self.plan_params = plan_params

Expand Down Expand Up @@ -250,6 +253,7 @@ def flashinfer_mha_with_cache(
scale: Optional[float],
k_scale: float,
v_scale: float,
logit_cap: Optional[float],
) -> torch.Tensor:
# reshape to standard [b*s, n_heads, head_dim] layout
head_dim = k_cache.shape[-1]
Expand All @@ -273,6 +277,7 @@ def flashinfer_mha_with_cache(
q_dtype=q.dtype,
kv_dtype=k_cache.dtype,
sm_scale=scale,
logit_cap=logit_cap,
)

# Assuming k_scale = v_scale = 1.0, we just have to cast k and v to fp8 before appending to kv cache
Expand Down Expand Up @@ -327,6 +332,7 @@ def flashinfer_mha_with_cache_fake(
scale: Optional[float],
k_scale: float,
v_scale: float,
logit_cap: Optional[float],
) -> torch.Tensor:
return torch.empty_like(q.contiguous())

Expand Down Expand Up @@ -419,8 +425,22 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
ad_logger.warning("Provided scale is not a float. Using default scale instead.")
scale = None

# Get logit_cap from args or kwargs - it's typically the 8th argument (index 7)
if len(source_attn_node.args) > 7:
logit_cap = source_attn_node.args[7]
else:
logit_cap = source_attn_node.kwargs.get("logit_cap", None)

if not (isinstance(logit_cap, float) or logit_cap is None):
ad_logger.debug("Provided logit_cap is not a float or None. Disabling soft-capping.")
logit_cap = None
elif isinstance(logit_cap, float) and logit_cap <= 0:
ad_logger.warning("Provided logit_cap is not positive. Disabling soft-capping.")
logit_cap = None

return [
scale, # softmax scale
1.0, # k_scale
1.0, # v_scale
logit_cap,
]
61 changes: 48 additions & 13 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,27 +40,58 @@ def scaled_dot_product_attention(
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
logit_cap: Optional[float] = None,
) -> torch.Tensor:
"""A carbon copy of torch.nn.functional.scaled_dot_product_attention as custom op.

Using this custom op instead of using the functional directly ensures consistent representation
of the vanilla sdpa in a graph.
"""

return F.scaled_dot_product_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)
# Handle soft capping by applying it manually since F.scaled_dot_product_attention
# may not support soft_cap parameter
if logit_cap is not None:
# Apply manual soft capping to the attention scores
# First compute raw attention scores
d_k = query.size(-1)
if scale is None:
scale = 1.0 / (d_k**0.5)

# Compute attention scores
scores = torch.matmul(query, key.transpose(-2, -1)) * scale

# Apply soft capping: tanh(scores / logit_cap) * logit_cap
scores = torch.tanh(scores / logit_cap) * logit_cap

if attn_mask is not None:
scores += attn_mask

# Apply softmax
attn_weights = F.softmax(scores, dim=-1)

# Apply dropout if specified
if dropout_p > 0.0:
attn_weights = F.dropout(attn_weights, p=dropout_p, training=torch.is_grad_enabled())

# Apply attention to values
output = torch.matmul(attn_weights, value)
return output.contiguous()
else:
# Use standard SDPA when no soft capping
return F.scaled_dot_product_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)


@scaled_dot_product_attention.register_fake
def scaled_dot_product_attention_fake(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, logit_cap=None
Comment on lines 93 to +94
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nvchenghaoz, how about we wait until we merge this PR until we have decided on how we will support more attention features/arguments?

):
"""Fake implementation of scaled_dot_product_attention."""
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
Expand All @@ -75,10 +106,12 @@ def grouped_sdpa(
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
logit_cap: Optional[float] = None,
) -> torch.Tensor:
"""SDPA attention that can handle GQA."""

return F.scaled_dot_product_attention(
# Use our custom scaled_dot_product_attention that supports soft capping
return scaled_dot_product_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
Expand All @@ -99,6 +132,7 @@ def grouped_sdpa_fake(
dropout_p=0.0,
is_causal=False,
scale=None,
logit_cap=None,
):
"""Fake implementation of grouped SDPA."""
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
Expand All @@ -113,6 +147,7 @@ def bsnd_grouped_sdpa(
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
logit_cap: Optional[float] = None,
) -> torch.Tensor:
"""Attention that assumes the input layout is bsnd.

Expand All @@ -124,15 +159,15 @@ def bsnd_grouped_sdpa(
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()

out = grouped_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale)
out = grouped_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale, logit_cap)

# let's transpose back to bnsd
return out.transpose(1, 2).contiguous()


@bsnd_grouped_sdpa.register_fake
def bsnd_grouped_sdpa_fake(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, logit_cap=None
):
"""Fake implementation of bnsd grouped SDPA."""
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
Expand Down
17 changes: 15 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _generate_mha(
cache_locs: torch.Tensor,
input_pos: torch.Tensor,
scale: float,
logit_cap: Optional[float],
out: torch.Tensor,
):
b, (n_heads, q_d_head) = q.shape[0], q.shape[-2:]
Expand Down Expand Up @@ -97,6 +98,7 @@ def _generate_mha(
v_d_head,
SEQ_BLOCK_SIZE,
HEAD_BLOCK_SIZE,
LOGIT_CAP=logit_cap,
)
attention_kv_stage2[(b, n_heads, 1)](
stage1_output_values,
Expand All @@ -122,6 +124,7 @@ def _flattened_context_mha(
seq_start: torch.Tensor,
scale: float,
out: torch.Tensor,
logit_cap: Optional[float],
) -> None:
# NOTE: s_total == sum(seq_len)
s_total, n_heads, q_d_head = q.shape
Expand Down Expand Up @@ -166,6 +169,7 @@ def _flattened_context_mha(
SEQ_BLOCK,
max_cache_seq_len,
num_stages=2,
LOGIT_CAP=logit_cap,
)


Expand All @@ -187,6 +191,7 @@ def flattened_mha_with_cache(
# <none>
# CONSTANTS
scale: Optional[float],
logit_cap: Optional[float],
) -> torch.Tensor:
"""Flattened MHA with cache that takes q, k, v in BSND layout.

Expand Down Expand Up @@ -223,7 +228,7 @@ def flattened_mha_with_cache(
y = q.new_empty(*bs_view, num_heads, v_head_dim).contiguous()
if s == 1:
# generate-only phase
_generate_mha(q, k, v, k_cache, v_cache, cache_loc, input_pos, scale, y)
_generate_mha(q, k, v, k_cache, v_cache, cache_loc, input_pos, scale, logit_cap, y)
else:
# mixed context + generate phase
_flattened_context_mha(
Expand All @@ -237,7 +242,8 @@ def flattened_mha_with_cache(
seq_len,
seq_start,
scale,
y,
out=y,
logit_cap=logit_cap,
)

return y.view(*output_shape)
Expand All @@ -255,6 +261,7 @@ def flattened_mha_fake(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
scale: Optional[float],
logit_cap: Optional[float],
):
return q.new_empty(*q.shape[:-1], v.shape[-1]).contiguous()

Expand Down Expand Up @@ -389,6 +396,12 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
ad_logger.warning("Provided scale is not a float, Using default scale instead.")
scale = None

if len(source_attn_node.args) > 7:
logit_cap = source_attn_node.args[7]
else:
logit_cap = source_attn_node.kwargs.get("logit_cap", None)

return [
scale, # softmax scale
logit_cap, # soft capping scale
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import triton
from triton import language as tl
from triton.language.extra.libdevice import tanh


@triton.jit
Expand Down Expand Up @@ -112,6 +113,7 @@ def gqa_attention_kv_stage1(
V_D_HEAD: tl.constexpr, # Dimension of each key/value head
SEQ_BLOCK_SIZE: tl.constexpr, # Block size used for tiling the sequence dim.
HEAD_BLOCK_SIZE: tl.constexpr, # pad to 16 if HEAD_RATIO is < 16 to invoke tensor cores.
LOGIT_CAP: tl.constexpr = None, # softcapping introduced in the Gemma 2 paper
):
"""Attention kernel to be used for generate-only batches.

Expand All @@ -126,8 +128,10 @@ def gqa_attention_kv_stage1(
1. Fetch the K-cache from 0 to input_pos
2. Fetch the V-cache from 0 to input_pos
3. A = Q*K^T [1,D_HEAD] * [1,seq_len,D_HEAD] -> [1, seq_len]
4. S = softmax(A)
5. O = S*V [1, seq_len] * [1, seq_len, D_HEAD] -> [1, D_HEAD]
4. A = A * scale
5. A = A * logit_cap if logit_cap is not None
6. S = softmax(A)
7. O = S*V [1, seq_len] * [1, seq_len, D_HEAD] -> [1, D_HEAD]
"""
# Assume KV-cache layout: [Batch, Seq, Head, Dim]
# A program is responsible for 1 batch, 1 head and a block of sequences.
Expand Down Expand Up @@ -200,6 +204,8 @@ def gqa_attention_kv_stage1(
attn = tl.dot(q, k.trans()) # [N, seq_block]
attn = attn.to(tl.float32)
attn *= SCALE
if LOGIT_CAP is not None:
attn = LOGIT_CAP * tanh(attn / LOGIT_CAP)
# Set to -inf attn values where mask is not set. This forces exp(attn) to 0.
attn = tl.where(head_mask[:, None] * seq_mask[None, :], attn, float("-inf"))
# compute max_attn only when invalid attn values are masked out.
Expand Down Expand Up @@ -573,6 +579,7 @@ def context_attention_kv_flattened(
V_D_HEAD: tl.constexpr, # Dimension of each value head.
SEQ_BLOCK: tl.constexpr,
MAX_SEQ_LENGTH: tl.constexpr,
LOGIT_CAP: tl.constexpr = None,
):
"""Kernel for context phase.

Expand Down Expand Up @@ -641,6 +648,8 @@ def context_attention_kv_flattened(
(seq_offsets[:, None] + kv_position) >= kv_seq_offsets[None, :], qk, float("-inf")
)
qk *= SCALE
if LOGIT_CAP is not None:
qk = LOGIT_CAP * tanh(qk / LOGIT_CAP)
# rowmax
m_ij = tl.maximum(tl.max(qk, 1), lse_i)
p = tl.exp(qk - m_ij[:, None])
Expand Down
6 changes: 5 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/models/decilm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

def _from_pretrained_patched(pretrained_model_name_or_path, **kwargs):
print(str(pretrained_model_name_or_path))
if re.search(r"Llama-3_(?:1|3)-Nemotron-(?:Ultra|Super)", str(pretrained_model_name_or_path)):

# Use the eager attention implementation for Gemma-2 models to import the soft logit capping ops.
if re.search(
r"Llama-3_(?:1|3)-Nemotron-(?:Ultra|Super)", str(pretrained_model_name_or_path)
) or re.search(r"gemma-2", str(pretrained_model_name_or_path), re.IGNORECASE):
kwargs["attn_implementation"] = "eager"
return _orig_from_pretrained(pretrained_model_name_or_path, **kwargs)

Expand Down
Loading