Skip to content

Commit 14b36e0

Browse files
authored
[TRTLLM-6174][feat] Enable FP32 mamba ssm cache (#6574)
Signed-off-by: Shahar Mor <[email protected]>
1 parent 199f306 commit 14b36e0

File tree

16 files changed

+148
-41
lines changed

16 files changed

+148
-41
lines changed

tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ def __init__(
147147
quant_config=config.get_quant_config(),
148148
allreduce_strategy=config.allreduce_strategy)
149149

150+
self._mamba_ssm_cache_dtype = config.quant_config.mamba_ssm_cache_dtype
151+
150152
def forward(
151153
self,
152154
hidden_states: torch.Tensor,
@@ -230,6 +232,7 @@ def forward(
230232
seq_idx=seq_idx,
231233
return_varlen_states=True,
232234
return_final_states=False,
235+
mamba_ssm_cache_dtype=self._mamba_ssm_cache_dtype,
233236
)
234237
out.append(rearrange(y, "b l h p -> (b l) (h p)"))
235238

tensorrt_llm/_torch/modules/mamba/ssd_combined.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
# See the License for the specific language governing permissions and
1717
# limitations under the License.
1818

19+
from typing import Optional
20+
1921
import torch
2022
from einops import rearrange
2123

@@ -43,6 +45,7 @@ def _mamba_chunk_scan_combined_fwd(
4345
cu_seqlens=None,
4446
dt_softplus=False,
4547
dt_limit=(0.0, float("inf")),
48+
mamba_ssm_cache_dtype=None,
4649
):
4750
batch, seqlen, nheads, headdim = x.shape
4851
_, _, ngroups, dstate = B.shape
@@ -120,7 +123,7 @@ def _mamba_chunk_scan_combined_fwd(
120123
if initial_states is not None else None),
121124
seq_idx=seq_idx,
122125
chunk_size=chunk_size,
123-
out_dtype=C.dtype,
126+
out_dtype=mamba_ssm_cache_dtype or C.dtype,
124127
is_cont_batched=cu_seqlens is not None)
125128
states, final_states = [
126129
rearrange(t, "... (p n) -> ... p n", n=dstate)
@@ -174,24 +177,26 @@ def _mamba_chunk_scan_combined_fwd(
174177
return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
175178

176179

177-
def mamba_chunk_scan_combined(x,
178-
dt,
179-
A,
180-
B,
181-
C,
182-
chunk_size,
183-
D=None,
184-
z=None,
185-
dt_bias=None,
186-
initial_states=None,
187-
seq_idx=None,
188-
chunk_indices=None,
189-
chunk_offsets=None,
190-
cu_seqlens=None,
191-
dt_softplus=False,
192-
dt_limit=(0.0, float("inf")),
193-
return_final_states=False,
194-
return_varlen_states=False):
180+
def mamba_chunk_scan_combined(
181+
x,
182+
dt,
183+
A,
184+
B,
185+
C,
186+
chunk_size,
187+
D=None,
188+
z=None,
189+
dt_bias=None,
190+
initial_states=None,
191+
seq_idx=None,
192+
chunk_indices=None,
193+
chunk_offsets=None,
194+
cu_seqlens=None,
195+
dt_softplus=False,
196+
dt_limit=(0.0, float("inf")),
197+
return_final_states=False,
198+
return_varlen_states=False,
199+
mamba_ssm_cache_dtype: Optional[torch.dtype] = None):
195200
"""
196201
Argument:
197202
x: (batch, seqlen, nheads, headdim)
@@ -207,6 +212,7 @@ def mamba_chunk_scan_combined(x,
207212
seq_idx: (batch, seqlen)
208213
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
209214
dt_softplus: Whether to apply softplus to dt
215+
mamba_ssm_cache_dtype: torch.dtype, default to None
210216
Return:
211217
out: (batch, seqlen, nheads, headdim)
212218
"""
@@ -231,7 +237,8 @@ def mamba_chunk_scan_combined(x,
231237
chunk_offsets=chunk_offsets,
232238
cu_seqlens=cu_seqlens,
233239
dt_softplus=dt_softplus,
234-
dt_limit=dt_limit)
240+
dt_limit=dt_limit,
241+
mamba_ssm_cache_dtype=mamba_ssm_cache_dtype)
235242
if not return_varlen_states:
236243
return out if not return_final_states else (out, final_states)
237244
else:

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ def _create_kv_cache_manager(
330330
mamba_layer_mask = [
331331
char == "M" for char in config.hybrid_override_pattern
332332
]
333+
333334
kv_cache_manager = MambaHybridCacheManager(
334335
# mamba cache parameters
335336
config.ssm_state_size,
@@ -340,6 +341,8 @@ def _create_kv_cache_manager(
340341
mamba_num_layers,
341342
mamba_layer_mask,
342343
config.torch_dtype,
344+
model_engine.model.model_config.quant_config.
345+
mamba_ssm_cache_dtype,
343346
# kv cache parameters
344347
executor_config.kv_cache_config,
345348
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,

tensorrt_llm/_torch/pyexecutor/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class PyTorchConfig:
6464
"""
6565

6666
kv_cache_dtype: str = "auto"
67+
mamba_ssm_cache_dtype: str = "auto"
68+
6769
enable_iter_perf_stats: bool = False
6870
# If true, enables per request stats per iteration
6971
# Must also set enable_iter_perf_stats to true to get request stats

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
get_num_extra_kv_tokens, update_spec_config_from_model_config)
2323
from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP
2424
from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
25-
torch_dtype_to_str, trace_func)
25+
str_dtype_to_torch, torch_dtype_to_str,
26+
trace_func)
2627
from tensorrt_llm.inputs.multimodal import (MultimodalParams,
2728
MultimodalRuntimeData)
2829
from tensorrt_llm.logger import logger
@@ -98,6 +99,16 @@ def warmup(self, resource_manager: ResourceManager) -> None:
9899
_VALID_KV_CACHE_DTYPES = ("fp8", "auto")
99100

100101

102+
def validate_and_set_mamba_ssm_cache_dtype(config: ModelConfig,
103+
mamba_ssm_cache_dtype: str) -> None:
104+
if mamba_ssm_cache_dtype == "auto":
105+
mamba_ssm_cache_dtype = config.pretrained_config.torch_dtype
106+
else:
107+
mamba_ssm_cache_dtype = str_dtype_to_torch(mamba_ssm_cache_dtype)
108+
109+
config.quant_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype
110+
111+
101112
def validate_and_set_kv_cache_quant(model_config: ModelConfig,
102113
pyt_kv_cache_dtype: str) -> QuantAlgo:
103114
logger.info(
@@ -1022,6 +1033,9 @@ def _load_model(self,
10221033

10231034
validate_and_set_kv_cache_quant(
10241035
config, self.pytorch_backend_config.kv_cache_dtype)
1036+
validate_and_set_mamba_ssm_cache_dtype(
1037+
config, self.pytorch_backend_config.mamba_ssm_cache_dtype)
1038+
10251039
num_layers = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM", "0"))
10261040
if num_layers > 0:
10271041
config.pretrained_config.num_hidden_layers = num_layers

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -939,9 +939,12 @@ def __init__(
939939
max_batch_size: int,
940940
mapping: Mapping,
941941
dtype: torch.dtype,
942+
ssm_cache_dtype: torch.dtype,
942943
layer_mask: Optional[List[bool]] = None,
943944
) -> None:
944945

946+
self.mamba_ssm_cache_dtype = ssm_cache_dtype
947+
945948
# get tp size
946949
tp_size = mapping.tp_size
947950

@@ -993,7 +996,7 @@ def __init__(
993996
head_dim,
994997
d_state,
995998
],
996-
dtype=dtype,
999+
dtype=self.mamba_ssm_cache_dtype,
9971000
device=device,
9981001
)
9991002

@@ -1051,6 +1054,9 @@ def get_ssm_states(self, layer_idx: int) -> torch.Tensor:
10511054
layer_offset = self.mamba_layer_offsets[layer_idx]
10521055
return self.ssm_states[layer_offset]
10531056

1057+
def get_mamba_ssm_cache_dtype(self) -> torch.dtype:
1058+
return self.mamba_ssm_cache_dtype
1059+
10541060
def shutdown(self):
10551061
# release tensor memory, keeping python references as tensors
10561062
self.conv_states = torch.tensor([])
@@ -1072,6 +1078,8 @@ def __init__(
10721078
mamba_num_layers: int,
10731079
mamba_layer_mask: List[bool],
10741080
mamba_cache_dtype: torch.dtype,
1081+
mamba_ssm_cache_dtype: torch.dtype,
1082+
10751083
# kv cache parameters
10761084
kv_cache_config: KvCacheConfigCpp,
10771085
kv_cache_type: CacheTypeCpp,
@@ -1105,6 +1113,7 @@ def __init__(
11051113
max_batch_size,
11061114
mapping,
11071115
mamba_cache_dtype,
1116+
mamba_ssm_cache_dtype,
11081117
mamba_layer_mask,
11091118
)
11101119

tensorrt_llm/bench/benchmark/low_latency.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@
5656
default=.90,
5757
help="The percentage of memory to use for KV Cache after model load.",
5858
)
59+
@optgroup.option(
60+
"--mamba_ssm_cache_dtype",
61+
type=click.Choice(["auto", "float16", "bfloat16", "float32"]),
62+
default="auto",
63+
help="Data type for Mamba SSM cache. If 'auto', inferred from model config.",
64+
)
5965
@optgroup.option(
6066
"--max_seq_len",
6167
type=int,

tensorrt_llm/bench/benchmark/throughput.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@
8484
default=.90,
8585
help="The percentage of memory to use for KV Cache after model load.",
8686
)
87+
@optgroup.option(
88+
"--mamba_ssm_cache_dtype",
89+
type=click.Choice(["auto", "float16", "bfloat16", "float32"]),
90+
default="auto",
91+
help="Data type for Mamba SSM cache. If 'auto', inferred from model config.",
92+
)
8793
@optgroup.group(
8894
"Engine Input Configuration",
8995
help="Input configuration for driving the engine.",

tensorrt_llm/bench/benchmark/utils/general.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
validate_and_set_kv_cache_quant
1313
from tensorrt_llm.bench.build.build import (get_benchmark_engine_settings,
1414
get_model_config)
15+
from tensorrt_llm.bench.build.dataclasses import NemotronHybridConfig
1516
from tensorrt_llm.bench.dataclasses.general import (DatasetMetadata,
1617
InferenceRequest)
1718
from tensorrt_llm.logger import logger
@@ -88,6 +89,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
8889
enable_chunked_prefill = params.get("enable_chunked_prefill", False)
8990

9091
kv_cache_dtype = "auto"
92+
mamba_ssm_cache_dtype = params.get("mamba_ssm_cache_dtype", "auto")
9193
kv_cache_config = {}
9294
if extra_llm_api_options:
9395
with open(extra_llm_api_options, 'r') as f:
@@ -96,6 +98,8 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
9698
"dtype": "auto",
9799
})
98100
kv_cache_dtype = kv_cache_config.get("dtype", "auto")
101+
mamba_ssm_cache_dtype = kv_cache_config.get("mamba_ssm_cache_dtype",
102+
mamba_ssm_cache_dtype)
99103

100104
enable_chunked_prefill = llm_args_dict.get("enable_chunked_prefill",
101105
enable_chunked_prefill)
@@ -115,6 +119,9 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
115119
else:
116120
model_config = get_model_config(model, model_path)
117121

122+
if isinstance(model_config, NemotronHybridConfig):
123+
model_config.set_mamba_ssm_cache_dtype(mamba_ssm_cache_dtype)
124+
118125
from tensorrt_llm._torch.model_config import ModelConfig
119126
model = model_path or model
120127
tllm_model_config = ModelConfig.from_pretrained(model,
@@ -161,6 +168,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
161168
}
162169

163170
kv_cache_config["dtype"] = kv_cache_dtype
171+
kv_cache_config["mamba_ssm_cache_dtype"] = mamba_ssm_cache_dtype
164172

165173
pyt_options = {
166174
"cuda_graph_config": cuda_graph_config,

tensorrt_llm/bench/build/dataclasses.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ class NemotronHybridConfig(ModelConfig):
223223
mamba_head_dim: int
224224
d_inner: Optional[int] = Field(default=None)
225225
num_mamba_layers: Optional[int] = Field(default=None)
226+
mamba_ssm_cache_dtype: Optional[str] = Field(default="auto")
226227

227228
@model_validator(mode="after")
228229
def set_values_if_none(self):
@@ -248,3 +249,6 @@ def extra_model_cache_in_gb(self, bytes_per_elem, target_seq_len=None):
248249
def cache_memory_fraction(self, cache_memory_fraction):
249250
# Each mamba cache entry is pretty large (~50MB for 8B model), so we are more conservative when estimating the max batch size
250251
return cache_memory_fraction**2
252+
253+
def set_mamba_ssm_cache_dtype(self, mamba_ssm_cache_dtype: str):
254+
self.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype

0 commit comments

Comments
 (0)