Skip to content

Commit c28ee7b

Browse files
authored
Add tests
Differential Revision: D77629227 Pull Request resolved: #2624
1 parent f537110 commit c28ee7b

File tree

1 file changed

+342
-0
lines changed

1 file changed

+342
-0
lines changed
Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include <gtest/gtest.h>
8+
#if defined(TORCHAO_BUILD_CPU_AARCH64)
9+
#include <torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h>
10+
#endif // TORCHAO_BUILD_CPU_AARCH64
11+
#include <torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h>
12+
#include <torchao/experimental/ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.h>
13+
#include <torchao/experimental/ops/memory.h>
14+
#include <torchao/experimental/ops/parallel.h>
15+
16+
const float kTol = 1.0e-5;
17+
using namespace torchao::ops::groupwise_lowbit_weight_lut;
18+
19+
template <int weight_nbit, bool has_scales>
20+
UKernelConfig get_ukernel_config(bool has_bias) {
21+
namespace kernel =
22+
torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut;
23+
24+
int preferred_alignment = 16;
25+
int n_step = 8;
26+
constexpr int nr = 4;
27+
constexpr int kr = 32;
28+
constexpr int sr = 8;
29+
constexpr int mr = 1;
30+
int m_step = 1;
31+
32+
auto uk = UKernelConfig::make(
33+
preferred_alignment,
34+
n_step,
35+
nr,
36+
kr,
37+
sr,
38+
weight_nbit,
39+
has_scales,
40+
has_bias,
41+
&kernel::packed_weights_size,
42+
&kernel::packed_weights_offset,
43+
&kernel::pack_weights<weight_nbit, nr, kr, sr>,
44+
/*configs*/ {});
45+
46+
uk.configs[0] = UKernelConfig::config_type{
47+
m_step,
48+
mr,
49+
&kernel::packed_activations_size,
50+
&kernel::packed_activations_offset,
51+
&kernel::pack_activations<mr, kr, sr>,
52+
&kernel::
53+
groupwise_lowbit_weight_lut_kernel_1x4x32<weight_nbit, has_scales>};
54+
return uk;
55+
}
56+
57+
template <int weight_nbit, bool has_scales>
58+
void test_groupwise_lowbit_weight_lut(
59+
int m,
60+
int k,
61+
int n,
62+
int scale_group_size,
63+
int lut_group_size,
64+
bool has_bias,
65+
bool has_clamp,
66+
const UKernelConfig* ukernel_config_arg = nullptr) {
67+
UKernelConfig ukernel_config;
68+
if (ukernel_config_arg != nullptr) {
69+
ukernel_config = *ukernel_config_arg;
70+
} else {
71+
ukernel_config = get_ukernel_config<weight_nbit, has_scales>(has_bias);
72+
}
73+
74+
auto test_case = torchao::groupwise_lowbit_weight_lut_test_case::
75+
generate_with_decoupled_grouping(
76+
m,
77+
k,
78+
n,
79+
scale_group_size,
80+
lut_group_size,
81+
weight_nbit,
82+
has_scales,
83+
has_bias,
84+
has_clamp);
85+
86+
auto output = std::vector<float>(m * n);
87+
88+
for (auto num_threads : {1, 4, 500}) {
89+
torchao::set_num_threads(num_threads);
90+
EXPECT_EQ(torchao::get_num_threads(), num_threads);
91+
auto packed_weight_data_size = ukernel_config.packed_weights_size(
92+
n,
93+
k,
94+
weight_nbit,
95+
scale_group_size,
96+
has_scales,
97+
has_bias,
98+
ukernel_config.nr,
99+
ukernel_config.kr,
100+
ukernel_config.sr);
101+
auto preferred_packed_weight_data_alignment =
102+
ukernel_config.preferred_alignment;
103+
auto packed_weights = torchao::make_aligned_byte_ptr(
104+
preferred_packed_weight_data_alignment, packed_weight_data_size);
105+
106+
pack_weights_operator(
107+
ukernel_config,
108+
// Outputs
109+
packed_weights.get(),
110+
// Inputs
111+
n,
112+
k,
113+
scale_group_size,
114+
lut_group_size,
115+
test_case.weight_qval_indices.data(),
116+
test_case.weight_scales.data(),
117+
test_case.weight_luts.data(),
118+
test_case.bias.data());
119+
120+
groupwise_lowbit_weight_lut_parallel_operator(
121+
ukernel_config,
122+
std::nullopt,
123+
output.data(),
124+
m,
125+
n,
126+
k,
127+
scale_group_size,
128+
lut_group_size,
129+
packed_weights.get(),
130+
test_case.activations.data(),
131+
has_clamp,
132+
test_case.clamp_min,
133+
test_case.clamp_max);
134+
135+
float tol = kTol;
136+
for (int i = 0; i < m * n; i++) {
137+
EXPECT_NEAR(output[i], test_case.expected_output[i], tol);
138+
}
139+
}
140+
}
141+
142+
struct KernelTestParams {
143+
int m;
144+
int k;
145+
int n;
146+
int scale_group_size;
147+
int lut_group_size;
148+
bool has_bias;
149+
bool has_clamp;
150+
};
151+
152+
class ComprehensiveKernelTest
153+
: public ::testing::TestWithParam<KernelTestParams> {};
154+
155+
TEST_P(ComprehensiveKernelTest, kernel_test_has_scales_true) {
156+
const KernelTestParams& params = GetParam();
157+
158+
constexpr bool has_scales = true;
159+
160+
for (int weight_nbit : {1, 2, 3, 4}) {
161+
switch (weight_nbit) {
162+
case 1:
163+
test_groupwise_lowbit_weight_lut<1, has_scales>(
164+
params.m,
165+
params.k,
166+
params.n,
167+
params.scale_group_size,
168+
params.lut_group_size,
169+
params.has_bias,
170+
params.has_clamp);
171+
break;
172+
case 2:
173+
test_groupwise_lowbit_weight_lut<2, has_scales>(
174+
params.m,
175+
params.k,
176+
params.n,
177+
params.scale_group_size,
178+
params.lut_group_size,
179+
params.has_bias,
180+
params.has_clamp);
181+
break;
182+
case 3:
183+
test_groupwise_lowbit_weight_lut<3, has_scales>(
184+
params.m,
185+
params.k,
186+
params.n,
187+
params.scale_group_size,
188+
params.lut_group_size,
189+
params.has_bias,
190+
params.has_clamp);
191+
break;
192+
case 4:
193+
test_groupwise_lowbit_weight_lut<4, has_scales>(
194+
params.m,
195+
params.k,
196+
params.n,
197+
params.scale_group_size,
198+
params.lut_group_size,
199+
params.has_bias,
200+
params.has_clamp);
201+
break;
202+
default:
203+
FAIL() << "Unsupported weight_nbit value: " << weight_nbit;
204+
}
205+
}
206+
}
207+
208+
TEST_P(ComprehensiveKernelTest, kernel_test_has_scales_false) {
209+
const KernelTestParams& params = GetParam();
210+
211+
constexpr bool has_scales = false;
212+
213+
for (int weight_nbit : {1, 2, 3, 4}) {
214+
switch (weight_nbit) {
215+
case 1:
216+
test_groupwise_lowbit_weight_lut<1, has_scales>(
217+
params.m,
218+
params.k,
219+
params.n,
220+
params.scale_group_size,
221+
params.lut_group_size,
222+
params.has_bias,
223+
params.has_clamp);
224+
break;
225+
case 2:
226+
test_groupwise_lowbit_weight_lut<2, has_scales>(
227+
params.m,
228+
params.k,
229+
params.n,
230+
params.scale_group_size,
231+
params.lut_group_size,
232+
params.has_bias,
233+
params.has_clamp);
234+
break;
235+
case 3:
236+
test_groupwise_lowbit_weight_lut<3, has_scales>(
237+
params.m,
238+
params.k,
239+
params.n,
240+
params.scale_group_size,
241+
params.lut_group_size,
242+
params.has_bias,
243+
params.has_clamp);
244+
break;
245+
case 4:
246+
test_groupwise_lowbit_weight_lut<4, has_scales>(
247+
params.m,
248+
params.k,
249+
params.n,
250+
params.scale_group_size,
251+
params.lut_group_size,
252+
params.has_bias,
253+
params.has_clamp);
254+
break;
255+
default:
256+
FAIL() << "Unsupported weight_nbit value: " << weight_nbit;
257+
}
258+
}
259+
}
260+
261+
INSTANTIATE_TEST_SUITE_P(
262+
KernelEdgeCases,
263+
ComprehensiveKernelTest,
264+
::testing::Values(
265+
// Flag-specific tests
266+
KernelTestParams{
267+
8,
268+
64,
269+
16,
270+
32,
271+
256,
272+
/*has_bias=*/true,
273+
/*has_clamp=*/true},
274+
KernelTestParams{
275+
8,
276+
64,
277+
16,
278+
32,
279+
256,
280+
/*has_bias=*/true,
281+
/*has_clamp=*/false},
282+
KernelTestParams{
283+
8,
284+
64,
285+
16,
286+
32,
287+
256,
288+
/*has_bias=*/false,
289+
/*has_clamp=*/true},
290+
KernelTestParams{
291+
8,
292+
64,
293+
16,
294+
32,
295+
256,
296+
/*has_bias=*/false,
297+
/*has_clamp=*/false},
298+
299+
// Prime number dimensions for m and n
300+
KernelTestParams{
301+
7,
302+
64,
303+
13,
304+
32,
305+
256,
306+
/*has_bias=*/true,
307+
/*has_clamp=*/true},
308+
KernelTestParams{
309+
13,
310+
128,
311+
17,
312+
64,
313+
512,
314+
/*has_bias=*/false,
315+
/*has_clamp=*/false},
316+
KernelTestParams{
317+
1,
318+
32,
319+
5,
320+
32,
321+
128,
322+
/*has_bias=*/true,
323+
/*has_clamp=*/false},
324+
325+
// Varying Dimensions and Group Sizes
326+
KernelTestParams{8, 64, 16, 32, 256, true, true},
327+
KernelTestParams{8, 64, 12, 32, 256, true, false},
328+
KernelTestParams{7, 128, 24, 64, 512, false, true},
329+
KernelTestParams{1, 32, 4, 32, 128, true, true},
330+
331+
// Unaligned M
332+
KernelTestParams{7, 64, 16, 32, 256, true, false},
333+
KernelTestParams{5, 64, 16, 32, 256, false, true},
334+
KernelTestParams{1, 64, 16, 32, 256, true, true}));
335+
336+
void PrintTo(const KernelTestParams& params, std::ostream* os) {
337+
*os << "KernelTestParams(m=" << params.m << ", k=" << params.k
338+
<< ", n=" << params.n << ", scale_gs=" << params.scale_group_size
339+
<< ", lut_gs=" << params.lut_group_size
340+
<< ", has_bias=" << (params.has_bias ? "true" : "false")
341+
<< ", has_clamp=" << (params.has_clamp ? "true" : "false") << ")";
342+
}

0 commit comments

Comments
 (0)