Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 49 additions & 34 deletions auto_round/export/export_to_autoround/qlinear_fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def pack(self, linear, scales, zeros=None, g_idx=None, global_scale=None, input_
self.weight = scaled_tensor.to(compress_dtype)
else:
compress_dtype = torch.uint8
self.weight_packed = self.pack_fp4_to_uint8(scaled_tensor)
self.weight_packed = pack_fp4_to_uint8(scaled_tensor)

if global_scale is not None:
self.weight_global_scale = global_scale.to(torch.float32).to(device)
Expand All @@ -185,46 +185,61 @@ def pack(self, linear, scales, zeros=None, g_idx=None, global_scale=None, input_
self.input_global_scale = input_global_scale.to(torch.float32).to(device)
return

def pack_fp4_to_uint8(self, x: torch.Tensor) -> torch.Tensor:
"""
Packs a tensor with values in the fp4 range into uint8.
As there are 16 valid fp4 values, two fp4 values can be
packed into one uint8. Each fp4 value is mapped to its
particular index (e.g. 0.5 is mapped to index 1, 6.0 is mapped
to index 7) which is then represented using 4 bits. Consecutive
pairs of 4 bits are then packed into an uint8.

:param x: tensor to pack
returns: a packed tensor in uint8
"""
def pack_fp4_to_uint8(scaled_tensor: torch.Tensor):
if scaled_tensor.device.type == "cuda":
return pack_fp4_to_uint8_cuda(scaled_tensor)
else:
return pack_fp4_to_uint8_cpu(scaled_tensor)

m, n = x.shape
device = x.device

# Create lookup table for FP4 values to indices
# Map the absolute values to 0-7 indices
kE2M1 = torch.tensor(FLOAT_TO_E2M1, device=device, dtype=x.dtype)
# The torch.compile with dynamic=True is incompatible with multiple threads
# https://github.com/pytorch/pytorch/issues/126024
@torch.compiler.disable()
def pack_fp4_to_uint8_cpu(x: torch.Tensor) -> torch.Tensor:
return _pack_fp4_to_uint8(x)

# Find closest valid FP4 value index for each element
abs_x = torch.abs(x)
abs_indices = torch.zeros_like(abs_x, dtype=torch.long)
for i, val in enumerate(kE2M1): # TODO any optimize?
abs_indices = torch.where(torch.isclose(abs_x, val), i, abs_indices)

# Apply sign bit (bit 3) to get final 4-bit representation
indices = abs_indices + (torch.signbit(x) << 3).to(torch.long)
# Adapted from https://github.com/neuralmagic/compressed-tensors/pull/400
@torch.compile(fullgraph=True, dynamic=True)
def pack_fp4_to_uint8_cuda(x: torch.Tensor) -> torch.Tensor:
"""
Packs a tensor with values in the fp4 range into uint8.

:param x: tensor to pack
returns: a packed tensor in uint8
"""
return _pack_fp4_to_uint8(x)


def _pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:

m, n = x.shape
device = x.device

# Create lookup table for FP4 values to indices
# Map the absolute values to 0-7 indices
kE2M1 = torch.tensor(FLOAT_TO_E2M1, device=device, dtype=x.dtype)

# Find closest valid FP4 value index for each element
abs_x = torch.abs(x)
abs_diff_x = torch.abs(abs_x.unsqueeze(-1) - kE2M1) # [m, n, 8]
abs_indices = torch.argmin(abs_diff_x, dim=-1) # [m, n]

# Apply sign bit (bit 3) to get final 4-bit representation
indices = abs_indices + (torch.signbit(x).to(torch.long) << 3)

# Reshape to prepare for packing pairs of values
indices = indices.reshape(-1)
# Reshape to prepare for packing pairs of values
indices = indices.reshape(-1)

# Handle odd length by padding if necessary
if indices.numel() % 2 != 0:
indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)])
# Handle odd length by padding if necessary
if indices.numel() % 2 != 0:
indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)])

# Reshape to pair consecutive elements
indices = indices.reshape(-1, 2)
# Reshape to pair consecutive elements
indices = indices.reshape(-1, 2)

# Pack pairs of 4-bit values into 8-bit values
packed = (indices[:, 0] | (indices[:, 1] << 4)).to(torch.uint8)
# Pack pairs of 4-bit values into 8-bit values
packed = (indices[:, 0] | (indices[:, 1] << 4)).to(torch.uint8)

return packed.reshape(m, n // 2)
return packed.reshape(m, n // 2)
83 changes: 83 additions & 0 deletions test/test_cuda/test_packing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pytest
import torch

from auto_round.export.export_to_autoround.qlinear_fp import FLOAT_TO_E2M1, pack_fp4_to_uint8


# Random sampling from FLOAT_TO_E2M1
def _create_random_e2m1_tensor(shape):
"""Create a tensor of the given shape with random values from FLOAT_TO_E2M1."""
# Create a tensor of indices randomly selected from 0 to len(FLOAT_TO_E2M1)-1
indices = torch.randint(0, len(FLOAT_TO_E2M1), shape)

# Map the indices to their corresponding values
e2m1_tensor = torch.tensor(FLOAT_TO_E2M1, dtype=torch.float32)[indices]
return e2m1_tensor


def pack_fp4_to_uint8_old(x: torch.Tensor) -> torch.Tensor:
"""
Packs a tensor with values in the fp4 range into uint8.
As there are 16 valid fp4 values, two fp4 values can be
packed into one uint8. Each fp4 value is mapped to its
particular index (e.g. 0.5 is mapped to index 1, 6.0 is mapped
to index 7) which is then represented using 4 bits. Consecutive
pairs of 4 bits are then packed into an uint8.

:param x: tensor to pack
returns: a packed tensor in uint8
"""

m, n = x.shape
device = x.device

# Create lookup table for FP4 values to indices
# Map the absolute values to 0-7 indices
kE2M1 = torch.tensor(FLOAT_TO_E2M1, device=device, dtype=x.dtype)

# Find closest valid FP4 value index for each element
abs_x = torch.abs(x)
abs_indices = torch.zeros_like(abs_x, dtype=torch.long)
for i, val in enumerate(kE2M1): # TODO any optimize?
abs_indices = torch.where(torch.isclose(abs_x, val), i, abs_indices)

# Apply sign bit (bit 3) to get final 4-bit representation
indices = abs_indices + (torch.signbit(x) << 3).to(torch.long)

# Reshape to prepare for packing pairs of values
indices = indices.reshape(-1)

# Handle odd length by padding if necessary
if indices.numel() % 2 != 0:
indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)])

# Reshape to pair consecutive elements
indices = indices.reshape(-1, 2)

# Pack pairs of 4-bit values into 8-bit values
packed = (indices[:, 0] | (indices[:, 1] << 4)).to(torch.uint8)

return packed.reshape(m, n // 2)


qwen_weight_shapes = [
torch.Size([2048, 768]),
torch.Size([768, 2048]),
torch.Size([128, 2048]),
torch.Size([512, 2048]),
torch.Size([4096, 2048]),
torch.Size([151936, 2048]),
torch.Size([2048, 4096]),
]


@pytest.mark.parametrize("shape", qwen_weight_shapes)
def test_packing_fp4(shape):
with torch.device("cuda"):
M, N = shape
random_tensor = _create_random_e2m1_tensor((M, N))
# Pack the tensor using the packing function
packed_tensor = pack_fp4_to_uint8(random_tensor)
packed_tensor_old = pack_fp4_to_uint8_old(random_tensor)
# check equal
assert torch.equal(packed_tensor, packed_tensor_old), "Packed tensors are not equal"