|
16 | 16 | _f32_to_floatx_unpacked,
|
17 | 17 | _floatx_unpacked_to_f32,
|
18 | 18 | )
|
19 |
| -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_7 |
| 19 | +from torchao.utils import ( |
| 20 | + TORCH_VERSION_AT_LEAST_2_4, |
| 21 | + TORCH_VERSION_AT_LEAST_2_7, |
| 22 | + is_sm_at_least_100, |
| 23 | +) |
20 | 24 |
|
21 | 25 | # TODO(future): if needed, make the below work on previous PyTorch versions,
|
22 | 26 | # just need to hunt down the previous location of `libdevice`. An assert
|
@@ -1730,3 +1734,90 @@ def triton_quantize_nvfp4(
|
1730 | 1734 | x: torch.Tensor, tensor_scale: Optional[torch.Tensor] = None
|
1731 | 1735 | ) -> Tuple[torch.Tensor, torch.Tensor]:
|
1732 | 1736 | raise AssertionError("needs torch version 2.8+ and triton")
|
| 1737 | + |
| 1738 | + |
| 1739 | +# MXFP8 CUDA kernel is only built on SM100+ |
| 1740 | +if is_sm_at_least_100(): |
| 1741 | + from torchao.prototype import mxfp8_cuda |
| 1742 | + |
| 1743 | + @torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=()) |
| 1744 | + def mxfp8_quantize_cuda( |
| 1745 | + x: torch.Tensor, |
| 1746 | + rowwise: bool = False, |
| 1747 | + colwise: bool = True, |
| 1748 | + scaling_mode: str = "floor", |
| 1749 | + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| 1750 | + # Input shape must be 2D. |
| 1751 | + assert x.ndim == 2 |
| 1752 | + rows, cols = x.shape |
| 1753 | + |
| 1754 | + # Block size must be a multiple of 32. |
| 1755 | + block_size = 32 |
| 1756 | + assert rows % block_size == 0, "rows must be a multiple of 32" |
| 1757 | + assert cols % block_size == 0, "cols must be a multiple of 32" |
| 1758 | + |
| 1759 | + # Convert scaling mode to expected string format and call into kernel. |
| 1760 | + output_rowwise, output_colwise, scales_rowwise, scales_colwise = ( |
| 1761 | + mxfp8_cuda.quantize( |
| 1762 | + x, |
| 1763 | + rowwise=rowwise, |
| 1764 | + colwise=colwise, |
| 1765 | + scaling_mode=scaling_mode, |
| 1766 | + ) |
| 1767 | + ) |
| 1768 | + return output_rowwise, output_colwise, scales_rowwise, scales_colwise |
| 1769 | + |
| 1770 | + @mxfp8_quantize_cuda.register_fake |
| 1771 | + def _( |
| 1772 | + x: torch.Tensor, |
| 1773 | + rowwise: bool = False, |
| 1774 | + colwise: bool = True, |
| 1775 | + scaling_mode: str = "floor", |
| 1776 | + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| 1777 | + assert x.ndim == 2 |
| 1778 | + rows, cols = x.shape |
| 1779 | + block_size = 32 |
| 1780 | + assert rows % block_size == 0, "rows must be a multiple of 32" |
| 1781 | + assert cols % block_size == 0, "cols must be a multiple of 32" |
| 1782 | + num_row_blocks = rows // 32 |
| 1783 | + num_col_blocks = cols // 32 |
| 1784 | + |
| 1785 | + # rowwise |
| 1786 | + if rowwise: |
| 1787 | + output_rowwise = x.new_empty(rows, cols, dtype=torch.float8_e4m3fn) |
| 1788 | + scales_rowwise = x.new_empty( |
| 1789 | + rows, num_col_blocks, 1, dtype=torch.float8_e8m0fnu |
| 1790 | + ) |
| 1791 | + else: |
| 1792 | + output_rowwise = x.new_empty(0, dtype=torch.float8_e4m3fn) |
| 1793 | + scales_rowwise = x.new_empty(0, dtype=torch.float8_e8m0fnu) |
| 1794 | + |
| 1795 | + # colwise |
| 1796 | + if colwise: |
| 1797 | + # column major |
| 1798 | + output_colwise = torch.empty_strided( |
| 1799 | + (rows, cols), (1, rows), dtype=torch.float8_e4m3fn, device=x.device |
| 1800 | + ) |
| 1801 | + |
| 1802 | + # colwise scales are written in column-major format to avoid uncoalesced global memory accesses |
| 1803 | + scales_colwise = torch.empty_strided( |
| 1804 | + (cols, num_row_blocks), |
| 1805 | + (1, cols), |
| 1806 | + dtype=torch.float8_e8m0fnu, |
| 1807 | + device=x.device, |
| 1808 | + ) |
| 1809 | + else: |
| 1810 | + output_colwise = x.new_empty(0, dtype=torch.float8_e4m3fn) |
| 1811 | + scales_colwise = x.new_empty(0, dtype=torch.float8_e8m0fnu) |
| 1812 | + |
| 1813 | + return output_rowwise, output_colwise, scales_rowwise, scales_colwise |
| 1814 | + |
| 1815 | +else: |
| 1816 | + |
| 1817 | + def mxfp8_quantize_cuda( |
| 1818 | + x: torch.Tensor, |
| 1819 | + rowwise: bool = False, |
| 1820 | + colwise: bool = True, |
| 1821 | + scaling_mode: str = "floor", |
| 1822 | + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| 1823 | + raise NotImplementedError("needs torch version 2.8+ and sm100") |
0 commit comments