Skip to content

integration of new mxfp8 casting cuda kernel #2564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 18, 2025

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jul 16, 2025

Stacked PRs:


integration of new mxfp8 casting cuda kernel

Summary

Integrating kernel added in #2513. Custom op wrapper was recently added in #2543. Remaining code to migrate from my private repo is in this PR:

  • Register sharding strategy for mxfp8_quantize_cuda custom op.
  • Add wrapper with Dtensor handling for mxfp8_quantize_cuda custom op
  • Update triton_scale_swizzle kernel to accept both row major and col major inputs (since cuda kernel writes scale in col major, to avoid uncoalesced global accesses)
  • Add MXFP8Dim1CastKernelChoice enum and replace all uses of boolean flag use_fp8_dim1_cast_triton_kernel with it. (Default to Triton for now)
  • Update tests accordingly and verify they are passing.

Test plan

  • pytest test/prototype/mx_formats/test_mx_linear.py -k eager_vs_hp
  • pytest test/prototype/mx_formats/test_mx_linear.py -k compile

Next steps

  • Integrate into torchtitan for e2e fsdp training tests once this stack lands. Torchtitan PR: [mxpf8] Make mxfp8 dim1 cast kernel configurable torchtitan#1401
  • Dtensor tests still having issues both with Triton and CUDA: ./test/prototype/mx_formats/test_mx_dtensor.sh: RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode(). This is a known issue, will follow up on it.

danielvegamyhre added a commit that referenced this pull request Jul 16, 2025
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/13 branch from 651b912 to 3f88897 Compare July 16, 2025 21:14
Copy link

pytorch-bot bot commented Jul 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2564

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cc00ef6 with merge base 95d13d5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 16, 2025
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Jul 16, 2025
danielvegamyhre added a commit that referenced this pull request Jul 16, 2025
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/13 branch from 3f88897 to 20847a7 Compare July 16, 2025 21:31
danielvegamyhre added a commit that referenced this pull request Jul 16, 2025
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/13 branch from 20847a7 to b1ed196 Compare July 16, 2025 21:39
danielvegamyhre added a commit that referenced this pull request Jul 16, 2025
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/13 branch from b1ed196 to bb930b6 Compare July 16, 2025 21:56
@danielvegamyhre danielvegamyhre added topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) topic: performance Use this tag if this PR improves the performance of a feature and removed topic: not user facing Use this tag if you don't want this PR to show up in release notes labels Jul 16, 2025
@danielvegamyhre danielvegamyhre requested a review from vkuzo July 16, 2025 22:49
@danielvegamyhre
Copy link
Contributor Author

@vkuzo @drisspg I squashed the remaining changes in the original stack into the same PR, so that the tests would be in the same PR as the changes.

@drisspg
Copy link
Contributor

drisspg commented Jul 16, 2025

Didn't read any code yet:

Update triton_scale_swizzle kernel to accept both row major and col major inputs (since cuda kernel writes scale in col major, to avoid uncoalesced global accesses)

I think updating to just accept generically strided inputs and writing out row major (required for mm kernels) is good, is that what you did?

danielvegamyhre added a commit that referenced this pull request Jul 16, 2025
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/13 branch from bb930b6 to b1e237f Compare July 16, 2025 23:07
@danielvegamyhre
Copy link
Contributor Author

Didn't read any code yet:

Update triton_scale_swizzle kernel to accept both row major and col major inputs (since cuda kernel writes scale in col major, to avoid uncoalesced global accesses)

I think updating to just accept generically strided inputs and writing out row major (required for mm kernels) is good, is that what you did?

Yep, that's correct

danielvegamyhre added a commit that referenced this pull request Jul 16, 2025
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/13 branch from b1e237f to c9c9da0 Compare July 16, 2025 23:12
@@ -25,6 +25,7 @@
from tqdm import tqdm

from torchao.prototype.mx_formats import MXLinearConfig
from torchao.prototype.mx_formats.config import MXFP8CastKernelChoice
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, MXFP8Dim1CastKernelChoice? since for dim0 we are always using torch.compile

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was originally MXFP8Dim1CastKernelChoice but here we discussed naming it MXFP8CastKernelChoice for potentially adding support for dim0 and dim0+dim1 casts as well. I don't have strong opinions either way, I went ahead and changed it back to MXFP8Dim1CastKernelChoice

@@ -33,6 +33,12 @@ class MXGemmKernelChoice(Enum):
CUBLAS = "cublas"


class MXFP8CastKernelChoice(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some comments on what the options are?

a,
rowwise=False,
colwise=True,
scaling_mode="floor",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO for later to allow choice of scaling modes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added todo on the custom op itself, with an explanation why we currently are using a string param

@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/13 branch from c9c9da0 to 7b1f899 Compare July 17, 2025 23:33
danielvegamyhre added a commit that referenced this pull request Jul 17, 2025
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/13 branch from 7b1f899 to cc00ef6 Compare July 17, 2025 23:40
@danielvegamyhre danielvegamyhre merged commit d828f91 into main Jul 18, 2025
19 checks passed
danielvegamyhre added a commit to pytorch/torchtitan that referenced this pull request Jul 25, 2025
Stacked PRs:
 * __->__#1427


--- --- ---

make mxfp8 dim1 cast kernel configurable

## Summary
- We recently added a new CUDA kernel for the mxfp8 dim1 cast which is
~1.4x faster than the existing Triton kernel or torch.compile, and using
it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b
using FSDP=4/8 (pytorch/ao#2513). The
integration work for composability with torch.compile + FSDP is complete
as well: pytorch/ao#2564
- This PR updates the mxfp8 user facing API to replace the boolean flag
`"--mx.use_triton_for_dim1_cast=[true|false]` to
`mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]`

## Test plan
- Triton: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="triton"`
- Cuda: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="cuda"`
- Torch: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="torch"`

## Limitations
- TP is currently not supported yet, as both the Triton kernel and CUDA
kernel are affected by an issue: `RuntimeError: Attempting to use
FunctionalTensor on its own. Instead, please use it with a corresponding
FunctionalTensorMode()`. This is a known issue we were talking to Brian
about, will continue following up on it.
idoh pushed a commit to idoh/torchtitan that referenced this pull request Jul 28, 2025
Stacked PRs:
 * __->__#1427


--- --- ---

make mxfp8 dim1 cast kernel configurable

## Summary
- We recently added a new CUDA kernel for the mxfp8 dim1 cast which is
~1.4x faster than the existing Triton kernel or torch.compile, and using
it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b
using FSDP=4/8 (pytorch/ao#2513). The
integration work for composability with torch.compile + FSDP is complete
as well: pytorch/ao#2564
- This PR updates the mxfp8 user facing API to replace the boolean flag
`"--mx.use_triton_for_dim1_cast=[true|false]` to
`mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]`

## Test plan
- Triton: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="triton"`
- Cuda: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="cuda"`
- Torch: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="torch"`

## Limitations
- TP is currently not supported yet, as both the Triton kernel and CUDA
kernel are affected by an issue: `RuntimeError: Attempting to use
FunctionalTensor on its own. Instead, please use it with a corresponding
FunctionalTensorMode()`. This is a known issue we were talking to Brian
about, will continue following up on it.
bentherien pushed a commit to bentherien/torchtitan_ that referenced this pull request Aug 5, 2025
Stacked PRs:
 * __->__#1427


--- --- ---

make mxfp8 dim1 cast kernel configurable

## Summary
- We recently added a new CUDA kernel for the mxfp8 dim1 cast which is
~1.4x faster than the existing Triton kernel or torch.compile, and using
it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b
using FSDP=4/8 (pytorch/ao#2513). The
integration work for composability with torch.compile + FSDP is complete
as well: pytorch/ao#2564
- This PR updates the mxfp8 user facing API to replace the boolean flag
`"--mx.use_triton_for_dim1_cast=[true|false]` to
`mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]`

## Test plan
- Triton: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="triton"`
- Cuda: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="cuda"`
- Torch: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="torch"`

## Limitations
- TP is currently not supported yet, as both the Triton kernel and CUDA
kernel are affected by an issue: `RuntimeError: Attempting to use
FunctionalTensor on its own. Instead, please use it with a corresponding
FunctionalTensorMode()`. This is a known issue we were talking to Brian
about, will continue following up on it.
joellidin pushed a commit to tplr-ai/torchtitan that referenced this pull request Aug 8, 2025
Stacked PRs:
 * __->__#1427


--- --- ---

make mxfp8 dim1 cast kernel configurable

## Summary
- We recently added a new CUDA kernel for the mxfp8 dim1 cast which is
~1.4x faster than the existing Triton kernel or torch.compile, and using
it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b
using FSDP=4/8 (pytorch/ao#2513). The
integration work for composability with torch.compile + FSDP is complete
as well: pytorch/ao#2564
- This PR updates the mxfp8 user facing API to replace the boolean flag
`"--mx.use_triton_for_dim1_cast=[true|false]` to
`mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]`

## Test plan
- Triton: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="triton"`
- Cuda: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="cuda"`
- Torch: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="torch"`

## Limitations
- TP is currently not supported yet, as both the Triton kernel and CUDA
kernel are affected by an issue: `RuntimeError: Attempting to use
FunctionalTensor on its own. Instead, please use it with a corresponding
FunctionalTensorMode()`. This is a known issue we were talking to Brian
about, will continue following up on it.
joellidin pushed a commit to tplr-ai/torchtitan that referenced this pull request Aug 8, 2025
Stacked PRs:
 * __->__#1427


--- --- ---

make mxfp8 dim1 cast kernel configurable

## Summary
- We recently added a new CUDA kernel for the mxfp8 dim1 cast which is
~1.4x faster than the existing Triton kernel or torch.compile, and using
it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b
using FSDP=4/8 (pytorch/ao#2513). The
integration work for composability with torch.compile + FSDP is complete
as well: pytorch/ao#2564
- This PR updates the mxfp8 user facing API to replace the boolean flag
`"--mx.use_triton_for_dim1_cast=[true|false]` to
`mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]`

## Test plan
- Triton: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="triton"`
- Cuda: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="cuda"`
- Torch: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="torch"`

## Limitations
- TP is currently not supported yet, as both the Triton kernel and CUDA
kernel are affected by an issue: `RuntimeError: Attempting to use
FunctionalTensor on its own. Instead, please use it with a corresponding
FunctionalTensorMode()`. This is a known issue we were talking to Brian
about, will continue following up on it.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) topic: performance Use this tag if this PR improves the performance of a feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants