Skip to content

Commit 95d13d5

Browse files
add custom op wrapping mxfp8 dim1 cast cuda kernel (#2550)
1 parent e93b7b6 commit 95d13d5

File tree

1 file changed

+92
-1
lines changed

1 file changed

+92
-1
lines changed

torchao/prototype/mx_formats/kernels.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
_f32_to_floatx_unpacked,
1717
_floatx_unpacked_to_f32,
1818
)
19-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_7
19+
from torchao.utils import (
20+
TORCH_VERSION_AT_LEAST_2_4,
21+
TORCH_VERSION_AT_LEAST_2_7,
22+
is_sm_at_least_100,
23+
)
2024

2125
# TODO(future): if needed, make the below work on previous PyTorch versions,
2226
# just need to hunt down the previous location of `libdevice`. An assert
@@ -1730,3 +1734,90 @@ def triton_quantize_nvfp4(
17301734
x: torch.Tensor, tensor_scale: Optional[torch.Tensor] = None
17311735
) -> Tuple[torch.Tensor, torch.Tensor]:
17321736
raise AssertionError("needs torch version 2.8+ and triton")
1737+
1738+
1739+
# MXFP8 CUDA kernel is only built on SM100+
1740+
if is_sm_at_least_100():
1741+
from torchao.prototype import mxfp8_cuda
1742+
1743+
@torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=())
1744+
def mxfp8_quantize_cuda(
1745+
x: torch.Tensor,
1746+
rowwise: bool = False,
1747+
colwise: bool = True,
1748+
scaling_mode: str = "floor",
1749+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1750+
# Input shape must be 2D.
1751+
assert x.ndim == 2
1752+
rows, cols = x.shape
1753+
1754+
# Block size must be a multiple of 32.
1755+
block_size = 32
1756+
assert rows % block_size == 0, "rows must be a multiple of 32"
1757+
assert cols % block_size == 0, "cols must be a multiple of 32"
1758+
1759+
# Convert scaling mode to expected string format and call into kernel.
1760+
output_rowwise, output_colwise, scales_rowwise, scales_colwise = (
1761+
mxfp8_cuda.quantize(
1762+
x,
1763+
rowwise=rowwise,
1764+
colwise=colwise,
1765+
scaling_mode=scaling_mode,
1766+
)
1767+
)
1768+
return output_rowwise, output_colwise, scales_rowwise, scales_colwise
1769+
1770+
@mxfp8_quantize_cuda.register_fake
1771+
def _(
1772+
x: torch.Tensor,
1773+
rowwise: bool = False,
1774+
colwise: bool = True,
1775+
scaling_mode: str = "floor",
1776+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1777+
assert x.ndim == 2
1778+
rows, cols = x.shape
1779+
block_size = 32
1780+
assert rows % block_size == 0, "rows must be a multiple of 32"
1781+
assert cols % block_size == 0, "cols must be a multiple of 32"
1782+
num_row_blocks = rows // 32
1783+
num_col_blocks = cols // 32
1784+
1785+
# rowwise
1786+
if rowwise:
1787+
output_rowwise = x.new_empty(rows, cols, dtype=torch.float8_e4m3fn)
1788+
scales_rowwise = x.new_empty(
1789+
rows, num_col_blocks, 1, dtype=torch.float8_e8m0fnu
1790+
)
1791+
else:
1792+
output_rowwise = x.new_empty(0, dtype=torch.float8_e4m3fn)
1793+
scales_rowwise = x.new_empty(0, dtype=torch.float8_e8m0fnu)
1794+
1795+
# colwise
1796+
if colwise:
1797+
# column major
1798+
output_colwise = torch.empty_strided(
1799+
(rows, cols), (1, rows), dtype=torch.float8_e4m3fn, device=x.device
1800+
)
1801+
1802+
# colwise scales are written in column-major format to avoid uncoalesced global memory accesses
1803+
scales_colwise = torch.empty_strided(
1804+
(cols, num_row_blocks),
1805+
(1, cols),
1806+
dtype=torch.float8_e8m0fnu,
1807+
device=x.device,
1808+
)
1809+
else:
1810+
output_colwise = x.new_empty(0, dtype=torch.float8_e4m3fn)
1811+
scales_colwise = x.new_empty(0, dtype=torch.float8_e8m0fnu)
1812+
1813+
return output_rowwise, output_colwise, scales_rowwise, scales_colwise
1814+
1815+
else:
1816+
1817+
def mxfp8_quantize_cuda(
1818+
x: torch.Tensor,
1819+
rowwise: bool = False,
1820+
colwise: bool = True,
1821+
scaling_mode: str = "floor",
1822+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1823+
raise NotImplementedError("needs torch version 2.8+ and sm100")

0 commit comments

Comments
 (0)