-
Notifications
You must be signed in to change notification settings - Fork 321
add custom op wrapping mxfp8 dim1 cast cuda kernel #2550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2550
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 4b0038f with merge base c011bad ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
4195316
to
8d6327b
Compare
|
||
else: | ||
|
||
def mxfp8_quantize_cuda( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit I hate this pattern it would be cool to create a decorator that
@requires_triton
or something that will raise this error if not avail
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree this is not ideal. i like the decorator idea, will try it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm it seems like with a decorator approach there are problems with torch custom op internals trying to infer the schema.
../.conda/envs/ao/lib/python3.13/site-packages/torch/_library/infer_schema.py:69: in error_fn
raise ValueError(f"infer_schema(func): {what} Got func with signature {sig})")
E ValueError: infer_schema(func): Parameter func must have a type annotation. Got func with signature (func))
I solved this by reversing the order of the custom op decorator and new decorator so the signature is preserved:
@requires_sm100
@torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=())
def mxfp8_quantize_cuda(
...
However, I then get an error when trying to register a fake for the custom op:
torchao/prototype/mx_formats/kernels.py:1770: in <module>
@mxfp8_quantize_cuda.register_fake
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E AttributeError: 'function' object has no attribute 'register_fake'
I inspected the object returned by the decorator and managed to figure out how to get it to work, but it is ugly and think it will probably have side effects.... :
# ... in the decorator ...
return decorator.__closure__[0].cell_contents
Here's the full working decorator: https://www.internalfb.com/phabricator/paste/view/P1870964511
IMO I am not confident enough in this implementation to include it atm (I think it probably won't work correctly on non-sm100) and prefer the predictability of the statement (for now).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @zou3519 Do you think that multiple decorators should be supported / supportable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be good to keep the code inside this file consistent, I'd vote for a separate PR fixing all the kernels with a better fallback syntax
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the goal? To error out when the operator gets executed and we are < sm100?
What's wrong with:
@torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=())
def mxfp8_quantize_cuda(...):
if not_is_sm100(...):
raise RuntimeError(...)
The multiple decorators can work though. My rec is:
@torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=())
@requires_sm100
def mxfp8_quantize_cuda(
but because requires_sm100 wraps a function, it should also propagate the type annotations to the wrapped function. That'll let torch.library.custom_op read from the type annotations.
x: torch.Tensor, | ||
rowwise: bool = False, | ||
colwise: bool = True, | ||
scaling_mode: str = "floor", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be good to use the scaling mode enum in as many places as possible, and convert to string right before the call into the c++ wrapper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevermind, I remembered why I had to use a string in the first place, it seems custom ops don't place nicely with enums:
torchao/prototype/mx_formats/kernels.py:1746: in <module>
@torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/ao/lib/python3.13/site-packages/torch/_library/custom_ops.py:148: in inner
schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/ao/lib/python3.13/site-packages/torch/_library/infer_schema.py:127: in infer_schema
if annotation_type.__origin__ is tuple:
^^^^^^^^^^^^^^^^^^^^^^^^^^
E AttributeError: type object 'ScaleCalculationMode' has no attribute '__origin__'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, you can't pass an enum to a custom op.
8d6327b
to
1ab7788
Compare
stack-info: PR: #2550, branch: danielvegamyhre/stack/6
1ab7788
to
4b0038f
Compare
Stacked PRs:
add custom op wrapping mxfp8 dim1 cast cuda kernel