Skip to content

Commit 7ce84dd

Browse files
add wrapper with dtensor handling for mxfp8 dim1 cast kernel
1 parent 3a5d3ba commit 7ce84dd

File tree

2 files changed

+56
-4
lines changed

2 files changed

+56
-4
lines changed

torchao/prototype/mx_formats/kernels.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1811,11 +1811,15 @@ def custom_mxfp8_quantize_cuda_dim1_sharding(
18111811
colwise: bool = True,
18121812
scaling_mode: str = "floor",
18131813
):
1814-
replicate = ([Replicate(), Replicate()], [Replicate(), None])
1814+
# _, colwise_data, _, colwise_scales = mxfp8_quantize_cuda(x, rowwise, colwise, scaling_mode)
1815+
replicate = (
1816+
[None, Replicate(), None, Replicate()],
1817+
[None, Replicate(), None, None],
1818+
)
18151819
# Note that the data is returned transposed, which is why
18161820
# we flip the sharding dim below
1817-
shard_dim0 = ([Shard(1), Shard(1)], [Shard(0), None])
1818-
shard_dim1 = ([Shard(0), Shard(0)], [Shard(1), None])
1821+
shard_dim0 = ([None, Shard(1), None, Shard(1)], [None, Shard(0), None, None])
1822+
shard_dim1 = ([None, Shard(0), None, Shard(0)], [None, Shard(1), None, None])
18191823
acceptable_shardings = [replicate, shard_dim0, shard_dim1]
18201824
return acceptable_shardings
18211825
else:

torchao/prototype/mx_formats/mx_linear.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
MXInferenceLinearConfig,
2020
MXLinearConfig,
2121
)
22-
from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim1
22+
from torchao.prototype.mx_formats.kernels import (
23+
mxfp8_quantize_cuda,
24+
triton_to_mxfp8_dim1,
25+
)
2326
from torchao.prototype.mx_formats.mx_tensor import MXTensor
2427
from torchao.quantization.transform_module import (
2528
register_quantize_module_handler,
@@ -66,6 +69,51 @@ def _triton_to_mxfp8_dim1_wrapper(
6669
return mx_tensor
6770

6871

72+
def _cuda_to_mxfp8_dim1_wrapper(
73+
a, block_size, elem_dtype, hp_dtype, gemm_kernel_choice
74+
):
75+
_, a_data, _, a_scale = mxfp8_quantize_cuda(
76+
a,
77+
rowwise=False,
78+
colwise=True,
79+
scaling_mode="floor",
80+
)
81+
if isinstance(a_data, DTensor):
82+
assert isinstance(a_scale, DTensor)
83+
a_data_local = a_data.to_local()
84+
a_scale_local = a_scale.to_local()
85+
inner = MXTensor(
86+
a_scale_local,
87+
a_data_local.t(),
88+
elem_dtype,
89+
block_size,
90+
hp_dtype,
91+
False,
92+
gemm_kernel_choice,
93+
False,
94+
)
95+
mx_tensor = DTensor.from_local(
96+
inner,
97+
a_data.device_mesh,
98+
a_data.placements,
99+
run_check=False,
100+
shape=a_data.t().size(),
101+
stride=a_data.t().stride(),
102+
)
103+
else:
104+
mx_tensor = MXTensor(
105+
a_scale,
106+
a_data.t(),
107+
elem_dtype,
108+
block_size,
109+
hp_dtype,
110+
False,
111+
gemm_kernel_choice,
112+
False,
113+
)
114+
return mx_tensor
115+
116+
69117
@torch._dynamo.allow_in_graph
70118
class mx_mm(torch.autograd.Function):
71119
# There are three gemms in a forward + backward of a Linear layer:

0 commit comments

Comments
 (0)