Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 92 additions & 1 deletion torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
_f32_to_floatx_unpacked,
_floatx_unpacked_to_f32,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_7
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_7,
is_sm_at_least_100,
)

# TODO(future): if needed, make the below work on previous PyTorch versions,
# just need to hunt down the previous location of `libdevice`. An assert
Expand Down Expand Up @@ -1730,3 +1734,90 @@ def triton_quantize_nvfp4(
x: torch.Tensor, tensor_scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
raise AssertionError("needs torch version 2.8+ and triton")


# MXFP8 CUDA kernel is only built on SM100+
if is_sm_at_least_100():
from torchao.prototype import mxfp8_cuda

@torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=())
def mxfp8_quantize_cuda(
x: torch.Tensor,
rowwise: bool = False,
colwise: bool = True,
scaling_mode: str = "floor",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to use the scaling mode enum in as many places as possible, and convert to string right before the call into the c++ wrapper

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, I remembered why I had to use a string in the first place, it seems custom ops don't place nicely with enums:

torchao/prototype/mx_formats/kernels.py:1746: in <module>
    @torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=())
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/ao/lib/python3.13/site-packages/torch/_library/custom_ops.py:148: in inner
    schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/ao/lib/python3.13/site-packages/torch/_library/infer_schema.py:127: in infer_schema
    if annotation_type.__origin__ is tuple:
       ^^^^^^^^^^^^^^^^^^^^^^^^^^
E   AttributeError: type object 'ScaleCalculationMode' has no attribute '__origin__'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, you can't pass an enum to a custom op.

) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# Input shape must be 2D.
assert x.ndim == 2
rows, cols = x.shape

# Block size must be a multiple of 32.
block_size = 32
assert rows % block_size == 0, "rows must be a multiple of 32"
assert cols % block_size == 0, "cols must be a multiple of 32"

# Convert scaling mode to expected string format and call into kernel.
output_rowwise, output_colwise, scales_rowwise, scales_colwise = (
mxfp8_cuda.quantize(
x,
rowwise=rowwise,
colwise=colwise,
scaling_mode=scaling_mode,
)
)
return output_rowwise, output_colwise, scales_rowwise, scales_colwise

@mxfp8_quantize_cuda.register_fake
def _(
x: torch.Tensor,
rowwise: bool = False,
colwise: bool = True,
scaling_mode: str = "floor",
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
assert x.ndim == 2
rows, cols = x.shape
block_size = 32
assert rows % block_size == 0, "rows must be a multiple of 32"
assert cols % block_size == 0, "cols must be a multiple of 32"
num_row_blocks = rows // 32
num_col_blocks = cols // 32

# rowwise
if rowwise:
output_rowwise = x.new_empty(rows, cols, dtype=torch.float8_e4m3fn)
scales_rowwise = x.new_empty(
rows, num_col_blocks, 1, dtype=torch.float8_e8m0fnu
)
else:
output_rowwise = x.new_empty(0, dtype=torch.float8_e4m3fn)
scales_rowwise = x.new_empty(0, dtype=torch.float8_e8m0fnu)

# colwise
if colwise:
# column major
output_colwise = torch.empty_strided(
(rows, cols), (1, rows), dtype=torch.float8_e4m3fn, device=x.device
)

# colwise scales are written in column-major format to avoid uncoalesced global memory accesses
scales_colwise = torch.empty_strided(
(cols, num_row_blocks),
(1, cols),
dtype=torch.float8_e8m0fnu,
device=x.device,
)
else:
output_colwise = x.new_empty(0, dtype=torch.float8_e4m3fn)
scales_colwise = x.new_empty(0, dtype=torch.float8_e8m0fnu)

return output_rowwise, output_colwise, scales_rowwise, scales_colwise

else:

def mxfp8_quantize_cuda(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit I hate this pattern it would be cool to create a decorator that
@requires_triton

or something that will raise this error if not avail

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree this is not ideal. i like the decorator idea, will try it

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm it seems like with a decorator approach there are problems with torch custom op internals trying to infer the schema.

../.conda/envs/ao/lib/python3.13/site-packages/torch/_library/infer_schema.py:69: in error_fn
    raise ValueError(f"infer_schema(func): {what} Got func with signature {sig})")
E   ValueError: infer_schema(func): Parameter func must have a type annotation. Got func with signature (func))

I solved this by reversing the order of the custom op decorator and new decorator so the signature is preserved:

@requires_sm100
@torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=())
def mxfp8_quantize_cuda(
...

However, I then get an error when trying to register a fake for the custom op:

torchao/prototype/mx_formats/kernels.py:1770: in <module>
    @mxfp8_quantize_cuda.register_fake
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   AttributeError: 'function' object has no attribute 'register_fake'

I inspected the object returned by the decorator and managed to figure out how to get it to work, but it is ugly and think it will probably have side effects.... :

# ... in the decorator ... 
   return decorator.__closure__[0].cell_contents

Here's the full working decorator: https://www.internalfb.com/phabricator/paste/view/P1870964511

IMO I am not confident enough in this implementation to include it atm (I think it probably won't work correctly on non-sm100) and prefer the predictability of the statement (for now).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zou3519 Do you think that multiple decorators should be supported / supportable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to keep the code inside this file consistent, I'd vote for a separate PR fixing all the kernels with a better fallback syntax

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the goal? To error out when the operator gets executed and we are < sm100?

What's wrong with:

@torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=())
def mxfp8_quantize_cuda(...):
    if not_is_sm100(...):
        raise RuntimeError(...)

The multiple decorators can work though. My rec is:

@torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=())
@requires_sm100
def mxfp8_quantize_cuda(

but because requires_sm100 wraps a function, it should also propagate the type annotations to the wrapped function. That'll let torch.library.custom_op read from the type annotations.

x: torch.Tensor,
rowwise: bool = False,
colwise: bool = True,
scaling_mode: str = "floor",
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
raise NotImplementedError("needs torch version 2.8+ and sm100")
Loading