16
16
# See the License for the specific language governing permissions and
17
17
# limitations under the License.
18
18
19
+ from typing import Optional
20
+
19
21
import torch
20
22
from einops import rearrange
21
23
@@ -43,6 +45,7 @@ def _mamba_chunk_scan_combined_fwd(
43
45
cu_seqlens = None ,
44
46
dt_softplus = False ,
45
47
dt_limit = (0.0 , float ("inf" )),
48
+ mamba_ssm_cache_dtype = None ,
46
49
):
47
50
batch , seqlen , nheads , headdim = x .shape
48
51
_ , _ , ngroups , dstate = B .shape
@@ -120,7 +123,7 @@ def _mamba_chunk_scan_combined_fwd(
120
123
if initial_states is not None else None ),
121
124
seq_idx = seq_idx ,
122
125
chunk_size = chunk_size ,
123
- out_dtype = C .dtype ,
126
+ out_dtype = mamba_ssm_cache_dtype or C .dtype ,
124
127
is_cont_batched = cu_seqlens is not None )
125
128
states , final_states = [
126
129
rearrange (t , "... (p n) -> ... p n" , n = dstate )
@@ -174,24 +177,26 @@ def _mamba_chunk_scan_combined_fwd(
174
177
return out , out_x , dt , dA_cumsum , states , final_states , varlen_states
175
178
176
179
177
- def mamba_chunk_scan_combined (x ,
178
- dt ,
179
- A ,
180
- B ,
181
- C ,
182
- chunk_size ,
183
- D = None ,
184
- z = None ,
185
- dt_bias = None ,
186
- initial_states = None ,
187
- seq_idx = None ,
188
- chunk_indices = None ,
189
- chunk_offsets = None ,
190
- cu_seqlens = None ,
191
- dt_softplus = False ,
192
- dt_limit = (0.0 , float ("inf" )),
193
- return_final_states = False ,
194
- return_varlen_states = False ):
180
+ def mamba_chunk_scan_combined (
181
+ x ,
182
+ dt ,
183
+ A ,
184
+ B ,
185
+ C ,
186
+ chunk_size ,
187
+ D = None ,
188
+ z = None ,
189
+ dt_bias = None ,
190
+ initial_states = None ,
191
+ seq_idx = None ,
192
+ chunk_indices = None ,
193
+ chunk_offsets = None ,
194
+ cu_seqlens = None ,
195
+ dt_softplus = False ,
196
+ dt_limit = (0.0 , float ("inf" )),
197
+ return_final_states = False ,
198
+ return_varlen_states = False ,
199
+ mamba_ssm_cache_dtype : Optional [torch .dtype ] = None ):
195
200
"""
196
201
Argument:
197
202
x: (batch, seqlen, nheads, headdim)
@@ -207,6 +212,7 @@ def mamba_chunk_scan_combined(x,
207
212
seq_idx: (batch, seqlen)
208
213
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
209
214
dt_softplus: Whether to apply softplus to dt
215
+ mamba_ssm_cache_dtype: torch.dtype, default to None
210
216
Return:
211
217
out: (batch, seqlen, nheads, headdim)
212
218
"""
@@ -231,7 +237,8 @@ def mamba_chunk_scan_combined(x,
231
237
chunk_offsets = chunk_offsets ,
232
238
cu_seqlens = cu_seqlens ,
233
239
dt_softplus = dt_softplus ,
234
- dt_limit = dt_limit )
240
+ dt_limit = dt_limit ,
241
+ mamba_ssm_cache_dtype = mamba_ssm_cache_dtype )
235
242
if not return_varlen_states :
236
243
return out if not return_final_states else (out , final_states )
237
244
else :
0 commit comments