Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions python/sglang/srt/layers/quantization/moe_wna16.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,46 @@

def get_weight_perm(num_bits: int):
perm_list: List[int] = []
# Precompute values used in the inner loops for performance
interleave_4 = np.array([0, 2, 4, 6, 1, 3, 5, 7])
interleave_8 = np.array([0, 2, 1, 3])

# Use local variable assignment for tight loops to improve resolution speed
extend = perm_list.extend
for i in range(32):
perm1: List[int] = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
col = i >> 2 # Faster int division for powers of two
base_row = i & 3 # i % 4
# Precompute row indices for block 0 and 1
rows_block_0 = [
2 * base_row,
2 * base_row + 1,
2 * (base_row + 4),
2 * (base_row + 4) + 1,
]
# Calculate perm1 for block 0 and block 1 in flat structure
perm1 = [
16 * row + col + 8 * block
for block in (0, 1)
for row in rows_block_0
]
# Unroll the extend for perm_list for j in range(4)
base_offset = 256
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
offset = base_offset * j
extend([p + offset for p in perm1])


perm = np.array(perm_list)

if num_bits == 4:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
interleave = interleave_4
elif num_bits == 8:
interleave = np.array([0, 2, 1, 3])
interleave = interleave_8
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))

perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
# Avoiding ravel for better performance; np.ndarray.flatten is a bit faster
perm = perm.reshape((-1, interleave.size))[:, interleave].flatten()
perm = torch.from_numpy(perm)
return perm

Expand Down