From d3c87aa5a36cf54f0cc3012a3360b68ceb2739ed Mon Sep 17 00:00:00 2001 From: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Date: Wed, 11 Jun 2025 23:18:49 +0000 Subject: [PATCH 1/6] Add the softcap to the triton kernel Signed-off-by: Chenghao Zhang --- .../auto_deploy/custom_ops/_triton_attention_internal.py | 2 ++ .../custom_ops/triton_kernels/attention_with_kv_cache.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/_triton_attention_internal.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/_triton_attention_internal.py index 18452d3b417..9b018f5a19f 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/_triton_attention_internal.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/_triton_attention_internal.py @@ -554,6 +554,7 @@ def fused_mha_with_cache( k_cache: torch.Tensor, v_cache: torch.Tensor, freqs_cis: Optional[torch.Tensor], + logit_cap: Optional[float] = None, ) -> torch.Tensor: """Fused MHA with cache that takes raw input from q, k, v GEMMs.""" # b, s info @@ -593,6 +594,7 @@ def fused_mha_fake( k_cache: torch.Tensor, v_cache: torch.Tensor, freqs_cis: torch.Tensor, + logit_cap: Optional[float] = None, ): return torch.empty_like(q.contiguous()) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/attention_with_kv_cache.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/attention_with_kv_cache.py index 9a59a363dc4..af9ec93a672 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/attention_with_kv_cache.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/attention_with_kv_cache.py @@ -2,6 +2,7 @@ import triton from triton import language as tl +from triton.language.extra.libdevice import tanh @triton.jit @@ -112,6 +113,7 @@ def gqa_attention_kv_stage1( V_D_HEAD: tl.constexpr, # Dimension of each key/value head SEQ_BLOCK_SIZE: tl.constexpr, # Block size used for tiling the sequence dim. HEAD_BLOCK_SIZE: tl.constexpr, # pad to 16 if HEAD_RATIO is < 16 to invoke tensor cores. + LOGIT_CAP: tl.constexpr = None, # softcapping introduced in the Gemma 2 paper ): """Attention kernel to be used for generate-only batches. @@ -200,6 +202,8 @@ def gqa_attention_kv_stage1( attn = tl.dot(q, k.trans()) # [N, seq_block] attn = attn.to(tl.float32) attn *= SCALE + if LOGIT_CAP is not None: + attn = LOGIT_CAP * tanh(attn / LOGIT_CAP) # Set to -inf attn values where mask is not set. This forces exp(attn) to 0. attn = tl.where(head_mask[:, None] * seq_mask[None, :], attn, float("-inf")) # compute max_attn only when invalid attn values are masked out. From 45d8e914f3041064f11b8a515446e959cb49e0d9 Mon Sep 17 00:00:00 2001 From: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Date: Thu, 12 Jun 2025 00:10:28 +0000 Subject: [PATCH 2/6] Add the softcap to flashinfer attention and add tests. Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> --- .../custom_ops/flashinfer_attention.py | 15 +++ .../singlegpu/custom_ops/test_attention_op.py | 84 +++++++++++++ .../test_flashinfer_attention_op.py | 119 ++++++++++++++++++ 3 files changed, 218 insertions(+) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index 6682299a656..b4d25dfc1d9 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -37,6 +37,7 @@ class PlanParams: q_dtype: torch.dtype kv_dtype: torch.dtype sm_scale: Optional[float] = None + logit_cap: Optional[float] = None causal: bool = True @@ -107,6 +108,7 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper): q_data_type=plan_params.q_dtype, kv_data_type=plan_params.kv_dtype, sm_scale=plan_params.sm_scale, + logits_soft_cap=plan_params.logit_cap, ) # we want to plan during warm-up of cuda graph capture to ensure we have the plan cached @@ -143,6 +145,7 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper): q_data_type=plan_params.q_dtype, kv_data_type=plan_params.kv_dtype, sm_scale=plan_params.sm_scale, + logits_soft_cap=plan_params.logit_cap, ) self.plan_params = plan_params @@ -250,6 +253,7 @@ def flashinfer_mha_with_cache( scale: Optional[float], k_scale: float, v_scale: float, + logit_cap: Optional[float], ) -> torch.Tensor: # reshape to standard [b*s, n_heads, head_dim] layout head_dim = k_cache.shape[-1] @@ -273,6 +277,7 @@ def flashinfer_mha_with_cache( q_dtype=q.dtype, kv_dtype=k_cache.dtype, sm_scale=scale, + logit_cap=logit_cap, ) # Assuming k_scale = v_scale = 1.0, we just have to cast k and v to fp8 before appending to kv cache @@ -327,6 +332,7 @@ def flashinfer_mha_with_cache_fake( scale: Optional[float], k_scale: float, v_scale: float, + logit_cap: Optional[float], ) -> torch.Tensor: return torch.empty_like(q.contiguous()) @@ -419,8 +425,17 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: ad_logger.warning("Provided scale is not a float. Using default scale instead.") scale = None + logit_cap = source_attn_node.kwargs.get("logit_cap", None) + if not (isinstance(logit_cap, float) or logit_cap is None): + ad_logger.debug("Provided logit_cap is not a float or None. Disabling soft-capping.") + logit_cap = None + elif isinstance(logit_cap, float) and logit_cap <= 0: + ad_logger.warning("Provided logit_cap is not positive. Disabling soft-capping.") + logit_cap = None + return [ scale, # softmax scale 1.0, # k_scale 1.0, # v_scale + logit_cap, ] diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_attention_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_attention_op.py index cfc5ac1891c..a763852bd10 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_attention_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_attention_op.py @@ -100,6 +100,90 @@ def test_gqa_op(device, dtype, n_heads, group_size, seq_len): ) +@pytest.mark.parametrize("logit_cap", [50.0]) +@pytest.mark.parametrize("group_size", [1, 4]) +@pytest.mark.parametrize("n_heads", [8]) +@pytest.mark.parametrize("dtype", ["float16", "float32"]) +@pytest.mark.parametrize("device", ["cuda"]) +def test_gqa_op_with_logit_cap(device, dtype, n_heads, group_size, logit_cap): + # This test is for generation phase, so seq_len is 1. + seq_len = 1 + BATCH_SIZE = 2 + D_HEAD = 16 + CACHE_SEQ_LEN = 8 + + dtype = getattr(torch, dtype) + n_kv_heads = n_heads // group_size + + offset = 4 # some offset + input_positions = torch.zeros(BATCH_SIZE, device=device, dtype=torch.int) + offset + + q = torch.randn(BATCH_SIZE, seq_len, n_heads, D_HEAD, dtype=dtype, device=device) + k = torch.randn(BATCH_SIZE, seq_len, n_kv_heads, D_HEAD, dtype=dtype, device=device) + v = torch.randn(BATCH_SIZE, seq_len, n_kv_heads, D_HEAD, dtype=dtype, device=device) + + # setup kv-cache + k_cache = torch.randn(BATCH_SIZE, CACHE_SEQ_LEN, n_kv_heads, D_HEAD, dtype=dtype, device=device) + v_cache = torch.randn(BATCH_SIZE, CACHE_SEQ_LEN, n_kv_heads, D_HEAD, dtype=dtype, device=device) + + # Store k,v in cache for op + k_cache_op = k_cache.clone() + v_cache_op = v_cache.clone() + + # run custom op + output = torch.ops.attention.fused_mha_with_cache( + q, k, v, input_positions, k_cache_op, v_cache_op, None, logit_cap + ) + + # for reference, we manually update the cache + k_cache[:, input_positions[0] : input_positions[0] + seq_len] = k + v_cache[:, input_positions[0] : input_positions[0] + seq_len] = v + + k_cache_ref = torch.repeat_interleave(k_cache, group_size, dim=2) # [b,s,n,d] + v_cache_ref = torch.repeat_interleave(v_cache, group_size, dim=2) # [b,s,n,d] + + # Reference implementation + q_ref = q.transpose(1, 2) + # up to `offset + 1` + k_ref = k_cache_ref[:, : offset + seq_len].transpose(1, 2) + v_ref = v_cache_ref[:, : offset + seq_len].transpose(1, 2) + + scale = 1.0 / (D_HEAD**0.5) + attn = torch.matmul(q_ref, k_ref.transpose(-2, -1)) * scale + + if logit_cap is not None: + attn = logit_cap * torch.tanh(attn / logit_cap) + + # For seq_len=1, there is no causal mask. We attend to all keys in cache up to current position. + + attn = torch.nn.functional.softmax(attn, dim=-1) + ref_out = torch.matmul(attn, v_ref) + + ref = ref_out.transpose(1, 2).contiguous().view(BATCH_SIZE, seq_len, n_heads * D_HEAD) + + # Check that op output and reference are close + assert torch.allclose( + ref.cpu().to(torch.float32), + output.cpu().to(torch.float32), + atol=1e-2, + rtol=1e-2, + ) + + # Check that cache is updated correctly by the op + assert torch.allclose( + k_cache_op.cpu(), + k_cache.cpu(), + atol=1e-5, + rtol=1e-5, + ) + assert torch.allclose( + v_cache_op.cpu(), + v_cache.cpu(), + atol=1e-5, + rtol=1e-5, + ) + + @pytest.mark.parametrize("num_generate_ratio", [0.0, 0.5, 1.0]) @pytest.mark.parametrize("max_seq_len", [0, 1, 16]) @pytest.mark.parametrize("group_size", [1, 4]) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py index 4872aef2210..cf1b41bc6e2 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py @@ -109,6 +109,7 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype, None, 1.0, 1.0, + None, # logit_cap ) ref = torch.nn.functional.scaled_dot_product_attention( @@ -234,6 +235,7 @@ def test_flashinfer_attention_op_decode( None, 1.0, 1.0, + None, # logit_cap ) assert torch.allclose( @@ -350,6 +352,7 @@ def test_flashinfer_attention_context_and_generate( None, 1.0, 1.0, + None, # logit_cap ) # Generate reference outputs @@ -425,6 +428,7 @@ def test_flashinfer_attention_context_and_generate( None, 1.0, 1.0, + None, # logit_cap ) # Generate reference outputs @@ -534,6 +538,7 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty None, 1.0, 1.0, + None, # logit_cap ) # Generate ref @@ -681,6 +686,7 @@ def test_flashinfer_attention_with_fp8_cache( None, K_SCALE, V_SCALE, + None, # logit_cap ) y = flashinfer_output.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD) @@ -778,6 +784,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de None, 1.0, 1.0, + None, # logit_cap ) # Compute reference @@ -861,6 +868,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de None, 1.0, 1.0, + None, # logit_cap ) # Compute reference @@ -886,3 +894,114 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de atol=1e-2, rtol=1e-2, ) + + +@pytest.mark.parametrize("seq_length", [64]) +@pytest.mark.parametrize("n_heads", [8]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("logit_cap", [10.0, 30.0]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("device", ["cuda"]) +def test_flashinfer_attention_op_context_with_logit_cap( + seq_length, n_heads, batch_size, logit_cap, dtype, device +): + """ + Tests the context phase of flashinfer attention with logit soft-capping. + """ + D_HEAD = 64 + MAX_SEQ_LEN = 2048 + MAX_BATCH_SIZE = 32 + DTYPE = dtype + BATCH_SIZE = batch_size + N_HEADS = n_heads + SEQ_LEN = seq_length + + # metadata + seq_len_tensor = torch.tensor([SEQ_LEN] * BATCH_SIZE, dtype=torch.int32, device=device) + offsets = torch.zeros(BATCH_SIZE, device=device, dtype=torch.int) + + qo_indptr = torch.cat( + (torch.zeros_like(seq_len_tensor[:1]), torch.cumsum(seq_len_tensor, 0)) + ).to(torch.int32) + paged_kv_indptr = torch.arange(0, batch_size + 1, dtype=torch.int32, device=device) + paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device) + paged_kv_last_page_len = offsets + seq_len_tensor + + # Q,K,V are computed using GEMM. + q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) + k = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) + v = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) + + # Setup KV Cache. KV cache is empty, context phase + k_cache = torch.zeros( + (MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD), dtype=DTYPE, device=device + ) + v_cache = torch.zeros( + (MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD), dtype=DTYPE, device=device + ) + + # make sure planner is initialized + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + _GlobalFlashInferPlanner.init_workspace(workspace) + + batch_indices, positions = flashinfer.get_batch_indices_positions( + qo_indptr, + flashinfer.get_seq_lens( + paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1] + ), + BATCH_SIZE * SEQ_LEN, + ) + flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache( + # Q, K, V + q, + k, + v, + # METADATA + qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + batch_indices, + positions, + # CACHES + k_cache, + v_cache, + # BUFFERS + workspace, + # CONSTANTS + None, + 1.0, + 1.0, + logit_cap, + ) + + # Reference implementation + q_ref = q.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD).transpose(1, 2) + k_ref = k.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD).transpose(1, 2) + v_ref = v.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD).transpose(1, 2) + + scale = D_HEAD**-0.5 + logits = torch.matmul(q_ref, k_ref.transpose(-2, -1)) * scale + + # Apply logit softcapping + if logit_cap > 0.0: + logits = logit_cap * torch.tanh(logits / logit_cap) + + # Apply causal mask + causal_mask = torch.triu( + torch.ones(SEQ_LEN, SEQ_LEN, device=device, dtype=torch.bool), diagonal=1 + ) + logits.masked_fill_(causal_mask, -float("inf")) + + # Apply softmax + attn_weights = torch.softmax(logits, dim=-1).to(v_ref.dtype) + + # Compute output + ref = (attn_weights @ v_ref).transpose(1, 2).reshape(BATCH_SIZE, SEQ_LEN, -1) + + assert torch.allclose( + flashinfer_output.cpu().to(torch.float32), + ref.cpu().to(torch.float32), + atol=1e-2, + rtol=1e-2, + ) From f1dcc350effb8b551c8d7657690194a59fd3e6d3 Mon Sep 17 00:00:00 2001 From: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Date: Thu, 12 Jun 2025 00:13:16 +0000 Subject: [PATCH 3/6] Minor change - Add log Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> --- .../_torch/auto_deploy/transformations/library/kvcache.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py index 97a4ef3fdac..c6af8788c88 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py @@ -68,6 +68,7 @@ def insert_cached_attention( if not source_attn_nodes: # If there are no nodes for kv cache insertion found, return current graph + ad_logger.info("No source attention nodes found, skipping cache insertion.") return egm # Sanity check From 7228d980ea3ac3c3b82949d4830145d7378b979a Mon Sep 17 00:00:00 2001 From: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Date: Wed, 18 Jun 2025 18:13:42 +0000 Subject: [PATCH 4/6] Add the softcap import and related transform Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> --- .../auto_deploy/custom_ops/torch_attention.py | 63 ++++++++++++++----- .../custom_ops/triton_attention.py | 27 ++++---- .../triton_kernels/attention_with_kv_cache.py | 9 ++- .../_torch/auto_deploy/models/decilm.py | 6 +- .../transformations/library/attention.py | 51 ++++++++++++--- .../auto_deploy/transformations/transform.py | 5 +- 6 files changed, 122 insertions(+), 39 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py index 6764ca3d91e..4d9131fbf11 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py @@ -40,6 +40,7 @@ def scaled_dot_product_attention( dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, + logit_cap: Optional[float] = None, ) -> torch.Tensor: """A carbon copy of torch.nn.functional.scaled_dot_product_attention as custom op. @@ -47,20 +48,50 @@ def scaled_dot_product_attention( of the vanilla sdpa in a graph. """ - return F.scaled_dot_product_attention( - query.contiguous(), - key.contiguous(), - value.contiguous(), - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - ) + # Handle soft capping by applying it manually since F.scaled_dot_product_attention + # may not support soft_cap parameter + if logit_cap is not None: + # Apply manual soft capping to the attention scores + # First compute raw attention scores + d_k = query.size(-1) + if scale is None: + scale = 1.0 / (d_k**0.5) + + # Compute attention scores + scores = torch.matmul(query, key.transpose(-2, -1)) * scale + + # Apply soft capping: tanh(scores / logit_cap) * logit_cap + scores = torch.tanh(scores / logit_cap) * logit_cap + + if attn_mask is not None: + scores += attn_mask + + # Apply softmax + attn_weights = F.softmax(scores, dim=-1) + + # Apply dropout if specified + if dropout_p > 0.0: + attn_weights = F.dropout(attn_weights, p=dropout_p, training=torch.is_grad_enabled()) + + # Apply attention to values + output = torch.matmul(attn_weights, value) + return output.contiguous() + else: + # Use standard SDPA when no soft capping + return F.scaled_dot_product_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + ) @scaled_dot_product_attention.register_fake def scaled_dot_product_attention_fake( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, logit_cap=None ): """Fake implementation of scaled_dot_product_attention.""" return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous() @@ -75,10 +106,12 @@ def grouped_sdpa( dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, + logit_cap: Optional[float] = None, ) -> torch.Tensor: """SDPA attention that can handle GQA.""" - return F.scaled_dot_product_attention( + # Use our custom scaled_dot_product_attention that supports soft capping + return scaled_dot_product_attention( query.contiguous(), key.contiguous(), value.contiguous(), @@ -86,7 +119,7 @@ def grouped_sdpa( dropout_p=dropout_p, is_causal=is_causal, scale=scale, - enable_gqa=True, + logit_cap=logit_cap, ) @@ -99,6 +132,7 @@ def grouped_sdpa_fake( dropout_p=0.0, is_causal=False, scale=None, + logit_cap=None, ): """Fake implementation of grouped SDPA.""" return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous() @@ -113,6 +147,7 @@ def bsnd_grouped_sdpa( dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, + logit_cap: Optional[float] = None, ) -> torch.Tensor: """Attention that assumes the input layout is bsnd. @@ -124,7 +159,7 @@ def bsnd_grouped_sdpa( key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() - out = grouped_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale) + out = grouped_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale, logit_cap) # let's transpose back to bnsd return out.transpose(1, 2).contiguous() @@ -132,7 +167,7 @@ def bsnd_grouped_sdpa( @bsnd_grouped_sdpa.register_fake def bsnd_grouped_sdpa_fake( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, logit_cap=None ): """Fake implementation of bnsd grouped SDPA.""" return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous() diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py index c95e1c28547..2f681a7cf90 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py @@ -40,6 +40,7 @@ def _generate_mha( cache_locs: torch.Tensor, input_pos: torch.Tensor, scale: float, + logit_cap: Optional[float], out: torch.Tensor, ): b, (n_heads, q_d_head) = q.shape[0], q.shape[-2:] @@ -55,7 +56,6 @@ def _generate_mha( stage1_output_logsumexp = torch.empty( b, n_heads, num_blocks, device=device, dtype=torch.float32 ) - float("inf") - update_kv_cache[(b, n_kv_heads, 1)]( k, v, @@ -74,13 +74,7 @@ def _generate_mha( ) HEAD_BLOCK_SIZE = max(16, triton.next_power_of_2(n_heads // n_kv_heads)) - gqa_attention_kv_stage1[ - ( - b, - n_kv_heads, - num_blocks, - ) - ]( + gqa_attention_kv_stage1[(b, n_heads, num_blocks)]( q, k_cache, v_cache, @@ -97,6 +91,7 @@ def _generate_mha( v_d_head, SEQ_BLOCK_SIZE, HEAD_BLOCK_SIZE, + LOGIT_CAP=logit_cap, ) attention_kv_stage2[(b, n_heads, 1)]( stage1_output_values, @@ -122,6 +117,7 @@ def _flattened_context_mha( seq_start: torch.Tensor, scale: float, out: torch.Tensor, + logit_cap: Optional[float], ) -> None: # NOTE: s_total == sum(seq_len) s_total, n_heads, q_d_head = q.shape @@ -166,6 +162,7 @@ def _flattened_context_mha( SEQ_BLOCK, max_cache_seq_len, num_stages=2, + LOGIT_CAP=logit_cap, ) @@ -187,6 +184,7 @@ def flattened_mha_with_cache( # # CONSTANTS scale: Optional[float], + logit_cap: Optional[float], ) -> torch.Tensor: """Flattened MHA with cache that takes q, k, v in BSND layout. @@ -223,7 +221,7 @@ def flattened_mha_with_cache( y = q.new_empty(*bs_view, num_heads, v_head_dim).contiguous() if s == 1: # generate-only phase - _generate_mha(q, k, v, k_cache, v_cache, cache_loc, input_pos, scale, y) + _generate_mha(q, k, v, k_cache, v_cache, cache_loc, input_pos, scale, logit_cap, y) else: # mixed context + generate phase _flattened_context_mha( @@ -237,7 +235,8 @@ def flattened_mha_with_cache( seq_len, seq_start, scale, - y, + out=y, + logit_cap=logit_cap, ) return y.view(*output_shape) @@ -255,6 +254,7 @@ def flattened_mha_fake( k_cache: torch.Tensor, v_cache: torch.Tensor, scale: Optional[float], + logit_cap: Optional[float], ): return q.new_empty(*q.shape[:-1], v.shape[-1]).contiguous() @@ -382,13 +382,18 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: scale = source_attn_node.args[6] else: scale = source_attn_node.kwargs.get("scale", None) - # do a sanity check on the scale if it is not None, we only support the default scale # of 1/sqrt(head_dim) and so we should do an approximate check for that one if not isinstance(scale, float): ad_logger.warning("Provided scale is not a float, Using default scale instead.") scale = None + if len(source_attn_node.args) > 7: + logit_cap = source_attn_node.args[7] + else: + logit_cap = source_attn_node.kwargs.get("logit_cap", None) + return [ scale, # softmax scale + logit_cap, # soft capping scale ] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/attention_with_kv_cache.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/attention_with_kv_cache.py index af9ec93a672..ad2e1928768 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/attention_with_kv_cache.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/attention_with_kv_cache.py @@ -128,8 +128,10 @@ def gqa_attention_kv_stage1( 1. Fetch the K-cache from 0 to input_pos 2. Fetch the V-cache from 0 to input_pos 3. A = Q*K^T [1,D_HEAD] * [1,seq_len,D_HEAD] -> [1, seq_len] - 4. S = softmax(A) - 5. O = S*V [1, seq_len] * [1, seq_len, D_HEAD] -> [1, D_HEAD] + 4. A = A * scale + 5. A = A * logit_cap if logit_cap is not None + 6. S = softmax(A) + 7. O = S*V [1, seq_len] * [1, seq_len, D_HEAD] -> [1, D_HEAD] """ # Assume KV-cache layout: [Batch, Seq, Head, Dim] # A program is responsible for 1 batch, 1 head and a block of sequences. @@ -577,6 +579,7 @@ def context_attention_kv_flattened( V_D_HEAD: tl.constexpr, # Dimension of each value head. SEQ_BLOCK: tl.constexpr, MAX_SEQ_LENGTH: tl.constexpr, + LOGIT_CAP: tl.constexpr = None, ): """Kernel for context phase. @@ -645,6 +648,8 @@ def context_attention_kv_flattened( (seq_offsets[:, None] + kv_position) >= kv_seq_offsets[None, :], qk, float("-inf") ) qk *= SCALE + if LOGIT_CAP is not None: + qk = LOGIT_CAP * tanh(qk / LOGIT_CAP) # rowmax m_ij = tl.maximum(tl.max(qk, 1), lse_i) p = tl.exp(qk - m_ij[:, None]) diff --git a/tensorrt_llm/_torch/auto_deploy/models/decilm.py b/tensorrt_llm/_torch/auto_deploy/models/decilm.py index 1a9f7368a64..d57f2fefe06 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/decilm.py +++ b/tensorrt_llm/_torch/auto_deploy/models/decilm.py @@ -7,7 +7,11 @@ def _from_pretrained_patched(pretrained_model_name_or_path, **kwargs): print(str(pretrained_model_name_or_path)) - if re.search(r"Llama-3_(?:1|3)-Nemotron-(?:Ultra|Super)", str(pretrained_model_name_or_path)): + + # Use the eager attention implementation for Gemma-2 models to import the soft logit capping ops. + if re.search( + r"Llama-3_(?:1|3)-Nemotron-(?:Ultra|Super)", str(pretrained_model_name_or_path) + ) or re.search(r"gemma-2", str(pretrained_model_name_or_path), re.IGNORECASE): kwargs["attn_implementation"] = "eager" return _orig_from_pretrained(pretrained_model_name_or_path, **kwargs) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py index 7e46bd652ce..c8e47336d2f 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py @@ -293,7 +293,8 @@ def _match_eager_attention_pattern(final_matmul_node: Node) -> Optional[Dict[str Match the eager attention pattern starting from the final matmul node. The pattern is: - transpose -> matmul -> mul/div -> (optional) add -> (optional) to -> softmax -> (optional) to -> dropout -> matmul + transpose -> matmul -> mul/div -> (optional) div -> tanh -> mul (soft capping) + -> (optional) add -> (optional) to -> softmax -> (optional) to -> dropout -> matmul Returns a dictionary with information about the match or None if no match. """ @@ -352,21 +353,51 @@ def _match_eager_attention_pattern(final_matmul_node: Node) -> Optional[Dict[str prev_node = prev_node.args[0] # Check for attention mask pattern (add node) + attn_mask = None if is_op(prev_node, torch.ops.aten.add): add_node = prev_node attn_mask = add_node.args[1] # Second arg is the mask - # The add should have a mul or div node as its first argument + # The add should have input as its first argument if len(add_node.args) < 1: return None - scaling_node = add_node.args[0] - if not (is_op(scaling_node, torch.ops.aten.mul) or is_op(scaling_node, torch.ops.aten.div)): - return None - elif is_op(prev_node, torch.ops.aten.mul) or is_op(prev_node, torch.ops.aten.div): - # No mask case - the softmax input is directly the mul or div node + prev_node = add_node.args[0] + + # Check for optional soft capping pattern: div -> tanh -> mul + logit_cap = None + if is_op(prev_node, torch.ops.aten.mul): + # Check if this mul is part of soft capping (mul after tanh) + if len(prev_node.args) >= 2: + mul_input = prev_node.args[0] + soft_cap_mul_factor = prev_node.args[1] + + # Check if the input to mul is tanh + if is_op(mul_input, torch.ops.aten.tanh): + if len(mul_input.args) >= 1: + tanh_input = mul_input.args[0] + + # Check if the input to tanh is div (completing the soft cap pattern) + if is_op(tanh_input, torch.ops.aten.div): + if len(tanh_input.args) >= 2: + div_input = tanh_input.args[0] + soft_cap_div_factor = tanh_input.args[1] + + # Verify that the div and mul factors are the same (soft cap scale) + if isinstance(soft_cap_div_factor, (float, int)) and isinstance( + soft_cap_mul_factor, (float, int) + ): + if abs(soft_cap_div_factor - soft_cap_mul_factor) < 1e-6: + logit_cap = soft_cap_div_factor + prev_node = div_input + elif soft_cap_div_factor == soft_cap_mul_factor: + # Same node/tensor used for both operations + logit_cap = soft_cap_div_factor + prev_node = div_input + + # Now prev_node should be the scaling operation (mul or div) + if is_op(prev_node, torch.ops.aten.mul) or is_op(prev_node, torch.ops.aten.div): scaling_node = prev_node - attn_mask = None else: return None @@ -422,6 +453,10 @@ def _match_eager_attention_pattern(final_matmul_node: Node) -> Optional[Dict[str if attn_mask is not None: match_info["attn_mask"] = attn_mask + # Add soft cap scale if it exists + if logit_cap is not None: + match_info["logit_cap"] = logit_cap + return match_info diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py index 921844a1f4e..2c11d759f1c 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py @@ -182,11 +182,10 @@ def __call__(self, cm: CachedSequenceInterface) -> GraphModule: from .library import visualize_namespace visualize_namespace(egm, args=cm.args, dynamic_shapes=cm.dynamic_shapes) + except ImportError: ad_logger.warning( - "Please run `pip install -r examples/auto_deploy/requirements.txt` to visualize" - " the graph." + "Please run `pip install -r examples/auto_deploy/requirements.txt` to visualize the graph." ) - except ImportError: pass ############################################################################################ From cf9023ae50d60c1f5d40cd2b803f3491c99c0f79 Mon Sep 17 00:00:00 2001 From: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com> Date: Tue, 24 Jun 2025 22:28:16 -0700 Subject: [PATCH 5/6] Extract logit_cap from the args position not the value Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com> --- .../_torch/auto_deploy/custom_ops/flashinfer_attention.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index b4d25dfc1d9..c18e6ed1529 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -425,7 +425,12 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: ad_logger.warning("Provided scale is not a float. Using default scale instead.") scale = None - logit_cap = source_attn_node.kwargs.get("logit_cap", None) + # Get logit_cap from args or kwargs - it's typically the 8th argument (index 7) + if len(source_attn_node.args) > 7: + logit_cap = source_attn_node.args[7] + else: + logit_cap = source_attn_node.kwargs.get("logit_cap", None) + if not (isinstance(logit_cap, float) or logit_cap is None): ad_logger.debug("Provided logit_cap is not a float or None. Disabling soft-capping.") logit_cap = None From f83e6744cdf8248e3d0c681d7fdef4786767efe7 Mon Sep 17 00:00:00 2001 From: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com> Date: Wed, 25 Jun 2025 13:30:30 -0700 Subject: [PATCH 6/6] Fix the minor error during the kernel call Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com> --- .../_torch/auto_deploy/custom_ops/torch_attention.py | 2 +- .../_torch/auto_deploy/custom_ops/triton_attention.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py index 4d9131fbf11..fecd1ffaac2 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py @@ -119,7 +119,7 @@ def grouped_sdpa( dropout_p=dropout_p, is_causal=is_causal, scale=scale, - logit_cap=logit_cap, + enable_gqa=True, ) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py index 2f681a7cf90..1a7a3da2069 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py @@ -56,6 +56,7 @@ def _generate_mha( stage1_output_logsumexp = torch.empty( b, n_heads, num_blocks, device=device, dtype=torch.float32 ) - float("inf") + update_kv_cache[(b, n_kv_heads, 1)]( k, v, @@ -74,7 +75,13 @@ def _generate_mha( ) HEAD_BLOCK_SIZE = max(16, triton.next_power_of_2(n_heads // n_kv_heads)) - gqa_attention_kv_stage1[(b, n_heads, num_blocks)]( + gqa_attention_kv_stage1[ + ( + b, + n_kv_heads, + num_blocks, + ) + ]( q, k_cache, v_cache, @@ -382,6 +389,7 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: scale = source_attn_node.args[6] else: scale = source_attn_node.kwargs.get("scale", None) + # do a sanity check on the scale if it is not None, we only support the default scale # of 1/sqrt(head_dim) and so we should do an approximate check for that one if not isinstance(scale, float):