-
Notifications
You must be signed in to change notification settings - Fork 317
integration of new mxfp8 casting cuda kernel #2564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
651b912
to
3f88897
Compare
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2564
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit cc00ef6 with merge base 95d13d5 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
3f88897
to
20847a7
Compare
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
20847a7
to
b1ed196
Compare
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
b1ed196
to
bb930b6
Compare
Didn't read any code yet:
I think updating to just accept generically strided inputs and writing out row major (required for mm kernels) is good, is that what you did? |
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
bb930b6
to
b1e237f
Compare
Yep, that's correct |
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
b1e237f
to
c9c9da0
Compare
@@ -25,6 +25,7 @@ | |||
from tqdm import tqdm | |||
|
|||
from torchao.prototype.mx_formats import MXLinearConfig | |||
from torchao.prototype.mx_formats.config import MXFP8CastKernelChoice |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, MXFP8Dim1CastKernelChoice
? since for dim0 we are always using torch.compile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was originally MXFP8Dim1CastKernelChoice
but here we discussed naming it MXFP8CastKernelChoice
for potentially adding support for dim0 and dim0+dim1 casts as well. I don't have strong opinions either way, I went ahead and changed it back to MXFP8Dim1CastKernelChoice
@@ -33,6 +33,12 @@ class MXGemmKernelChoice(Enum): | |||
CUBLAS = "cublas" | |||
|
|||
|
|||
class MXFP8CastKernelChoice(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add some comments on what the options are?
a, | ||
rowwise=False, | ||
colwise=True, | ||
scaling_mode="floor", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO for later to allow choice of scaling modes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added todo on the custom op itself, with an explanation why we currently are using a string param
c9c9da0
to
7b1f899
Compare
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
7b1f899
to
cc00ef6
Compare
Stacked PRs: * __->__#1427 --- --- --- make mxfp8 dim1 cast kernel configurable ## Summary - We recently added a new CUDA kernel for the mxfp8 dim1 cast which is ~1.4x faster than the existing Triton kernel or torch.compile, and using it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b using FSDP=4/8 (pytorch/ao#2513). The integration work for composability with torch.compile + FSDP is complete as well: pytorch/ao#2564 - This PR updates the mxfp8 user facing API to replace the boolean flag `"--mx.use_triton_for_dim1_cast=[true|false]` to `mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]` ## Test plan - Triton: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="triton"` - Cuda: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="cuda"` - Torch: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="torch"` ## Limitations - TP is currently not supported yet, as both the Triton kernel and CUDA kernel are affected by an issue: `RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()`. This is a known issue we were talking to Brian about, will continue following up on it.
Stacked PRs: * __->__#1427 --- --- --- make mxfp8 dim1 cast kernel configurable ## Summary - We recently added a new CUDA kernel for the mxfp8 dim1 cast which is ~1.4x faster than the existing Triton kernel or torch.compile, and using it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b using FSDP=4/8 (pytorch/ao#2513). The integration work for composability with torch.compile + FSDP is complete as well: pytorch/ao#2564 - This PR updates the mxfp8 user facing API to replace the boolean flag `"--mx.use_triton_for_dim1_cast=[true|false]` to `mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]` ## Test plan - Triton: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="triton"` - Cuda: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="cuda"` - Torch: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="torch"` ## Limitations - TP is currently not supported yet, as both the Triton kernel and CUDA kernel are affected by an issue: `RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()`. This is a known issue we were talking to Brian about, will continue following up on it.
Stacked PRs: * __->__#1427 --- --- --- make mxfp8 dim1 cast kernel configurable ## Summary - We recently added a new CUDA kernel for the mxfp8 dim1 cast which is ~1.4x faster than the existing Triton kernel or torch.compile, and using it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b using FSDP=4/8 (pytorch/ao#2513). The integration work for composability with torch.compile + FSDP is complete as well: pytorch/ao#2564 - This PR updates the mxfp8 user facing API to replace the boolean flag `"--mx.use_triton_for_dim1_cast=[true|false]` to `mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]` ## Test plan - Triton: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="triton"` - Cuda: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="cuda"` - Torch: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="torch"` ## Limitations - TP is currently not supported yet, as both the Triton kernel and CUDA kernel are affected by an issue: `RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()`. This is a known issue we were talking to Brian about, will continue following up on it.
Stacked PRs: * __->__#1427 --- --- --- make mxfp8 dim1 cast kernel configurable ## Summary - We recently added a new CUDA kernel for the mxfp8 dim1 cast which is ~1.4x faster than the existing Triton kernel or torch.compile, and using it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b using FSDP=4/8 (pytorch/ao#2513). The integration work for composability with torch.compile + FSDP is complete as well: pytorch/ao#2564 - This PR updates the mxfp8 user facing API to replace the boolean flag `"--mx.use_triton_for_dim1_cast=[true|false]` to `mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]` ## Test plan - Triton: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="triton"` - Cuda: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="cuda"` - Torch: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="torch"` ## Limitations - TP is currently not supported yet, as both the Triton kernel and CUDA kernel are affected by an issue: `RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()`. This is a known issue we were talking to Brian about, will continue following up on it.
Stacked PRs: * __->__#1427 --- --- --- make mxfp8 dim1 cast kernel configurable ## Summary - We recently added a new CUDA kernel for the mxfp8 dim1 cast which is ~1.4x faster than the existing Triton kernel or torch.compile, and using it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b using FSDP=4/8 (pytorch/ao#2513). The integration work for composability with torch.compile + FSDP is complete as well: pytorch/ao#2564 - This PR updates the mxfp8 user facing API to replace the boolean flag `"--mx.use_triton_for_dim1_cast=[true|false]` to `mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]` ## Test plan - Triton: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="triton"` - Cuda: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="cuda"` - Torch: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="torch"` ## Limitations - TP is currently not supported yet, as both the Triton kernel and CUDA kernel are affected by an issue: `RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()`. This is a known issue we were talking to Brian about, will continue following up on it.
Stacked PRs:
integration of new mxfp8 casting cuda kernel
Summary
Integrating kernel added in #2513. Custom op wrapper was recently added in #2543. Remaining code to migrate from my private repo is in this PR:
mxfp8_quantize_cuda
custom op.mxfp8_quantize_cuda
custom optriton_scale_swizzle
kernel to accept both row major and col major inputs (since cuda kernel writes scale in col major, to avoid uncoalesced global accesses)MXFP8Dim1CastKernelChoice
enum and replace all uses of boolean flaguse_fp8_dim1_cast_triton_kernel
with it. (Default to Triton for now)Test plan
pytest test/prototype/mx_formats/test_mx_linear.py -k eager_vs_hp
pytest test/prototype/mx_formats/test_mx_linear.py -k compile
Next steps
./test/prototype/mx_formats/test_mx_dtensor.sh
:RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()
. This is a known issue, will follow up on it.