Skip to content

Commit 1d84542

Browse files
committed
Align Int4Tensor implementation details with the design of Float8Tensor
Summary: Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N] Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops) * Added VERSION 2 for Int4WeightOnlyConfig * Migrated op implementation and tests from #2387 Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_tensor.py Reviewers: Subscribers: Tasks: Tags:
1 parent 09c1ec3 commit 1d84542

File tree

3 files changed

+591
-143
lines changed

3 files changed

+591
-143
lines changed

test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Lines changed: 312 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,79 @@
77
import unittest
88

99
import torch
10+
import torch.nn as nn
11+
import torch.nn.functional as F
1012
from torch.testing._internal.common_utils import (
1113
TestCase,
14+
instantiate_parametrized_tests,
15+
parametrize,
1216
run_tests,
1317
)
1418

15-
from torchao.quantization import (
16-
Int4WeightOnlyConfig,
17-
quantize_,
18-
)
19+
from torchao.prototype.moe_quant.utils import MoEQuantConfig
20+
from torchao.quantization import Int4WeightOnlyConfig, quantize_
1921
from torchao.quantization.utils import compute_error
20-
from torchao.utils import (
21-
TORCH_VERSION_AT_LEAST_2_8,
22-
is_sm_at_least_90,
23-
)
22+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_90
23+
24+
25+
class Experts(nn.Module):
26+
def __init__(
27+
self,
28+
num_local_experts: int,
29+
dim: int,
30+
hidden_dim: int,
31+
dtype: torch.dtype,
32+
device: torch.device,
33+
) -> None:
34+
super().__init__()
35+
36+
self.num_local_experts = num_local_experts
37+
self.dim = dim
38+
39+
self.w1: nn.Parameter = nn.Parameter(
40+
torch.randn(
41+
num_local_experts,
42+
dim,
43+
hidden_dim,
44+
dtype=dtype,
45+
device=device,
46+
)
47+
)
48+
49+
self.w2: nn.Parameter = nn.Parameter(
50+
torch.randn(
51+
num_local_experts,
52+
hidden_dim,
53+
dim,
54+
dtype=dtype,
55+
device=device,
56+
)
57+
)
58+
59+
self.w3: nn.Parameter = nn.Parameter(
60+
torch.randn(
61+
num_local_experts,
62+
dim,
63+
hidden_dim,
64+
dtype=dtype,
65+
device=device,
66+
)
67+
)
68+
69+
def forward(
70+
self,
71+
routed_in_egD: torch.Tensor, # noqa: N803
72+
) -> torch.Tensor:
73+
e = self.num_local_experts
74+
D = self.dim
75+
76+
x_egD = routed_in_egD.view(e, -1, D)
77+
78+
middle_out_egF = F.silu(torch.bmm(x_egD, self.w1)) * torch.bmm(x_egD, self.w3)
79+
out_egD = torch.bmm(middle_out_egF, self.w2)
80+
out_egD = out_egD.view(-1, D)
81+
82+
return out_egD
2483

2584

2685
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
@@ -61,9 +120,9 @@ def test_slice(self):
61120
quantize_(dummy, self.config)
62121
weight1 = dummy.weight.narrow(0, 0, 64)
63122
weight2 = dummy.weight.narrow(1, 0, 128)
64-
self.assertEqual(weight1._data, dummy.weight._data.narrow(0, 0, 64))
123+
self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, 64))
65124
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(1, 0, 64))
66-
self.assertEqual(weight2._data, dummy.weight._data.narrow(1, 0, 64))
125+
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, 64))
67126
self.assertEqual(weight2.scale, dummy.weight.scale.narrow(0, 0, 1))
68127

69128
# check for sliced weight, before and after float8 quantization
@@ -80,31 +139,62 @@ def test_slice(self):
80139
res = dummy(input)
81140
assert compute_error(res, res_ref) > 15
82141

83-
def test_slice_and_copy_(self):
142+
def test_slice_preserves_aliasing(self):
143+
config = self.config
84144
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
85145
l.weight = torch.nn.Parameter(
86146
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
87147
)
88-
quantize_(l, self.config)
148+
quantize_(l, config)
89149
param = l.weight
90150
param_data = param.data
91151
param_data = param_data.narrow(0, 0, 512)
92-
assert param.data._data.data_ptr() == param_data._data.data_ptr()
152+
# Making sure the aliasing is preserved in sliced quantized Tensor
153+
assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr()
93154
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()
94-
assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr()
95-
orig_value = param.data._data[0][0].item()
96155

97-
# dummy_l has random input (shouldn't be 0)
156+
def test_slice_and_copy_similar_to_vllm(self):
157+
# making sure https://github.com/vllm-project/vllm/blob/90bd2ab6e3eb7e83d3f40d99fc23e6e43834743a/vllm/model_executor/layers/linear.py#L483-L495 works properly
158+
# the test is similar to the linked code, but with some hardcoded arguments
159+
# and does not use tensor parallelism
160+
161+
dtype = torch.bfloat16
162+
device = "cuda"
163+
config = self.config
164+
l = torch.nn.Linear(1024, 1024, device="cuda", dtype=dtype)
165+
quantize_(l, config)
166+
167+
# high level, we do a narrow for both param.data and the loaded_weights
168+
# and do inplace copy_ to copy from the loaded_weights into param.data
169+
170+
# simulate loaded_weight
98171
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
99-
quantize_(dummy_l, self.config)
100-
quantized = dummy_l.weight
101-
quantized = quantized.narrow(0, 0, 512)
172+
# making the weight different
173+
dummy_l.weight = torch.nn.Parameter(
174+
dummy_l.weight + 2 * torch.randn(1024, 1024, device=device, dtype=dtype),
175+
requires_grad=False,
176+
)
177+
quantize_(dummy_l, config)
102178

103-
param_data.copy_(quantized)
179+
output_dim = 0
180+
shard_size = 512
181+
for tp_rank in [0, 1]:
182+
start_idx = tp_rank * shard_size
183+
param = l.weight
184+
param_data = param.data
185+
param_data = param_data.narrow(output_dim, start_idx, shard_size)
186+
orig_value = param_data.qdata[0][0].item()
187+
loaded_weight = dummy_l.weight
188+
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
104189

105-
# making sure param.data is updated
106-
assert param.data._data[0][0] != orig_value
190+
# making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0]
191+
assert orig_value != loaded_weight.qdata[0][0]
192+
param_data.copy_(loaded_weight)
193+
# making sure param.data is updated to loaded_weight
194+
assert param_data.qdata[0][0] == loaded_weight.qdata[0][0]
195+
assert torch.equal(param_data.scale, loaded_weight.scale)
107196

197+
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
108198
def test_bmm(self):
109199
class M(torch.nn.Module):
110200
def __init__(self, weight):
@@ -126,20 +216,213 @@ def forward(self, x):
126216
quantized = m(input)
127217
self.assertTrue(compute_error(original, quantized) > 18)
128218

129-
def test_to_device(self):
219+
@parametrize(
220+
"sizes",
221+
[
222+
((128,), 256, 128),
223+
((32, 128), 64, 256),
224+
((2, 32, 128), 64, 256),
225+
],
226+
)
227+
def test_to_device(self, sizes):
228+
config = self.config
229+
M, N, K = sizes
230+
dtype = torch.bfloat16
130231
for device in self.GPU_DEVICES:
131-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
132-
quantize_(linear, self.config)
232+
input_tensor = torch.randn(*M, K, dtype=dtype, device=device)
233+
linear = torch.nn.Linear(K, N, dtype=dtype)
234+
quantize_(linear, config)
133235
linear.to(device)
236+
linear(input_tensor)
134237

135-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
136-
quantize_(linear, self.config)
238+
linear = torch.nn.Linear(K, N, dtype=dtype)
239+
quantize_(linear, config)
137240
linear.to(device=device)
241+
linear(input_tensor)
138242

139-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
140-
quantize_(linear, self.config)
243+
linear = torch.nn.Linear(K, N, dtype=dtype)
244+
quantize_(linear, config)
141245
linear.to(device)
246+
linear(input_tensor)
247+
248+
@parametrize(
249+
"sizes",
250+
[
251+
((128,), 256, 128),
252+
((32, 128), 64, 256),
253+
((2, 32, 128), 64, 256),
254+
],
255+
)
256+
def test_cat(self, sizes):
257+
config = self.config
258+
dtype = torch.bfloat16
259+
device = "cuda"
260+
M, N, K = sizes
261+
linear1 = torch.nn.Linear(K, N, dtype=dtype, device=device)
262+
linear2 = torch.nn.Linear(K, N, dtype=dtype, device=device)
263+
input_cat1 = torch.randn(*M, K, dtype=dtype, device=device)
264+
265+
cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
266+
dummy_linear1 = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device)
267+
268+
dummy_linear1.weight = torch.nn.Parameter(cat_weight1)
269+
quantize_(dummy_linear1, config)
270+
271+
quantize_(linear1, config)
272+
quantize_(linear2, config)
273+
274+
cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
275+
self.assertTrue(cat_qweight1.shape, (2 * N, K))
276+
self.assertEqual(
277+
dummy_linear1.weight.qdata,
278+
cat_qweight1.qdata,
279+
)
280+
self.assertEqual(
281+
dummy_linear1.weight.scale,
282+
cat_qweight1.scale,
283+
)
284+
self.assertEqual(
285+
dummy_linear1.weight.zero_point,
286+
cat_qweight1.zero_point,
287+
)
288+
289+
# making sure cat_qweight1 can be used for inference
290+
dummy_linear1.weight = torch.nn.Parameter(cat_qweight1, requires_grad=False)
291+
dummy_linear1(input_cat1)
292+
293+
# align the scale and zero_point before concatenation
294+
linear2.weight.scale = linear1.weight.scale
295+
linear2.weight.zero_point = linear1.weight.zero_point
296+
cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
297+
self.assertTrue(cat_qweight2.shape, (N, 2 * K))
298+
ref_data = torch.cat(
299+
[
300+
linear1.weight.qdata,
301+
linear2.weight.qdata,
302+
],
303+
dim=1,
304+
)
305+
ref_scale = linear1.weight.scale
306+
self.assertEqual(cat_qweight2.qdata, ref_data)
307+
self.assertEqual(cat_qweight2.scale, ref_scale)
308+
309+
def test_moe_weight_reshape_ops(self):
310+
"""This is testing the op call sequence in saving and loading quantization
311+
checkpoints in llama-models for llama4
312+
(https://github.com/meta-llama/llama-models/tree/main/models/llama4)
313+
"""
314+
# only per row quantization is supported for bmm
315+
dtype = torch.bfloat16
316+
device = "cuda"
317+
318+
bmm_config = self.config
319+
moe_config = MoEQuantConfig(bmm_config)
320+
321+
batch_size = 4
322+
num_experts = 2
323+
input_dim = 64
324+
dim = 128
325+
hidden_dim = 256
326+
327+
moe1 = Experts(num_experts, dim, hidden_dim, dtype, device)
328+
moe2 = Experts(num_experts, dim, hidden_dim, dtype, device)
329+
moe_combined = Experts(num_experts, dim, 2 * hidden_dim, dtype, device)
330+
input = torch.randn(batch_size, input_dim, dim, dtype=dtype, device=device)
331+
332+
moes = [moe1, moe2]
333+
334+
for moe in moes:
335+
moe(input)
336+
337+
def filter_fn(module, fqn):
338+
return isinstance(module, Experts)
339+
340+
# need to transpose before quantizing
341+
moe.w1 = torch.nn.Parameter(
342+
moe.w1.transpose(1, 2).contiguous(), requires_grad=False
343+
)
344+
moe.w2 = torch.nn.Parameter(
345+
moe.w2.transpose(1, 2).contiguous(), requires_grad=False
346+
)
347+
moe.w3 = torch.nn.Parameter(
348+
moe.w3.transpose(1, 2).contiguous(), requires_grad=False
349+
)
350+
351+
quantize_(moe, moe_config, filter_fn=filter_fn)
352+
353+
before = moe(input)
354+
355+
# transposing for resharding support since only 2D resharding is supported
356+
new_last_dim = moe.w1.shape[-2]
357+
moe.w1 = torch.nn.Parameter(
358+
moe.w1.transpose(1, 2).reshape(-1, new_last_dim), requires_grad=False
359+
)
360+
new_last_dim = moe.w2.shape[-2]
361+
moe.w2 = torch.nn.Parameter(
362+
moe.w2.transpose(1, 2).reshape(-1, new_last_dim), requires_grad=False
363+
)
364+
new_last_dim = moe.w3.shape[-2]
365+
moe.w3 = torch.nn.Parameter(
366+
moe.w3.transpose(1, 2).reshape(-1, new_last_dim), requires_grad=False
367+
)
368+
369+
moe.w1 = torch.nn.Parameter(
370+
moe.w1.unflatten(0, (num_experts, -1)).squeeze(dim=0),
371+
requires_grad=False,
372+
)
373+
moe.w2 = torch.nn.Parameter(
374+
moe.w2.unflatten(0, (num_experts, -1)).squeeze(dim=0),
375+
requires_grad=False,
376+
)
377+
moe.w3 = torch.nn.Parameter(
378+
moe.w3.unflatten(0, (num_experts, -1)).squeeze(dim=0),
379+
requires_grad=False,
380+
)
381+
382+
# transpose again to recover the original weights
383+
moe.w1 = torch.nn.Parameter(moe.w1.transpose(1, 2), requires_grad=False)
384+
moe.w2 = torch.nn.Parameter(moe.w2.transpose(1, 2), requires_grad=False)
385+
moe.w3 = torch.nn.Parameter(moe.w3.transpose(1, 2), requires_grad=False)
386+
387+
after = moe(input)
388+
self.assertEqual(before, after)
389+
390+
state_dicts = [moe1.state_dict(), moe2.state_dict()]
391+
# align the scale parameter so they can be concatenated
392+
for key in ["w1", "w2", "w3"]:
393+
weights = [st[key] for st in state_dicts]
394+
for i in range(1, len(weights)):
395+
weights[i].scale = weights[0].scale
396+
weights[i].zero_point = weights[0].zero_point
397+
398+
def process_key(key: str) -> torch.Tensor:
399+
tensors = [s[key] for s in state_dicts]
400+
# Note: we have a hacky implementation for cat in user codebase
401+
# since it is not implemented correctly before
402+
if key == "w2":
403+
return torch.cat(tensors, dim=-1)
404+
else:
405+
return torch.cat(tensors, dim=-2)
406+
407+
new_state_dict = {}
408+
for key in ["w1", "w2", "w3"]:
409+
new_state_dict[key] = process_key(key)
410+
411+
moe_combined.w1 = torch.nn.Parameter(
412+
moe_combined.w1.transpose(1, 2), requires_grad=False
413+
)
414+
moe_combined.w2 = torch.nn.Parameter(
415+
moe_combined.w2.transpose(1, 2), requires_grad=False
416+
)
417+
moe_combined.w3 = torch.nn.Parameter(
418+
moe_combined.w3.transpose(1, 2), requires_grad=False
419+
)
420+
moe_combined.load_state_dict(new_state_dict, assign=True)
421+
# make sure it runs
422+
moe_combined(input)
423+
142424

425+
instantiate_parametrized_tests(TestInt4Tensor)
143426

144427
if __name__ == "__main__":
145428
run_tests()

0 commit comments

Comments
 (0)