File tree Expand file tree Collapse file tree 1 file changed +14
-0
lines changed
torchao/prototype/mx_formats Expand file tree Collapse file tree 1 file changed +14
-0
lines changed Original file line number Diff line number Diff line change @@ -1811,6 +1811,20 @@ def _(
1811
1811
1812
1812
return output_rowwise , output_colwise , scales_rowwise , scales_colwise
1813
1813
1814
+ @register_sharding (torch .ops .torchao .mxfp8_quantize_cuda .default )
1815
+ def custom_mxfp8_quantize_cuda_dim1_sharding (
1816
+ x : torch .Tensor ,
1817
+ rowwise : bool = False ,
1818
+ colwise : bool = True ,
1819
+ scaling_mode : ScaleCalculationMode = ScaleCalculationMode .FLOOR ,
1820
+ ):
1821
+ replicate = ([Replicate (), Replicate ()], [Replicate (), None ])
1822
+ # Note that the data is returned transposed, which is why
1823
+ # we flip the sharding dim below
1824
+ shard_dim0 = ([Shard (1 ), Shard (1 )], [Shard (0 ), None ])
1825
+ shard_dim1 = ([Shard (0 ), Shard (0 )], [Shard (1 ), None ])
1826
+ acceptable_shardings = [replicate , shard_dim0 , shard_dim1 ]
1827
+ return acceptable_shardings
1814
1828
else :
1815
1829
1816
1830
def mxfp8_quantize_cuda (
You can’t perform that action at this time.
0 commit comments