-
Notifications
You must be signed in to change notification settings - Fork 325
add custom op wrapping mxfp8 dim1 cast cuda kernel #2550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,11 @@ | |
_f32_to_floatx_unpacked, | ||
_floatx_unpacked_to_f32, | ||
) | ||
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_7 | ||
from torchao.utils import ( | ||
TORCH_VERSION_AT_LEAST_2_4, | ||
TORCH_VERSION_AT_LEAST_2_7, | ||
is_sm_at_least_100, | ||
) | ||
|
||
# TODO(future): if needed, make the below work on previous PyTorch versions, | ||
# just need to hunt down the previous location of `libdevice`. An assert | ||
|
@@ -1730,3 +1734,90 @@ def triton_quantize_nvfp4( | |
x: torch.Tensor, tensor_scale: Optional[torch.Tensor] = None | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
raise AssertionError("needs torch version 2.8+ and triton") | ||
|
||
|
||
# MXFP8 CUDA kernel is only built on SM100+ | ||
if is_sm_at_least_100(): | ||
from torchao.prototype import mxfp8_cuda | ||
|
||
@torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=()) | ||
def mxfp8_quantize_cuda( | ||
x: torch.Tensor, | ||
rowwise: bool = False, | ||
colwise: bool = True, | ||
scaling_mode: str = "floor", | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
# Input shape must be 2D. | ||
assert x.ndim == 2 | ||
rows, cols = x.shape | ||
|
||
# Block size must be a multiple of 32. | ||
block_size = 32 | ||
assert rows % block_size == 0, "rows must be a multiple of 32" | ||
assert cols % block_size == 0, "cols must be a multiple of 32" | ||
|
||
# Convert scaling mode to expected string format and call into kernel. | ||
output_rowwise, output_colwise, scales_rowwise, scales_colwise = ( | ||
mxfp8_cuda.quantize( | ||
x, | ||
rowwise=rowwise, | ||
colwise=colwise, | ||
scaling_mode=scaling_mode, | ||
) | ||
) | ||
return output_rowwise, output_colwise, scales_rowwise, scales_colwise | ||
|
||
@mxfp8_quantize_cuda.register_fake | ||
def _( | ||
x: torch.Tensor, | ||
rowwise: bool = False, | ||
colwise: bool = True, | ||
scaling_mode: str = "floor", | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
assert x.ndim == 2 | ||
rows, cols = x.shape | ||
block_size = 32 | ||
assert rows % block_size == 0, "rows must be a multiple of 32" | ||
assert cols % block_size == 0, "cols must be a multiple of 32" | ||
num_row_blocks = rows // 32 | ||
num_col_blocks = cols // 32 | ||
|
||
# rowwise | ||
if rowwise: | ||
output_rowwise = x.new_empty(rows, cols, dtype=torch.float8_e4m3fn) | ||
scales_rowwise = x.new_empty( | ||
rows, num_col_blocks, 1, dtype=torch.float8_e8m0fnu | ||
) | ||
else: | ||
output_rowwise = x.new_empty(0, dtype=torch.float8_e4m3fn) | ||
scales_rowwise = x.new_empty(0, dtype=torch.float8_e8m0fnu) | ||
|
||
# colwise | ||
if colwise: | ||
# column major | ||
output_colwise = torch.empty_strided( | ||
(rows, cols), (1, rows), dtype=torch.float8_e4m3fn, device=x.device | ||
) | ||
|
||
# colwise scales are written in column-major format to avoid uncoalesced global memory accesses | ||
scales_colwise = torch.empty_strided( | ||
(cols, num_row_blocks), | ||
(1, cols), | ||
dtype=torch.float8_e8m0fnu, | ||
device=x.device, | ||
) | ||
else: | ||
output_colwise = x.new_empty(0, dtype=torch.float8_e4m3fn) | ||
scales_colwise = x.new_empty(0, dtype=torch.float8_e8m0fnu) | ||
|
||
return output_rowwise, output_colwise, scales_rowwise, scales_colwise | ||
|
||
else: | ||
|
||
def mxfp8_quantize_cuda( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit I hate this pattern it would be cool to create a decorator that or something that will raise this error if not avail There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree this is not ideal. i like the decorator idea, will try it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm it seems like with a decorator approach there are problems with torch custom op internals trying to infer the schema.
I solved this by reversing the order of the custom op decorator and new decorator so the signature is preserved: @requires_sm100
@torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=())
def mxfp8_quantize_cuda(
... However, I then get an error when trying to register a fake for the custom op:
I inspected the object returned by the decorator and managed to figure out how to get it to work, but it is ugly and think it will probably have side effects.... : # ... in the decorator ...
return decorator.__closure__[0].cell_contents Here's the full working decorator: https://www.internalfb.com/phabricator/paste/view/P1870964511 IMO I am not confident enough in this implementation to include it atm (I think it probably won't work correctly on non-sm100) and prefer the predictability of the statement (for now). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @zou3519 Do you think that multiple decorators should be supported / supportable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be good to keep the code inside this file consistent, I'd vote for a separate PR fixing all the kernels with a better fallback syntax There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the goal? To error out when the operator gets executed and we are < sm100? What's wrong with:
The multiple decorators can work though. My rec is:
but because requires_sm100 wraps a function, it should also propagate the type annotations to the wrapped function. That'll let torch.library.custom_op read from the type annotations. |
||
x: torch.Tensor, | ||
rowwise: bool = False, | ||
colwise: bool = True, | ||
scaling_mode: str = "floor", | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
raise NotImplementedError("needs torch version 2.8+ and sm100") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be good to use the scaling mode enum in as many places as possible, and convert to string right before the call into the c++ wrapper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevermind, I remembered why I had to use a string in the first place, it seems custom ops don't place nicely with enums:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, you can't pass an enum to a custom op.