Skip to content

Remove device to host sync triggered in _flash_attention_forward #39213

@piyifan123

Description

@piyifan123

Feature request

Problem

In

and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
, the condition check (torch.diff(position_ids, dim=-1) >= 0).all()) would cause the result from the device tensor position_ids to be synced to the host side.

During inference/training, it can cause serious performance degradation due to CPU blocking. see the following for an example:

Image

Proposal

Precompute the result (torch.diff(position_ids, dim=-1) >= 0).all()) and store it in the FlashAttentionKwargs so that we don't have to perform this device to host sync in every attention call in every layer.

The only question is whether there exists a model that cannot precompute this, i.e., the position_ids seq changes during the forward process for the same batch? Based on the fact that we are caching cu_seqlens in FlashAttentionKwargs anyway (which is equivalent to positions_ids), we can assume that?

Motivation

As stated above, this could severely degrade the out of box performance of transformers and it's usually hard for normal user to notice. And the fix would follow an existing mechanism, i.e., FlashAttentionKwargs approach to avoid recomputation for FA needed kwargs.

Your contribution

Can prepare a PR if the team thinks the proposed approach is OK.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions