Skip to content

Conversation

nvchenghaoz
Copy link

@nvchenghaoz nvchenghaoz commented Jun 24, 2025

Description

This PR enables the full support for logit softcapping. The main change includes:

  1. Use the eager mode to import the logit softcapping related operators like tanh and div. Related file: tensorrt_llm/_torch/auto_deploy/models/decilm.py
  2. Make sure the pattern matching could support the logit softcapping structure and as an optional route.
  3. Update the torch attention to support logit softcapping.
  4. Minor updates to make sure the logit softcapping can be used by the fast kernels.

Test Coverage

Flash Infer Output

2025-06-24 22:08:52,873 - INFO - flashinfer.jit: Loading JIT ops: batch_prefill_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_qk_256_head_dim_vo_256_posenc_0_use_swa_False_use_logits_cap_True_f16qk_False
2025-06-24 22:08:52,901 - INFO - flashinfer.jit: Finished loading JIT ops: page
2025-06-24 22:08:52,934 - INFO - flashinfer.jit: Loading JIT ops: batch_prefill_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_qk_256_head_dim_vo_256_posenc_0_use_swa_False_use_logits_cap_True_f16qk_False
2025-06-24 22:09:12,674 - INFO - flashinfer.jit: Finished loading JIT ops: batch_prefill_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_qk_256_head_dim_vo_256_posenc_0_use_swa_False_use_logits_cap_True_f16qk_False
2025-06-24 22:09:12,726 - INFO - flashinfer.jit: Finished loading JIT ops: batch_prefill_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_qk_256_head_dim_vo_256_posenc_0_use_swa_False_use_logits_cap_True_f16qk_False
Processed requests: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:55<00:00, 27.74s/it]
[06/24/2025-22:09:16] [TRT-LLM AUTO-DEPLOY] [I] [PROMPT 0] How big is the universe? :

We don't know!

Here's why:

  • The universe is expanding: Space itself is stretching, meaning whatever objects are in space are moving farther apart.
  • We can only see so far: The light from the most distant galaxies hasn't had time to reach us yet. It's waiting to be discovered, adding to the mystery of the universe's size.
  • The observable universe: The observable universe is the area we

Triton Output

Processed requests: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:31<00:00, 15.82s/it]
[06/25/2025-13:28:59] [TRT-LLM AUTO-DEPLOY] [I] [PROMPT 0] How big is the universe? :

There's no easy answer to this, because we can't directly observe the edge of the universe. Here's what we do know:

  • The Observable Universe: This is the portion of the universe we can currently see due to the finite speed of light. It has a radius of about 93 billion light-years. That means if you sent a spaceship light-speed, it would take 93 billion years to reach it!

[06/25/2025-13:28:59] [TRT-LLM AUTO-DEPLOY] [I] [PROMPT 1] In simple words and in a single sentence, explain the concept of gravity: :

Gravity is the force that pulls objects towards each other.

You can also elaborate on that statement and explain different aspects:

Gravity's pull gets weaker the further away an object is from a massive object, like Earth. That's why the Moon orbits the Earth, and why a feather falls to the ground while a bowling ball stays put.

Here are some potential titles for a paragraph explaining gravity:

  • The Universal Force That Binds Us
  • The Ultimate "Sticky"

@nvchenghaoz nvchenghaoz changed the base branch from feat/ad_2025_06_13 to feat/ad-2025-06-24 June 24, 2025 23:41
@nvchenghaoz
Copy link
Author

This PR needs some refinement as the output is weird, I will dig into the root cause tomorrow.

@nvchenghaoz nvchenghaoz force-pushed the chenghao/softcap-graph branch from 7164745 to f83e674 Compare June 25, 2025 20:32
Comment on lines 93 to +94
def scaled_dot_product_attention_fake(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, logit_cap=None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nvchenghaoz, how about we wait until we merge this PR until we have decided on how we will support more attention features/arguments?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants