Skip to content

Conversation

Skylion007
Copy link

@Skylion007 Skylion007 commented Sep 5, 2025

Description

This forces the dtype of the attention mask for TransfromerEngine to be uint8. This allows DeepSeekV3 and other models to run through TransformerEngine (using the nightly jax:maxtext container from 09-04-2025).

This is needed to pass assert added in TransformerEngine https://github.com/NVIDIA/TransformerEngine/blob/c47f329b2084406093124851a3aeecb935183def/transformer_engine/jax/cpp_extensions/softmax.py#L484 . This enables the MLA from DeepSeekV3 to not through an assert in TransformerEngine.

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

I ran it locally on H200s and verified the assert is no longer an issue when using cudnn_attn_te as the attention implementation.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link

google-cla bot commented Sep 5, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@Skylion007
Copy link
Author

SIgned the CLA

@Skylion007 Skylion007 changed the title Change attention mask type to uint8 for TE Change attention mask type to uint8 for cudnn_flash_te Sep 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant