Skip to content

Commit d858130

Browse files
modify triton_scale_swizzle kernel to accept column major inputs
1 parent 7ce84dd commit d858130

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

torchao/prototype/mx_formats/kernels.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,6 +1404,7 @@ def triton_scale_swizzle(
14041404
scale_cols,
14051405
output_ptr,
14061406
input_row_stride,
1407+
input_col_stride,
14071408
output_block_stride,
14081409
BLOCK_ROWS: tl.constexpr,
14091410
BLOCK_COLS: tl.constexpr,
@@ -1423,7 +1424,7 @@ def triton_scale_swizzle(
14231424
mask = (global_rows < scale_rows) & (global_cols < scale_cols)
14241425

14251426
input_scales = tl.load(
1426-
scale_ptr + global_rows * input_row_stride + global_cols,
1427+
scale_ptr + global_rows * input_row_stride + global_cols * input_col_stride,
14271428
mask=mask,
14281429
other=0.0,
14291430
)
@@ -1463,7 +1464,6 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
14631464
assert scale_tensor.element_size() == 1, (
14641465
"Expected element size to be 1 byte (8 bits)"
14651466
)
1466-
assert scale_tensor.is_contiguous(), "Input tensor must be contiguous"
14671467

14681468
rows, cols = scale_tensor.shape
14691469

@@ -1476,7 +1476,8 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
14761476
out = scale_tensor.new_empty((padded_rows, padded_cols))
14771477

14781478
# Input stride (for row-major format)
1479-
input_row_stride = cols
1479+
input_row_stride = scale_tensor.stride()[0]
1480+
input_col_stride = scale_tensor.stride()[1]
14801481

14811482
# We probably want handle multiple blocks per tile but for now keep it simple
14821483
BLOCK_ROWS, BLOCK_COLS = 128, 4
@@ -1495,6 +1496,7 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
14951496
cols,
14961497
out.view(torch.uint8),
14971498
input_row_stride,
1499+
input_col_stride,
14981500
output_block_stride,
14991501
BLOCK_ROWS=BLOCK_ROWS,
15001502
BLOCK_COLS=BLOCK_COLS,

0 commit comments

Comments
 (0)