- 
                Notifications
    
You must be signed in to change notification settings  - Fork 124
 
Flex attention integration with block mask cache #1210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
01864e4    to
    e464d85      
    Compare
  
    | 
           Hi @djsaunde! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with  If you have received this in error or have any questions, please contact us at [email protected]. Thanks!  | 
    
What does this PR do? Please describe:
Adding flex attention as a
default_sdpaoption, plus aBlockMaskCacheimplementation in order to re-use block masks (expensive to compute) across layers / training steps.Running:
On the gsm8k dataset on a single H100 gives the below tensorboard visualizations.
NLL comparison:
Losses differ by at most ~0.0002 over the first 300 training steps. This is true with or without utilizing the block mask cache.
Throughput (elements per second):
TorchSDPA(gray) vs.FlexSDPAwith block mask caching (orange) vs.FlexSDPAwithout (green) are compared above.TorchSDPAis fastest, withFlexSDPAwith caching being about 65% as fast, and without, about 40% (!) as fast.So, we lose quite a bit of speed here, at least in the padded batches setting. AFAICT, this is because flex attention, at least as of torch 2.6.0 (?), needs to recompile each time it runs when running batch-dependent mask functions (see e.g.: pytorch/pytorch#136196).
I expect the setting with packed sequences of a fixed length to be the most performant, since we only have to create a single block mask and can re-use it thereafter, instead of creating new ones whenever the sequence lengths change. In fact, torchtune seems to use flex attention only in the sample packing case for similar reasons.
Losses also match on a handful of other tested configurations.
Not supported by this PR:
Does your PR introduce any breaking changes? If yes, please list them:
N/A
Check list: