Skip to content

Commit 1ab7788

Browse files
add custom op wrapping mxfp8 dim1 cast cuda kernel
stack-info: PR: #2550, branch: danielvegamyhre/stack/6
1 parent c011bad commit 1ab7788

File tree

1 file changed

+91
-1
lines changed

1 file changed

+91
-1
lines changed

torchao/prototype/mx_formats/kernels.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
_f32_to_floatx_unpacked,
1717
_floatx_unpacked_to_f32,
1818
)
19-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_7
19+
from torchao.utils import (
20+
TORCH_VERSION_AT_LEAST_2_4,
21+
TORCH_VERSION_AT_LEAST_2_7,
22+
is_sm_at_least_100,
23+
)
2024

2125
# TODO(future): if needed, make the below work on previous PyTorch versions,
2226
# just need to hunt down the previous location of `libdevice`. An assert
@@ -32,6 +36,7 @@
3236
F6_E3M2_EXP_BIAS,
3337
F32_EXP_BIAS,
3438
)
39+
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode
3540

3641

3742
def get_bits(x: torch.Tensor) -> str:
@@ -1730,3 +1735,88 @@ def triton_quantize_nvfp4(
17301735
x: torch.Tensor, tensor_scale: Optional[torch.Tensor] = None
17311736
) -> Tuple[torch.Tensor, torch.Tensor]:
17321737
raise AssertionError("needs torch version 2.8+ and triton")
1738+
1739+
1740+
# MXFP8 CUDA kernel is only built on SM100+
1741+
if is_sm_at_least_100():
1742+
from torchao.prototype import mxfp8_cuda
1743+
1744+
@torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=())
1745+
def mxfp8_quantize_cuda(
1746+
x: torch.Tensor,
1747+
rowwise: bool = False,
1748+
colwise: bool = True,
1749+
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
1750+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1751+
# Input shape must be 2D.
1752+
assert x.ndim == 2
1753+
rows, cols = x.shape
1754+
1755+
# Block size must be a multiple of 32.
1756+
block_size = 32
1757+
assert rows % block_size == 0, "rows must be a multiple of 32"
1758+
assert cols % block_size == 0, "cols must be a multiple of 32"
1759+
1760+
# Convert scaling mode to expected string format and call into kernel.
1761+
scale_mode_str = scaling_mode.value.lower()
1762+
output_rowwise, output_colwise, scales_rowwise, scales_colwise = (
1763+
mxfp8_cuda.quantize(
1764+
x, rowwise=rowwise, colwise=colwise, scaling_mode=scale_mode_str
1765+
)
1766+
)
1767+
return output_rowwise, output_colwise, scales_rowwise, scales_colwise
1768+
1769+
@mxfp8_quantize_cuda.register_fake
1770+
def _(
1771+
x: torch.Tensor,
1772+
rowwise: bool = False,
1773+
colwise: bool = True,
1774+
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
1775+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1776+
assert x.ndim == 2
1777+
rows, cols = x.shape
1778+
block_size = 32
1779+
assert rows % block_size == 0, "rows must be a multiple of 32"
1780+
assert cols % block_size == 0, "cols must be a multiple of 32"
1781+
num_row_blocks = rows // 32
1782+
num_col_blocks = cols // 32
1783+
1784+
# rowwise
1785+
if rowwise:
1786+
output_rowwise = x.new_empty(rows, cols, dtype=torch.float8_e4m3fn)
1787+
scales_rowwise = x.new_empty(
1788+
rows, num_col_blocks, 1, dtype=torch.float8_e8m0fnu
1789+
)
1790+
else:
1791+
output_rowwise = x.new_empty(0, dtype=torch.float8_e4m3fn)
1792+
scales_rowwise = x.new_empty(0, dtype=torch.float8_e8m0fnu)
1793+
1794+
# colwise
1795+
if colwise:
1796+
# column major
1797+
output_colwise = torch.empty_strided(
1798+
(rows, cols), (1, rows), dtype=torch.float8_e4m3fn, device=x.device
1799+
)
1800+
1801+
# colwise scales are written in column-major format to avoid uncoalesced global memory accesses
1802+
scales_colwise = torch.empty_strided(
1803+
(cols, num_row_blocks),
1804+
(1, cols),
1805+
dtype=torch.float8_e8m0fnu,
1806+
device=x.device,
1807+
)
1808+
else:
1809+
output_colwise = x.new_empty(0, dtype=torch.float8_e4m3fn)
1810+
scales_colwise = x.new_empty(0, dtype=torch.float8_e8m0fnu)
1811+
1812+
return output_rowwise, output_colwise, scales_rowwise, scales_colwise
1813+
1814+
else:
1815+
1816+
def mxfp8_quantize_cuda(
1817+
x: torch.Tensor,
1818+
rowwise: bool = False,
1819+
colwise: bool = True,
1820+
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
1821+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1822+
raise NotImplementedError("needs torch version 2.8+ and sm100")

0 commit comments

Comments
 (0)