Skip to content

Commit 590af89

Browse files
authored
Merge branch 'main' into kv-cache-dce
2 parents 5499de9 + 1b9781e commit 590af89

File tree

2,188 files changed

+37210
-9718
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

2,188 files changed

+37210
-9718
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ tensorrt_llm/bindings/**/*.pyi
4343
tensorrt_llm/deep_ep/
4444
tensorrt_llm/deep_ep_cpp_tllm.*.so
4545
tensorrt_llm/deep_ep_cpp_tllm.pyi
46+
tensorrt_llm/deep_gemm/
47+
tensorrt_llm/deep_gemm_cpp_tllm.*.so
48+
tensorrt_llm/deep_gemm_cpp_tllm.pyi
4649
*docs/cpp_docs*
4750
*docs/source/_cpp_gen*
4851
docs/source/**/*.rst

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,6 @@
2626
[submodule "3rdparty/cppzmq"]
2727
path = 3rdparty/cppzmq
2828
url = https://github.com/zeromq/cppzmq.git
29+
[submodule "3rdparty/DeepGEMM"]
30+
path = 3rdparty/DeepGEMM
31+
url = https://github.com/deepseek-ai/DeepGEMM.git

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ repos:
2727
args: [--allow-multiple-documents]
2828
exclude: ".*/gitlab/.*.yml"
2929
- id: trailing-whitespace
30-
exclude: '\.patch$'
30+
exclude: '\.(patch|md)$'
3131
- id: check-toml
3232
- id: mixed-line-ending
3333
args: [--fix=lf]

3rdparty/DeepGEMM

Submodule DeepGEMM added at 7b6b556

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ TensorRT-LLM
99
[![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/)
1010
[![cuda](https://img.shields.io/badge/cuda-12.9.1-green)](https://developer.nvidia.com/cuda-downloads)
1111
[![trt](https://img.shields.io/badge/TRT-10.11.0-green)](https://developer.nvidia.com/tensorrt)
12-
[![version](https://img.shields.io/badge/release-1.0.0rc6-green)](./tensorrt_llm/version.py)
12+
[![version](https://img.shields.io/badge/release-1.1.0rc0-green)](./tensorrt_llm/version.py)
1313
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)
1414

1515
[Architecture](./docs/source/torch/arch_overview.md)   |   [Performance](./docs/source/performance/perf-overview.md)   |   [Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html)   |   [Documentation](./docs/source/)   |   [Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap)
@@ -253,5 +253,5 @@ Deprecation is used to inform developers that some APIs and tools are no longer
253253
## Useful Links
254254
- [Quantized models on Hugging Face](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4): A growing collection of quantized (e.g., FP8, FP4) and optimized LLMs, including [DeepSeek FP4](https://huggingface.co/nvidia/DeepSeek-R1-FP4), ready for fast inference with TensorRT-LLM.
255255
- [NVIDIA Dynamo](https://github.com/ai-dynamo/dynamo): A datacenter scale distributed inference serving framework that works seamlessly with TensorRT-LLM.
256-
- [AutoDeploy](./examples/auto_deploy/README.md): An experimental backend for TensorRT-LLM to simplify and accelerate the deployment of PyTorch models.
256+
- [AutoDeploy](./examples/auto_deploy/README.md): A prototype backend for TensorRT-LLM to simplify and accelerate the deployment of PyTorch models.
257257
- [WeChat Discussion Group](https://github.com/NVIDIA/TensorRT-LLM/issues/5359): A real-time channel for TensorRT-LLM Q&A and news.

cpp/CMakeLists.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ option(BUILD_PYT "Build in PyTorch TorchScript class mode" ON)
3131
option(BUILD_TESTS "Build Google tests" ON)
3232
option(BUILD_BENCHMARKS "Build benchmarks" ON)
3333
option(BUILD_DEEP_EP "Build the Deep EP module" ON)
34+
option(BUILD_DEEP_GEMM "Build the DeepGEMM module" ON)
3435
option(BUILD_MICRO_BENCHMARKS "Build C++ micro benchmarks" OFF)
3536
option(NVTX_DISABLE "Disable all NVTX features" ON)
3637
option(WARNING_IS_ERROR "Treat all warnings as errors" OFF)
@@ -199,7 +200,9 @@ set(TRT_LIB TensorRT::NvInfer)
199200
get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR} PATH)
200201

201202
set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty)
202-
if(BINDING_TYPE STREQUAL "pybind" OR BUILD_DEEP_EP)
203+
if(BINDING_TYPE STREQUAL "pybind"
204+
OR BUILD_DEEP_EP
205+
OR BUILD_DEEP_GEMM)
203206
add_subdirectory(${3RDPARTY_DIR}/pybind11
204207
${CMAKE_CURRENT_BINARY_DIR}/pybind11)
205208
endif()
@@ -218,7 +221,9 @@ include_directories(
218221
${3RDPARTY_DIR}/cutlass/tools/util/include
219222
${3RDPARTY_DIR}/NVTX/include
220223
${3RDPARTY_DIR}/json/include)
221-
if(BINDING_TYPE STREQUAL "pybind" OR BUILD_DEEP_EP)
224+
if(BINDING_TYPE STREQUAL "pybind"
225+
OR BUILD_DEEP_EP
226+
OR BUILD_DEEP_GEMM)
222227
include_directories(${3RDPARTY_DIR}/pybind11/include)
223228
endif()
224229
if(BINDING_TYPE STREQUAL "nanobind")

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2027,7 +2027,7 @@ class GenericLlmRequest
20272027

20282028
// Scatter the input tokens to other beam
20292029
mTokens = BeamTokens(mSamplingConfig.beamWidth, inputTokens);
2030-
mLastTokens = VecTokens(mSamplingConfig.beamWidth);
2030+
mLastTokens = VecTokens(mSamplingConfig.beamWidth, inputTokens.back());
20312031

20322032
// Init mUniqueTokens
20332033
VecUniqueTokens uniqueTokens{inputTokens.size()};
@@ -2347,6 +2347,9 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
23472347
void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager);
23482348

23492349
void moveLoraWeightsToGpu(runtime::BufferManager const& manager);
2350+
2351+
// Remove LoRA weights and LoRA config tensors
2352+
void removeLoraTensors();
23502353
};
23512354

23522355
} // namespace tensorrt_llm::batch_manager

cpp/include/tensorrt_llm/common/quantization.h

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,16 @@ class QuantMode
122122
return QuantMode(BaseType(1u) << 14);
123123
}
124124

125+
static constexpr QuantMode w4a8Mxfp4Mxfp8() noexcept
126+
{
127+
return QuantMode(BaseType(1u) << 15);
128+
}
129+
130+
static constexpr QuantMode w4a16Mxfp4() noexcept
131+
{
132+
return QuantMode(BaseType(1u) << 16);
133+
}
134+
125135
constexpr BaseType value() const noexcept
126136
{
127137
return mValue;
@@ -202,14 +212,25 @@ class QuantMode
202212
return isSet(w4a8Mxfp4Fp8());
203213
}
204214

215+
constexpr bool hasW4a8Mxfp4Mxfp8() const noexcept
216+
{
217+
return isSet(w4a8Mxfp4Mxfp8());
218+
}
219+
220+
constexpr bool hasW4a16Mxfp4() const noexcept
221+
{
222+
return isSet(w4a16Mxfp4());
223+
}
224+
205225
constexpr bool hasKvCacheQuant() const noexcept
206226
{
207227
return hasInt8KvCache() || hasFp8KvCache() || hasFp4KvCache();
208228
}
209229

210230
static constexpr QuantMode fromDescription(bool quantizeWeights, bool quantizeActivations, bool perToken,
211231
bool perChannel, bool perGroup, bool useInt4Weights, bool useInt8KvCache, bool useFp8KvCache, bool useFp8Qdq,
212-
bool useFp8RowWise, bool useW4a8QServe, bool useFp4Quant, bool useFp8BlockScales, bool useW4a8Mxfp4Fp8)
232+
bool useFp8RowWise, bool useW4a8QServe, bool useFp4Quant, bool useFp8BlockScales, bool useW4a8Mxfp4Fp8,
233+
bool useW4a8Mxfp4Mxfp8, bool useW4a16Mxfp4)
213234
{
214235
QuantMode quantMode{};
215236
if (quantizeWeights)
@@ -278,25 +299,35 @@ class QuantMode
278299
quantMode += w4a8Mxfp4Fp8();
279300
}
280301

302+
if (useW4a8Mxfp4Mxfp8)
303+
{
304+
quantMode += w4a8Mxfp4Mxfp8();
305+
}
306+
307+
if (useW4a16Mxfp4)
308+
{
309+
quantMode += w4a16Mxfp4();
310+
}
311+
281312
return quantMode;
282313
}
283314

284315
static constexpr QuantMode useSmoothQuant(bool perToken = false, bool perChannel = false)
285316
{
286-
return fromDescription(
287-
true, true, perToken, perChannel, false, false, false, false, false, false, false, false, false, false);
317+
return fromDescription(true, true, perToken, perChannel, false, false, false, false, false, false, false, false,
318+
false, false, false, false);
288319
}
289320

290321
static constexpr QuantMode useQServe(bool perGroup)
291322
{
292-
return fromDescription(
293-
true, true, false, false, perGroup, true, false, false, false, false, true, false, false, false);
323+
return fromDescription(true, true, false, false, perGroup, true, false, false, false, false, true, false, false,
324+
false, false, false);
294325
}
295326

296327
static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false)
297328
{
298329
return fromDescription(true, false, false, false, perGroup, useInt4Weights, false, false, false, false, false,
299-
false, false, false);
330+
false, false, false, false, false);
300331
}
301332

302333
static QuantMode const fromQuantAlgo(
@@ -353,28 +384,38 @@ class QuantMode
353384
}
354385
else if (quantAlgo == "FP8")
355386
{
356-
quantMode = fromDescription(
357-
false, false, false, false, false, false, false, false, true, false, false, false, false, false);
387+
quantMode = fromDescription(false, false, false, false, false, false, false, false, true, false, false,
388+
false, false, false, false, false);
358389
}
359390
else if (quantAlgo == "FP8_ROWWISE")
360391
{
361-
quantMode = fromDescription(
362-
false, false, true, true, false, false, false, false, false, true, false, false, false, false);
392+
quantMode = fromDescription(false, false, true, true, false, false, false, false, false, true, false, false,
393+
false, false, false, false);
363394
}
364395
else if (quantAlgo == "FP4")
365396
{
366-
quantMode = fromDescription(
367-
false, false, false, false, false, false, false, false, false, false, false, true, false, false);
397+
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
398+
true, false, false, false, false);
368399
}
369400
else if (quantAlgo == "FP8_BLOCK_SCALES")
370401
{
371-
quantMode = fromDescription(
372-
false, false, false, false, false, false, false, false, false, false, false, false, true, false);
402+
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
403+
false, true, false, false, false);
373404
}
374405
else if (quantAlgo == "W4A8_MXFP4_FP8")
375406
{
376-
quantMode = fromDescription(
377-
false, false, false, false, false, false, false, false, false, false, false, false, false, true);
407+
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
408+
false, false, true, false, false);
409+
}
410+
else if (quantAlgo == "W4A8_MXFP4_MXFP8")
411+
{
412+
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
413+
false, false, false, true, false);
414+
}
415+
else if (quantAlgo == "W4A16_MXFP4")
416+
{
417+
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
418+
false, false, false, false, true);
378419
}
379420

380421
if (kvCacheQuantAlgo == "INT8")

cpp/kernels/fmha_v2/fmha_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def getSMVersion():
5050
ids=["fp16", "bf16", "fp16-fp32", "e4m3"])
5151
@pytest.mark.parametrize('flag', [
5252
"-s-q 128 -paged-kv", "-s-q 63 -paged-kv", "-paged-kv",
53-
"-softcapping-scale-bmm1 30", "-contiguous-q-kv"
53+
"-softcapping-scale-bmm1 30", "-contiguous-q-kv", "-use-attention-sinks"
5454
])
5555
@pytest.mark.parametrize('tiled_kernel', ["", "-force-non-tiled"])
5656
def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel):
@@ -117,8 +117,8 @@ def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel):
117117
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -custom-mask -gqa 2 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
118118
shell=True,
119119
check=True)
120-
# alibi and softcapping-scale-bmm1 are mutually exclusive.
121-
if '-softcapping-scale-bmm1' not in flag:
120+
# alibi doesn't work with softcapping-scale-bmm1/use-attention-sinks.
121+
if '-softcapping-scale-bmm1' not in flag and '-use-attention-sinks' not in flag:
122122
subprocess.run(
123123
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -alibi -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
124124
shell=True,

cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,6 @@ struct Compute
326326
uint32_t smem_v = __cvta_generic_to_shared(&shared->smem_v[0]);
327327
Compute_tile_o ctile_o(0, smem_v);
328328

329-
// BMM2 epilogue
330-
Tile_o_epilogue tile_o_epilogue(params);
331-
332329
// Mutex between two compute groups.
333330
OrderedMutexAccessor mutex_accessor(shared->compute_mutex, warpgroup_id, SYNC_BARRIER);
334331
// Notify warpgroup 0 to execute HGMMA first (overlap HGMMA and Softmax Math Instructions).
@@ -368,6 +365,9 @@ struct Compute
368365
sage_scale_row = head_info.bidb * params.h + head_info.bidh;
369366
}
370367

368+
// BMM2 epilogue
369+
Tile_o_epilogue tile_o_epilogue(params, head_info);
370+
371371
int q_step_idx = warpgroup_id;
372372

373373
// Compute work.
@@ -490,7 +490,7 @@ struct Compute
490490
if (valid_run)
491491
{
492492
// Final step's update.
493-
tile_o_epilogue.scale(ctile_o, p_sum);
493+
tile_o_epilogue.scale(ctile_o, p_max, p_sum);
494494
// Store o_tile to gmem.
495495
gmem_o.store(ctile_o.acc_);
496496
}

0 commit comments

Comments
 (0)