@@ -1404,6 +1404,7 @@ def triton_scale_swizzle(
1404
1404
scale_cols ,
1405
1405
output_ptr ,
1406
1406
input_row_stride ,
1407
+ input_col_stride ,
1407
1408
output_block_stride ,
1408
1409
BLOCK_ROWS : tl .constexpr ,
1409
1410
BLOCK_COLS : tl .constexpr ,
@@ -1423,7 +1424,7 @@ def triton_scale_swizzle(
1423
1424
mask = (global_rows < scale_rows ) & (global_cols < scale_cols )
1424
1425
1425
1426
input_scales = tl .load (
1426
- scale_ptr + global_rows * input_row_stride + global_cols ,
1427
+ scale_ptr + global_rows * input_row_stride + global_cols * input_col_stride ,
1427
1428
mask = mask ,
1428
1429
other = 0.0 ,
1429
1430
)
@@ -1463,7 +1464,6 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
1463
1464
assert scale_tensor .element_size () == 1 , (
1464
1465
"Expected element size to be 1 byte (8 bits)"
1465
1466
)
1466
- assert scale_tensor .is_contiguous (), "Input tensor must be contiguous"
1467
1467
1468
1468
rows , cols = scale_tensor .shape
1469
1469
@@ -1476,7 +1476,8 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
1476
1476
out = scale_tensor .new_empty ((padded_rows , padded_cols ))
1477
1477
1478
1478
# Input stride (for row-major format)
1479
- input_row_stride = cols
1479
+ input_row_stride = scale_tensor .stride ()[0 ]
1480
+ input_col_stride = scale_tensor .stride ()[1 ]
1480
1481
1481
1482
# We probably want handle multiple blocks per tile but for now keep it simple
1482
1483
BLOCK_ROWS , BLOCK_COLS = 128 , 4
@@ -1495,6 +1496,7 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
1495
1496
cols ,
1496
1497
out .view (torch .uint8 ),
1497
1498
input_row_stride ,
1499
+ input_col_stride ,
1498
1500
output_block_stride ,
1499
1501
BLOCK_ROWS = BLOCK_ROWS ,
1500
1502
BLOCK_COLS = BLOCK_COLS ,
0 commit comments