Skip to content

Commit 27bdd3b

Browse files
register custom sharding for mxfp8 dim1 cast cuda kernel
stack-info: PR: #2551, branch: danielvegamyhre/stack/7
1 parent 1ab7788 commit 27bdd3b

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

torchao/prototype/mx_formats/kernels.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1811,6 +1811,20 @@ def _(
18111811

18121812
return output_rowwise, output_colwise, scales_rowwise, scales_colwise
18131813

1814+
@register_sharding(torch.ops.torchao.mxfp8_quantize_cuda.default)
1815+
def custom_mxfp8_quantize_cuda_dim1_sharding(
1816+
x: torch.Tensor,
1817+
rowwise: bool = False,
1818+
colwise: bool = True,
1819+
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
1820+
):
1821+
replicate = ([Replicate(), Replicate()], [Replicate(), None])
1822+
# Note that the data is returned transposed, which is why
1823+
# we flip the sharding dim below
1824+
shard_dim0 = ([Shard(1), Shard(1)], [Shard(0), None])
1825+
shard_dim1 = ([Shard(0), Shard(0)], [Shard(1), None])
1826+
acceptable_shardings = [replicate, shard_dim0, shard_dim1]
1827+
return acceptable_shardings
18141828
else:
18151829

18161830
def mxfp8_quantize_cuda(

0 commit comments

Comments
 (0)