Skip to content

Commit 8d6327b

Browse files
add custom op wrapping mxfp8 dim1 cast cuda kernel
stack-info: PR: #2550, branch: danielvegamyhre/stack/6
1 parent c011bad commit 8d6327b

File tree

1 file changed

+84
-1
lines changed

1 file changed

+84
-1
lines changed

torchao/prototype/mx_formats/kernels.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
_f32_to_floatx_unpacked,
1717
_floatx_unpacked_to_f32,
1818
)
19-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_7
19+
from torchao.utils import (
20+
TORCH_VERSION_AT_LEAST_2_4,
21+
TORCH_VERSION_AT_LEAST_2_7,
22+
is_sm_at_least_100,
23+
)
2024

2125
# TODO(future): if needed, make the below work on previous PyTorch versions,
2226
# just need to hunt down the previous location of `libdevice`. An assert
@@ -1730,3 +1734,82 @@ def triton_quantize_nvfp4(
17301734
x: torch.Tensor, tensor_scale: Optional[torch.Tensor] = None
17311735
) -> Tuple[torch.Tensor, torch.Tensor]:
17321736
raise AssertionError("needs torch version 2.8+ and triton")
1737+
1738+
1739+
# MXFP8 CUDA kernel is only built on SM100+
1740+
if is_sm_at_least_100():
1741+
from torchao.prototype import mxfp8_cuda
1742+
1743+
@torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=())
1744+
def mxfp8_quantize_cuda(
1745+
x: torch.Tensor,
1746+
rowwise: bool = False,
1747+
colwise: bool = True,
1748+
scaling_mode: str = "floor",
1749+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1750+
assert x.ndim == 2
1751+
rows, cols = x.shape
1752+
block_size = 32
1753+
assert rows % block_size == 0, "rows must be a multiple of 32"
1754+
assert cols % block_size == 0, "cols must be a multiple of 32"
1755+
output_rowwise, output_colwise, scales_rowwise, scales_colwise = (
1756+
mxfp8_cuda.quantize(
1757+
x, rowwise=rowwise, colwise=colwise, scaling_mode=scaling_mode
1758+
)
1759+
)
1760+
return output_rowwise, output_colwise, scales_rowwise, scales_colwise
1761+
1762+
@mxfp8_quantize_cuda.register_fake
1763+
def _(
1764+
x: torch.Tensor,
1765+
rowwise: bool = False,
1766+
colwise: bool = True,
1767+
scaling_mode: str = "floor",
1768+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1769+
assert x.ndim == 2
1770+
rows, cols = x.shape
1771+
block_size = 32
1772+
assert rows % block_size == 0, "rows must be a multiple of 32"
1773+
assert cols % block_size == 0, "cols must be a multiple of 32"
1774+
num_row_blocks = rows // 32
1775+
num_col_blocks = cols // 32
1776+
1777+
# rowwise
1778+
if rowwise:
1779+
output_rowwise = x.new_empty(rows, cols, dtype=torch.float8_e4m3fn)
1780+
scales_rowwise = x.new_empty(
1781+
rows, num_col_blocks, 1, dtype=torch.float8_e8m0fnu
1782+
)
1783+
else:
1784+
output_rowwise = x.new_empty(0, dtype=torch.float8_e4m3fn)
1785+
scales_rowwise = x.new_empty(0, dtype=torch.float8_e8m0fnu)
1786+
1787+
# colwise
1788+
if colwise:
1789+
# column major
1790+
output_colwise = torch.empty_strided(
1791+
(rows, cols), (1, rows), dtype=torch.float8_e4m3fn, device=x.device
1792+
)
1793+
1794+
# colwise scales are written in column-major format to avoid uncoalesced global memory accesses
1795+
scales_colwise = torch.empty_strided(
1796+
(cols, num_row_blocks),
1797+
(1, cols),
1798+
dtype=torch.float8_e8m0fnu,
1799+
device=x.device,
1800+
)
1801+
else:
1802+
output_colwise = x.new_empty(0, dtype=torch.float8_e4m3fn)
1803+
scales_colwise = x.new_empty(0, dtype=torch.float8_e8m0fnu)
1804+
1805+
return output_rowwise, output_colwise, scales_rowwise, scales_colwise
1806+
1807+
else:
1808+
1809+
def mxfp8_quantize_cuda(
1810+
x: torch.Tensor,
1811+
rowwise: bool = False,
1812+
colwise: bool = True,
1813+
scaling_mode: str = "floor",
1814+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1815+
raise NotImplementedError("needs torch version 2.8+ and sm100")

0 commit comments

Comments
 (0)