File tree Expand file tree Collapse file tree 1 file changed +14
-0
lines changed
torchao/prototype/mx_formats Expand file tree Collapse file tree 1 file changed +14
-0
lines changed Original file line number Diff line number Diff line change @@ -1812,6 +1812,20 @@ def _(
1812
1812
1813
1813
return output_rowwise , output_colwise , scales_rowwise , scales_colwise
1814
1814
1815
+ @register_sharding (torch .ops .torchao .mxfp8_quantize_cuda .default )
1816
+ def custom_mxfp8_quantize_cuda_dim1_sharding (
1817
+ x : torch .Tensor ,
1818
+ rowwise : bool = False ,
1819
+ colwise : bool = True ,
1820
+ scaling_mode : str = "floor" ,
1821
+ ):
1822
+ replicate = ([Replicate (), Replicate ()], [Replicate (), None ])
1823
+ # Note that the data is returned transposed, which is why
1824
+ # we flip the sharding dim below
1825
+ shard_dim0 = ([Shard (1 ), Shard (1 )], [Shard (0 ), None ])
1826
+ shard_dim1 = ([Shard (0 ), Shard (0 )], [Shard (1 ), None ])
1827
+ acceptable_shardings = [replicate , shard_dim0 , shard_dim1 ]
1828
+ return acceptable_shardings
1815
1829
else :
1816
1830
1817
1831
def mxfp8_quantize_cuda (
You can’t perform that action at this time.
0 commit comments