Skip to content

Commit f93b1e6

Browse files
register custom sharding for mxfp8 dim1 cast cuda kernel
stack-info: PR: #2551, branch: danielvegamyhre/stack/7
1 parent 4b0038f commit f93b1e6

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

torchao/prototype/mx_formats/kernels.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1812,6 +1812,20 @@ def _(
18121812

18131813
return output_rowwise, output_colwise, scales_rowwise, scales_colwise
18141814

1815+
@register_sharding(torch.ops.torchao.mxfp8_quantize_cuda.default)
1816+
def custom_mxfp8_quantize_cuda_dim1_sharding(
1817+
x: torch.Tensor,
1818+
rowwise: bool = False,
1819+
colwise: bool = True,
1820+
scaling_mode: str = "floor",
1821+
):
1822+
replicate = ([Replicate(), Replicate()], [Replicate(), None])
1823+
# Note that the data is returned transposed, which is why
1824+
# we flip the sharding dim below
1825+
shard_dim0 = ([Shard(1), Shard(1)], [Shard(0), None])
1826+
shard_dim1 = ([Shard(0), Shard(0)], [Shard(1), None])
1827+
acceptable_shardings = [replicate, shard_dim0, shard_dim1]
1828+
return acceptable_shardings
18151829
else:
18161830

18171831
def mxfp8_quantize_cuda(

0 commit comments

Comments
 (0)