@@ -1405,6 +1405,7 @@ def triton_scale_swizzle(
1405
1405
scale_cols ,
1406
1406
output_ptr ,
1407
1407
input_row_stride ,
1408
+ input_col_stride ,
1408
1409
output_block_stride ,
1409
1410
BLOCK_ROWS : tl .constexpr ,
1410
1411
BLOCK_COLS : tl .constexpr ,
@@ -1424,7 +1425,7 @@ def triton_scale_swizzle(
1424
1425
mask = (global_rows < scale_rows ) & (global_cols < scale_cols )
1425
1426
1426
1427
input_scales = tl .load (
1427
- scale_ptr + global_rows * input_row_stride + global_cols ,
1428
+ scale_ptr + global_rows * input_row_stride + global_cols * input_col_stride ,
1428
1429
mask = mask ,
1429
1430
other = 0.0 ,
1430
1431
)
@@ -1464,7 +1465,6 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
1464
1465
assert scale_tensor .element_size () == 1 , (
1465
1466
"Expected element size to be 1 byte (8 bits)"
1466
1467
)
1467
- assert scale_tensor .is_contiguous (), "Input tensor must be contiguous"
1468
1468
1469
1469
rows , cols = scale_tensor .shape
1470
1470
@@ -1477,7 +1477,8 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
1477
1477
out = scale_tensor .new_empty ((padded_rows , padded_cols ))
1478
1478
1479
1479
# Input stride (for row-major format)
1480
- input_row_stride = cols
1480
+ input_row_stride = scale_tensor .stride ()[0 ]
1481
+ input_col_stride = scale_tensor .stride ()[1 ]
1481
1482
1482
1483
# We probably want handle multiple blocks per tile but for now keep it simple
1483
1484
BLOCK_ROWS , BLOCK_COLS = 128 , 4
@@ -1496,6 +1497,7 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
1496
1497
cols ,
1497
1498
out .view (torch .uint8 ),
1498
1499
input_row_stride ,
1500
+ input_col_stride ,
1499
1501
output_block_stride ,
1500
1502
BLOCK_ROWS = BLOCK_ROWS ,
1501
1503
BLOCK_COLS = BLOCK_COLS ,
0 commit comments