Skip to content

add custom op wrapping mxfp8 dim1 cast cuda kernel #2550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 16, 2025

Conversation

Copy link

pytorch-bot bot commented Jul 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2550

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 4b0038f with merge base c011bad (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 15, 2025
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/6 branch from 4195316 to 8d6327b Compare July 15, 2025 20:31
@danielvegamyhre danielvegamyhre added mx topic: not user facing Use this tag if you don't want this PR to show up in release notes labels Jul 15, 2025

else:

def mxfp8_quantize_cuda(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit I hate this pattern it would be cool to create a decorator that
@requires_triton

or something that will raise this error if not avail

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree this is not ideal. i like the decorator idea, will try it

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm it seems like with a decorator approach there are problems with torch custom op internals trying to infer the schema.

../.conda/envs/ao/lib/python3.13/site-packages/torch/_library/infer_schema.py:69: in error_fn
    raise ValueError(f"infer_schema(func): {what} Got func with signature {sig})")
E   ValueError: infer_schema(func): Parameter func must have a type annotation. Got func with signature (func))

I solved this by reversing the order of the custom op decorator and new decorator so the signature is preserved:

@requires_sm100
@torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=())
def mxfp8_quantize_cuda(
...

However, I then get an error when trying to register a fake for the custom op:

torchao/prototype/mx_formats/kernels.py:1770: in <module>
    @mxfp8_quantize_cuda.register_fake
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   AttributeError: 'function' object has no attribute 'register_fake'

I inspected the object returned by the decorator and managed to figure out how to get it to work, but it is ugly and think it will probably have side effects.... :

# ... in the decorator ... 
   return decorator.__closure__[0].cell_contents

Here's the full working decorator: https://www.internalfb.com/phabricator/paste/view/P1870964511

IMO I am not confident enough in this implementation to include it atm (I think it probably won't work correctly on non-sm100) and prefer the predictability of the statement (for now).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zou3519 Do you think that multiple decorators should be supported / supportable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to keep the code inside this file consistent, I'd vote for a separate PR fixing all the kernels with a better fallback syntax

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the goal? To error out when the operator gets executed and we are < sm100?

What's wrong with:

@torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=())
def mxfp8_quantize_cuda(...):
    if not_is_sm100(...):
        raise RuntimeError(...)

The multiple decorators can work though. My rec is:

@torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=())
@requires_sm100
def mxfp8_quantize_cuda(

but because requires_sm100 wraps a function, it should also propagate the type annotations to the wrapped function. That'll let torch.library.custom_op read from the type annotations.

x: torch.Tensor,
rowwise: bool = False,
colwise: bool = True,
scaling_mode: str = "floor",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to use the scaling mode enum in as many places as possible, and convert to string right before the call into the c++ wrapper

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, I remembered why I had to use a string in the first place, it seems custom ops don't place nicely with enums:

torchao/prototype/mx_formats/kernels.py:1746: in <module>
    @torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=())
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/ao/lib/python3.13/site-packages/torch/_library/custom_ops.py:148: in inner
    schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/ao/lib/python3.13/site-packages/torch/_library/infer_schema.py:127: in infer_schema
    if annotation_type.__origin__ is tuple:
       ^^^^^^^^^^^^^^^^^^^^^^^^^^
E   AttributeError: type object 'ScaleCalculationMode' has no attribute '__origin__'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, you can't pass an enum to a custom op.

@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/6 branch from 8d6327b to 1ab7788 Compare July 16, 2025 15:26
stack-info: PR: #2550, branch: danielvegamyhre/stack/6
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/6 branch from 1ab7788 to 4b0038f Compare July 16, 2025 16:18
@danielvegamyhre danielvegamyhre merged commit 95d13d5 into main Jul 16, 2025
18 of 19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. mx topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants