Skip to content

Switch to sdpa_kernel api with newer torch version #34411

@mobicham

Description

@mobicham

System Info

  • transformers version: 4.45.2
  • Platform: Linux-5.15.0-89-generic-x86_64-with-glibc2.35
  • Python version: 3.10.14
  • Huggingface_hub version: 0.26.1
  • Safetensors version: 0.4.5
  • Accelerate version: 1.0.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.0+cu124 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA GeForce RTX 4090

Who can help?

I noticed that many files in transformers use the older sdp api torch.backends.cuda.sdp_kernel. We just discovered a bug in Pytorch 2.5.0 and the old sdp api that would make it run slower pytorch/pytorch#138386

It would be a good idea to update to the new api (from torch.nn.attention import sdpa_kernel, SDPBackend) and set the appropriate compile flag to avoid losing as much as 20% of the performance !

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Gist here as a reference example: https://gist.github.com/mobicham/aa1e77689d9cf866cbea2cb75a53a9e4
More details in the torch issue: pytorch/pytorch#138386

Expected behavior

Examples using sdp with torch 2.5.0 should run at least as fast as 2.4.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    DocumentationGood Second IssueIssues that are more difficult to do than "Good First" issues - give it a try if you want!bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions