Skip to content

Commit 7fda79a

Browse files
authored
speedup sdpa_mask for mindspore (#2112)
1 parent 0e89470 commit 7fda79a

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

mindnlp/transformers/masking_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,8 @@ def sdpa_mask_older_torch(
283283
# as vmap cannot handle slicing a tensor from scalar tensor (it internally calls `.item()` which vmap does not allow
284284
# However, in more recent version of Pytorch, a trick was introduced to handle it - which is the reason we have
285285
# `sdpa_mask_recent_torch`, as it allows more general `mask_function`
286-
causal_mask = _vmap_for_bhqkv(mask_function, bh_indices=False)(None, None, cache_position, kv_arange)
286+
causal_mask = mask_function(None, None, cache_position.reshape(cache_position.shape[0], 1), kv_arange.reshape(1, kv_arange.shape[0]))
287+
# causal_mask = _vmap_for_bhqkv(mask_function, bh_indices=False)(None, None, cache_position, kv_arange)
287288
causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1)
288289
if padding_mask is not None:
289290
causal_mask = causal_mask * padding_mask[:, None, None, :]

0 commit comments

Comments
 (0)