Skip to content

Commit 0a6208d

Browse files
modify triton_scale_swizzle kernel to accept column major inputs
stack-info: PR: #2553, branch: danielvegamyhre/stack/9
1 parent aaf8f6b commit 0a6208d

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

torchao/prototype/mx_formats/kernels.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,6 +1405,7 @@ def triton_scale_swizzle(
14051405
scale_cols,
14061406
output_ptr,
14071407
input_row_stride,
1408+
input_col_stride,
14081409
output_block_stride,
14091410
BLOCK_ROWS: tl.constexpr,
14101411
BLOCK_COLS: tl.constexpr,
@@ -1424,7 +1425,7 @@ def triton_scale_swizzle(
14241425
mask = (global_rows < scale_rows) & (global_cols < scale_cols)
14251426

14261427
input_scales = tl.load(
1427-
scale_ptr + global_rows * input_row_stride + global_cols,
1428+
scale_ptr + global_rows * input_row_stride + global_cols * input_col_stride,
14281429
mask=mask,
14291430
other=0.0,
14301431
)
@@ -1464,7 +1465,6 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
14641465
assert scale_tensor.element_size() == 1, (
14651466
"Expected element size to be 1 byte (8 bits)"
14661467
)
1467-
assert scale_tensor.is_contiguous(), "Input tensor must be contiguous"
14681468

14691469
rows, cols = scale_tensor.shape
14701470

@@ -1477,7 +1477,8 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
14771477
out = scale_tensor.new_empty((padded_rows, padded_cols))
14781478

14791479
# Input stride (for row-major format)
1480-
input_row_stride = cols
1480+
input_row_stride = scale_tensor.stride()[0]
1481+
input_col_stride = scale_tensor.stride()[1]
14811482

14821483
# We probably want handle multiple blocks per tile but for now keep it simple
14831484
BLOCK_ROWS, BLOCK_COLS = 128, 4
@@ -1496,6 +1497,7 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
14961497
cols,
14971498
out.view(torch.uint8),
14981499
input_row_stride,
1500+
input_col_stride,
14991501
output_block_stride,
15001502
BLOCK_ROWS=BLOCK_ROWS,
15011503
BLOCK_COLS=BLOCK_COLS,

0 commit comments

Comments
 (0)