-
Notifications
You must be signed in to change notification settings - Fork 317
integrate mxfp8 dim1 cast kernel choice enum into MXLinear #2554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2554
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 5 PendingAs of commit b7c508c with merge base 95d13d5 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
d858130
to
b01a184
Compare
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
97e9a7b
to
380e887
Compare
cc @vkuzo for review of this stack. Changes are tested via the tests ran in "test plan" section, which have been updated to test all 3 dim1 cast kernel choices (none, triton, cuda). |
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
380e887
to
6311b94
Compare
@@ -33,6 +33,11 @@ class MXGemmKernelChoice(Enum): | |||
CUBLAS = "cublas" | |||
|
|||
|
|||
class MXFP8Dim1CastKernelChoice(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: name it to support dim0_dim1 cast if that is added in the future? MXFP8CastKernelChoice
?
# TODO(1945): remove this config option once torch.compile gives us | ||
# a fast kernel | ||
use_fp8_dim1_cast_triton_kernel: bool = False | ||
mxfp8_dim1_cast_kernel_choice: Optional[MXFP8Dim1CastKernelChoice] = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how would someone use torch.compile generated kernel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By setting mxfp8_dim1_cast_kernel_choice=None
. Perhaps we should make it more explicit, by adding a 3rd enum option for MXFP8CastKernelChoice.TORCH
?
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
6311b94
to
8a86dae
Compare
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
8a86dae
to
ab8e821
Compare
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
ab8e821
to
7a48119
Compare
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
7a48119
to
622b26d
Compare
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
622b26d
to
3efc354
Compare
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
3efc354
to
80fc9d0
Compare
583c8c5
to
2938614
Compare
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
80fc9d0
to
5d5f087
Compare
stack-info: PR: #2554, branch: danielvegamyhre/stack/10
5d5f087
to
b7c508c
Compare
Stacked PRs:
integrate mxfp8 dim1 cast kernel choice enum into MXLinear
Summary
MXFP8Dim1CastKernelChoice
enum and replace all uses of boolean flaguse_fp8_dim1_cast_triton_kernel
with it. (Default to Triton for now)Test plan
pytest test/prototype/mx_formats/test_mx_linear.py -k eager_vs_hp
pytest test/prototype/mx_formats/test_mx_linear.py -k compile
Next steps
./test/prototype/mx_formats/test_mx_dtensor.sh
RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()
assert res.ndim == 0, "output tensor should be scalar!"