-
Notifications
You must be signed in to change notification settings - Fork 30.2k
Description
Feature request
Problem
In
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())) |
(torch.diff(position_ids, dim=-1) >= 0).all())
would cause the result from the device tensor position_ids
to be synced to the host side.
During inference/training, it can cause serious performance degradation due to CPU blocking. see the following for an example:

Proposal
Precompute the result (torch.diff(position_ids, dim=-1) >= 0).all()) and store it in the FlashAttentionKwargs
so that we don't have to perform this device to host sync in every attention call in every layer.
The only question is whether there exists a model that cannot precompute this, i.e., the position_ids seq changes during the forward process for the same batch? Based on the fact that we are caching cu_seqlens in FlashAttentionKwargs anyway (which is equivalent to positions_ids), we can assume that?
Motivation
As stated above, this could severely degrade the out of box performance of transformers and it's usually hard for normal user to notice. And the fix would follow an existing mechanism, i.e., FlashAttentionKwargs
approach to avoid recomputation for FA needed kwargs.
Your contribution
Can prepare a PR if the team thinks the proposed approach is OK.