Skip to content

integrate mxfp8 dim1 cast kernel choice enum into MXLinear #2554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jul 15, 2025

Stacked PRs:


integrate mxfp8 dim1 cast kernel choice enum into MXLinear

Summary

  • Add MXFP8Dim1CastKernelChoice enum and replace all uses of boolean flag use_fp8_dim1_cast_triton_kernel with it. (Default to Triton for now)
  • Update tests accordingly and verify they are passing.

Test plan

  • pytest test/prototype/mx_formats/test_mx_linear.py -k eager_vs_hp
  • pytest test/prototype/mx_formats/test_mx_linear.py -k compile

Next steps

  • Integrate into torchtitan for e2e fsdp training tests once this stack lands. Torchtitan PR: [mxpf8] Make mxfp8 dim1 cast kernel configurable torchtitan#1401
  • Dtensor tests still having issues both with Triton and CUDA: ./test/prototype/mx_formats/test_mx_dtensor.sh
    • Triton error (known issue): RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()
    • Cuda error in dtensor op dispatch: assert res.ndim == 0, "output tensor should be scalar!"

Copy link

pytorch-bot bot commented Jul 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2554

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 5 Pending

As of commit b7c508c with merge base 95d13d5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/9 branch from d858130 to b01a184 Compare July 15, 2025 20:31
danielvegamyhre added a commit that referenced this pull request Jul 15, 2025
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/10 branch from 97e9a7b to 380e887 Compare July 15, 2025 20:31
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 15, 2025
@danielvegamyhre danielvegamyhre added mx topic: not user facing Use this tag if you don't want this PR to show up in release notes labels Jul 15, 2025
@danielvegamyhre
Copy link
Contributor Author

cc @vkuzo for review of this stack. Changes are tested via the tests ran in "test plan" section, which have been updated to test all 3 dim1 cast kernel choices (none, triton, cuda).

@danielvegamyhre danielvegamyhre requested review from vkuzo and drisspg July 15, 2025 20:41
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/9 to main July 15, 2025 23:43
danielvegamyhre added a commit that referenced this pull request Jul 15, 2025
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/10 branch from 380e887 to 6311b94 Compare July 15, 2025 23:43
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/9 July 15, 2025 23:43
@@ -33,6 +33,11 @@ class MXGemmKernelChoice(Enum):
CUBLAS = "cublas"


class MXFP8Dim1CastKernelChoice(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: name it to support dim0_dim1 cast if that is added in the future? MXFP8CastKernelChoice?

# TODO(1945): remove this config option once torch.compile gives us
# a fast kernel
use_fp8_dim1_cast_triton_kernel: bool = False
mxfp8_dim1_cast_kernel_choice: Optional[MXFP8Dim1CastKernelChoice] = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how would someone use torch.compile generated kernel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By setting mxfp8_dim1_cast_kernel_choice=None. Perhaps we should make it more explicit, by adding a 3rd enum option for MXFP8CastKernelChoice.TORCH?

@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/9 to main July 16, 2025 15:25
danielvegamyhre added a commit that referenced this pull request Jul 16, 2025
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/10 branch from 6311b94 to 8a86dae Compare July 16, 2025 15:26
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/9 July 16, 2025 15:26
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/9 to main July 16, 2025 15:33
danielvegamyhre added a commit that referenced this pull request Jul 16, 2025
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/10 branch from 8a86dae to ab8e821 Compare July 16, 2025 15:33
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/9 July 16, 2025 15:33
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/9 to main July 16, 2025 16:14
danielvegamyhre added a commit that referenced this pull request Jul 16, 2025
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/10 branch from ab8e821 to 7a48119 Compare July 16, 2025 16:14
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/9 July 16, 2025 16:15
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/9 to main July 16, 2025 16:18
danielvegamyhre added a commit that referenced this pull request Jul 16, 2025
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/10 branch from 7a48119 to 622b26d Compare July 16, 2025 16:18
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/9 July 16, 2025 16:19
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/9 to main July 16, 2025 16:22
danielvegamyhre added a commit that referenced this pull request Jul 16, 2025
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/10 branch from 622b26d to 3efc354 Compare July 16, 2025 16:22
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/9 July 16, 2025 16:22
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/9 to main July 16, 2025 17:19
danielvegamyhre added a commit that referenced this pull request Jul 16, 2025
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/10 branch from 3efc354 to 80fc9d0 Compare July 16, 2025 17:19
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/9 July 16, 2025 17:19
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/9 branch from 583c8c5 to 2938614 Compare July 16, 2025 20:36
danielvegamyhre added a commit that referenced this pull request Jul 16, 2025
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/10 branch from 80fc9d0 to 5d5f087 Compare July 16, 2025 20:36
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/9 to main July 16, 2025 20:40
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/10 branch from 5d5f087 to b7c508c Compare July 16, 2025 20:40
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/9 July 16, 2025 20:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. mx topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants