@@ -377,7 +377,7 @@ def quantize_and_dequantize(inputs, scale, quantized_dtype, compute_dtype):
377
377
378
378
379
379
@keras_export ("keras.quantizers.pack_int4" )
380
- def pack_int4 (arr ):
380
+ def pack_int4 (arr , axis = 0 ):
381
381
"""Pack an int4 tensor into an int8 tensor with packed nibbles.
382
382
383
383
Accepts a Keras-compatible tensor. The input values must already be int8
@@ -396,43 +396,72 @@ def pack_int4(arr):
396
396
if arr .dtype != "int8" :
397
397
raise TypeError ("Expected int8 tensor for packing" )
398
398
399
- shape = ops .shape (arr )
400
- rows , cols = shape [0 ], shape [1 ]
399
+ rank = arr .shape .rank
400
+ # 1. Bring `axis` to the front.
401
+ perm = [axis ] + [i for i in range (rank ) if i != axis ]
402
+ inv_perm = [perm .index (i ) for i in range (rank )]
403
+ transposed = ops .transpose (arr , perm )
401
404
402
- orig_rows = rows
403
- if rows % 2 == 1 :
404
- padding_row = ops .zeros ((1 , cols ), dtype = "int8" )
405
- arr = ops .concatenate ([arr , padding_row ], axis = 0 )
406
- rows += 1
405
+ # 2. Pad to even length.
406
+ rows = ops .shape (transposed )[0 ]
407
+ needs_pad = ops .equal (ops .mod (rows , 2 ), 1 )
407
408
408
- # Map signed [-8,7] to unsigned 4-bit two's complement (0..15)
409
- arr_u = ops .where (arr < 0 , arr + 16 , arr )
410
- arr_u = ops .cast (arr_u , "uint8" )
411
- arr_u = ops .reshape (arr_u , (rows // 2 , 2 , cols ))
412
- low = arr_u [:, 0 , :]
413
- high = arr_u [:, 1 , :]
414
- packed = ops .bitwise_or (ops .left_shift (high , 4 ), low )
415
- packed = ops .cast (packed , "int8" )
416
- return packed , ops .shape (packed ), orig_rows
409
+ def _pad (x ):
410
+ pad_shape = ops .concatenate (
411
+ [ops .array ([1 ]), ops .array (ops .shape (x )[1 :])], axis = 0
412
+ )
413
+ pad_row = ops .zeros (pad_shape , dtype = "int8" )
414
+ return ops .concatenate ([x , pad_row ], axis = 0 )
415
+
416
+ transposed = ops .cond (
417
+ needs_pad , lambda : _pad (transposed ), lambda : transposed
418
+ )
419
+ rows_padded = ops .shape (transposed )[0 ]
420
+
421
+ # 3-4. Group in pairs and pack.
422
+ flat_tail = ops .reshape (transposed , (rows_padded // 2 , 2 , - 1 ))
423
+ low = flat_tail [:, 0 , :]
424
+ high = flat_tail [:, 1 , :]
425
+ low_u = ops .where (low < 0 , low + 16 , low )
426
+ high_u = ops .where (high < 0 , high + 16 , high )
427
+ packed_flat = ops .bitwise_or (low_u , ops .left_shift (high_u , 4 ))
428
+ packed_flat = ops .cast (packed_flat , "int8" )
429
+
430
+ # 5-6. Restore shape.
431
+ packed = ops .reshape (
432
+ packed_flat ,
433
+ ops .concatenate (
434
+ [
435
+ ops .expand_dims (rows_padded // 2 , 0 ),
436
+ ops .array (ops .shape (transposed )[1 :]),
437
+ ],
438
+ axis = 0 ,
439
+ ),
440
+ )
441
+ packed = ops .transpose (packed , inv_perm ) # back to original order
442
+ orig_len = rows # number of slices before padding
443
+ return packed , ops .shape (packed ), orig_len
417
444
418
445
419
446
@keras_export ("keras.quantizers.unpack_int4" )
420
- def unpack_int4 (packed , orig_rows ):
447
+ def unpack_int4 (packed , orig_len , axis = 0 ):
421
448
"""Unpack packed int4 tensor (ops) to int8 [-8,7]."""
422
- # Bitwise operations work element-wise.
423
- low = ops .bitwise_and (packed , 0x0F )
424
- high = ops .right_shift (packed , 4 )
425
- high = ops .bitwise_and (high , 0x0F )
426
-
427
- def _to_signed (x ):
428
- return ops .where (x < 8 , x , ops .subtract (x , 16 ))
429
-
430
- low = _to_signed (low )
431
- high = _to_signed (high )
432
-
433
- # Interleave rows back: stacked shape (2, packed_rows, cols)
434
- stacked = ops .stack ([low , high ], axis = 1 ) # (pairs, 2, cols)
435
- unpacked_full = ops .reshape (stacked , (- 1 , stacked .shape [- 1 ]))
436
- # Remove potential padded row.
437
- unpacked = unpacked_full [:orig_rows , :]
438
- return unpacked
449
+ rank = packed .shape .rank
450
+ perm = [axis ] + [i for i in range (rank ) if i != axis ]
451
+ inv_perm = [perm .index (i ) for i in range (rank )]
452
+ transposed = ops .transpose (packed , perm )
453
+
454
+ # 1. Split nibbles.
455
+ low = ops .bitwise_and (transposed , 0x0F )
456
+ high = ops .bitwise_and (ops .right_shift (transposed , 4 ), 0x0F )
457
+ to_signed = lambda x : ops .where (x < 8 , x , x - 16 )
458
+ low = to_signed (low )
459
+ high = to_signed (high )
460
+
461
+ # 2. Interleave.
462
+ stacked = ops .stack ([low , high ], axis = 1 ) # (pairs, 2, …)
463
+ unpacked = ops .reshape (stacked , (- 1 ,) + tuple (ops .shape (transposed )[1 :]))
464
+
465
+ # 3. Remove possible padding and restore layout.
466
+ unpacked = unpacked [:orig_len , ...]
467
+ return ops .transpose (unpacked , inv_perm )
0 commit comments