Skip to content

Conversation

@djsaunde
Copy link

@djsaunde djsaunde commented Jun 14, 2025

What does this PR do? Please describe:
Adding flex attention as a default_sdpa option, plus a BlockMaskCache implementation in order to re-use block masks (expensive to compute) across layers / training steps.

Running:

fairseq2 lm instruction_finetune path/to/output --config model.name=llama3_2_1b common.torch.default_sdpa=flex

On the gsm8k dataset on a single H100 gives the below tensorboard visualizations.

NLL comparison:

image

Losses differ by at most ~0.0002 over the first 300 training steps. This is true with or without utilizing the block mask cache.

Throughput (elements per second):

image

TorchSDPA (gray) vs. FlexSDPA with block mask caching (orange) vs. FlexSDPA without (green) are compared above. TorchSDPA is fastest, with FlexSDPA with caching being about 65% as fast, and without, about 40% (!) as fast.

So, we lose quite a bit of speed here, at least in the padded batches setting. AFAICT, this is because flex attention, at least as of torch 2.6.0 (?), needs to recompile each time it runs when running batch-dependent mask functions (see e.g.: pytorch/pytorch#136196).

I expect the setting with packed sequences of a fixed length to be the most performant, since we only have to create a single block mask and can re-use it thereafter, instead of creating new ones whenever the sequence lengths change. In fact, torchtune seems to use flex attention only in the sample packing case for similar reasons.

Losses also match on a handful of other tested configurations.

Not supported by this PR:

  • Dropout (apparently quite tricky to do efficiently)
  • ALiBi
  • Possibly others

Does your PR introduce any breaking changes? If yes, please list them:
N/A

Check list:

  • Was the content of this PR discussed and approved via a GitHub issue? (no need for typos or documentation improvements)
  • Did you read the contributor guideline?
  • Did you make sure that your PR does only one thing instead of bundling different changes together? (note: sort of? the block mask caching makes this attention implementation more performant)
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (no need for typos, documentation, or minor internal changes)

@djsaunde djsaunde self-assigned this Jun 14, 2025
@djsaunde djsaunde requested a review from cbalioglu as a code owner June 14, 2025 22:36
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 14, 2025
@djsaunde djsaunde marked this pull request as draft June 14, 2025 23:11
@djsaunde djsaunde marked this pull request as ready for review June 15, 2025 01:15
@djsaunde djsaunde force-pushed the flex-attn-block-mask-cache branch from 01864e4 to e464d85 Compare June 17, 2025 13:53
@meta-cla
Copy link

meta-cla bot commented Aug 26, 2025

Hi @djsaunde!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants