Skip to content

Commit efd8a41

Browse files
szyszyzysfacebook-github-bot
authored andcommitted
fix: improve formatting and resolve minor bug for better utility (#2634)
Summary: Pull Request resolved: #2634 Reviewed By: metascroy Differential Revision: D79119974
1 parent b4351a7 commit efd8a41

File tree

7 files changed

+107
-71
lines changed

7 files changed

+107
-71
lines changed

torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,17 @@ chunked and interleaved during the packing process.
4444
* @param input Pointer to the source activation matrix (float32, row-major).
4545
*/
4646
template <int mr_, int kr_, int sr_>
47-
inline void pack_activations(float* output, int m, int k, const float* input) {
47+
inline void pack_activations(
48+
float* output,
49+
int m,
50+
int k,
51+
const float* input,
52+
int mr,
53+
int kr,
54+
int sr) {
55+
(void)mr; // unused
56+
(void)kr; // unused
57+
(void)sr; // unused {
4858
activation_packing::pack_activations<mr_, kr_, sr_>(output, m, k, input);
4959
}
5060

@@ -100,7 +110,7 @@ row-major).
100110
* @param bias Pointer to the bias vector (float32, row-major).
101111
*/
102112
template <int weight_nbit_, int nr_, int kr_, int sr_>
103-
void pack_weights_for_groupwise_lut_kernel(
113+
void pack_weights(
104114
/*output*/
105115
void* packed_weights_ptr,
106116
/*inputs*/
@@ -113,7 +123,13 @@ void pack_weights_for_groupwise_lut_kernel(
113123
int lut_group_size,
114124
bool has_scales,
115125
bool has_bias,
116-
const float* bias) {
126+
const float* bias,
127+
int nr,
128+
int kr,
129+
int sr) {
130+
(void)nr; // unused
131+
(void)kr; // unused
132+
(void)sr; // unused
117133
weight_packing::pack_weights<weight_nbit_, nr_, kr_, sr_>(
118134
packed_weights_ptr,
119135
weight_qvals_indices,
@@ -190,7 +206,11 @@ inline void groupwise_lowbit_weight_lut_kernel_1x4x32(
190206
* @param k The K dimension (width) of the activation matrix.
191207
* @return The byte offset from the start of the buffer.
192208
*/
193-
inline size_t packed_activations_offset(int m_idx, int k) {
209+
inline size_t
210+
packed_activations_offset(int m_idx, int k, int mr, int kr, int sr) {
211+
(void)mr; // unused
212+
(void)kr; // unused
213+
(void)sr; // unused
194214
// For a simple padded row-major format, the offset is just m_idx * k.
195215
return sizeof(float) * m_idx * k;
196216
}

torchao/experimental/kernels/cpu/aarch64/tests/test_lut.cpp

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,13 @@ void test_groupwise_lowbit_lut_kernel(
7171
std::vector<float> packed_activations_buffer(
7272
kernel_api::packed_activations_size(m, k, mr_, kr_, sr_));
7373
kernel_api::pack_activations<mr_, kr_, sr_>(
74-
packed_activations_buffer.data(), m, k, source_activations.data());
74+
packed_activations_buffer.data(),
75+
m,
76+
k,
77+
source_activations.data(),
78+
mr_,
79+
kr_,
80+
sr_);
7581
// 3. Pack Weights
7682
std::vector<char> packed_weights(kernel_api::packed_weights_size(
7783
n,
@@ -83,19 +89,21 @@ void test_groupwise_lowbit_lut_kernel(
8389
nr_,
8490
kr_,
8591
sr_));
86-
kernel_api::
87-
pack_weights_for_groupwise_lut_kernel<weight_nbit_, nr_, kr_, sr_>(
88-
packed_weights.data(),
89-
test_case.weight_qval_indices.data(),
90-
test_case.weight_scales.data(),
91-
test_case.weight_luts.data(),
92-
n,
93-
k,
94-
flat_scale_group_size,
95-
flat_lut_group_size,
96-
has_scales_,
97-
has_bias,
98-
test_case.bias.data());
92+
kernel_api::pack_weights<weight_nbit_, nr_, kr_, sr_>(
93+
packed_weights.data(),
94+
test_case.weight_qval_indices.data(),
95+
test_case.weight_scales.data(),
96+
test_case.weight_luts.data(),
97+
n,
98+
k,
99+
flat_scale_group_size,
100+
flat_lut_group_size,
101+
has_scales_,
102+
has_bias,
103+
test_case.bias.data(),
104+
nr_,
105+
kr_,
106+
sr_);
99107

100108
// 4. Run the kernel
101109
std::vector<float> output(m * n);

torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -640,11 +640,10 @@ struct groupwise_lowbit_weight_lut_test_case {
640640
const int total_weights = n * k;
641641
// Frequencies are controlled by their group sizes.
642642
assert(total_weights % scale_group_size == 0);
643-
assert(total_weights % lut_group_size == 0);
644643

645644
// The number of unique scales/LUTs is derived directly from their group size.
646645
const int num_scales = total_weights / scale_group_size;
647-
const int num_luts = total_weights / lut_group_size;
646+
const int num_luts = (total_weights + lut_group_size - 1) / lut_group_size;
648647
const int lut_size = 1 << weight_nbit;
649648
std::mt19937 gen(std::random_device{}());
650649

@@ -726,9 +725,6 @@ struct groupwise_lowbit_weight_lut_test_case {
726725
int weight_nbit, bool has_scales,
727726
bool has_bias, bool has_clamp) {
728727

729-
std::cout << "[Generator Info] Using 'Per-Group' model.\n"
730-
<< " - Both scales and LUTs will switch every " << group_size << " weights." << std::endl;
731-
732728
// Just call the decoupled generator with the same group size for both.
733729
return _generate_master(
734730
m, k, n,
@@ -748,10 +744,6 @@ struct groupwise_lowbit_weight_lut_test_case {
748744
int scale_group_size, int lut_group_size, int weight_nbit, bool has_scales,
749745
bool has_bias, bool has_clamp) {
750746

751-
std::cout << "[Generator Info] Using 'Decoupled Grouping' model.\n"
752-
<< " - Scales will switch every " << scale_group_size << " weights.\n"
753-
<< " - LUTs will switch every " << lut_group_size << " weights." << std::endl;
754-
755747
return _generate_master(
756748
m, k, n,
757749
scale_group_size, lut_group_size,

torchao/experimental/ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@ void pack_weights_operator(
2828
const float* weight_scales,
2929
const float* weight_luts,
3030
const float* bias) {
31-
TORCHAO_CHECK(
32-
lut_group_size % scale_group_size == 0,
33-
"scale_group_size must devide lut_group_size");
34-
TORCHAO_CHECK(k % scale_group_size == 0, "scale_group_size must divide k");
31+
if (uk.has_scales) {
32+
TORCHAO_CHECK(
33+
lut_group_size % scale_group_size == 0,
34+
"scale_group_size must devide lut_group_size");
35+
TORCHAO_CHECK(k % scale_group_size == 0, "scale_group_size must divide k");
36+
}
3537
TORCHAO_CHECK(
3638
lut_group_size % (k * uk.nr) == 0,
3739
"lut_group_size must be a multiple of k*nr");
@@ -139,14 +141,17 @@ void groupwise_lowbit_weight_lut_parallel_operator(
139141
bool has_clamp,
140142
float clamp_min,
141143
float clamp_max) {
142-
TORCHAO_CHECK(
143-
lut_group_size % scale_group_size == 0,
144-
"scale_group_size must divide lut_group_size");
145-
TORCHAO_CHECK(k % scale_group_size == 0, "scale_group_size must divide k");
144+
if (uk.has_scales) {
145+
TORCHAO_CHECK(
146+
lut_group_size % scale_group_size == 0,
147+
"scale_group_size must divide lut_group_size");
148+
TORCHAO_CHECK(k % scale_group_size == 0, "scale_group_size must divide k");
149+
TORCHAO_CHECK(
150+
scale_group_size % uk.kr == 0, "kr must divide scale_group_size");
151+
}
152+
146153
TORCHAO_CHECK(
147154
lut_group_size % (k * uk.nr) == 0, "(k * nr) must divide lut_group_size");
148-
TORCHAO_CHECK(
149-
scale_group_size % uk.kr == 0, "kr must divide scale_group_size");
150155
int config_idx = uk.select_config_idx(m);
151156
auto& kernel_config = uk.configs[config_idx];
152157
int n_step = uk.n_step;
@@ -191,7 +196,7 @@ void groupwise_lowbit_weight_lut_parallel_operator(
191196
mc_tile_size,
192197
k,
193198
activation_row_ptr,
194-
kernel_config.mr,
199+
uk.nr,
195200
uk.kr,
196201
uk.sr);
197202

torchao/experimental/ops/groupwise_lowbit_weight_lut/kernel_config.h

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -150,32 +150,37 @@ struct UKernelConfig {
150150
packed_weights_offset != nullptr,
151151
"packed_weights_offset_fn_type must be set");
152152
TORCHAO_CHECK(pack_weights != nullptr, "pack_weights must be set");
153-
154153
// 2. Validate the Array of Linear Configurations
155154
// At least one configuration must be defined.
156155
TORCHAO_CHECK(
157156
!configs.empty(),
158157
"At least one valid kernel configuration must be provided.");
159158

159+
bool configs_set = true; // first linear config must be set
160160
for (size_t i = 0; i < configs.size(); ++i) {
161-
const auto& config = configs[i];
162-
163-
TORCHAO_CHECK(
164-
config.packed_activations_size != nullptr,
165-
"config.packed_activations_size must be set");
166-
TORCHAO_CHECK(
167-
config.pack_activations != nullptr,
168-
"config.pack_activations must be set");
169-
TORCHAO_CHECK(config.kernel != nullptr, "config.kernel must be set");
170-
171-
if (i > 0) {
172-
const auto& prev_config = configs[i - 1];
161+
if (configs_set) {
162+
const auto& config = configs[i];
163+
173164
TORCHAO_CHECK(
174-
prev_config.m_step > 0,
175-
"There cannot be a gap in configurations (m_step=0 followed by m_step>0)");
165+
config.packed_activations_size != nullptr,
166+
"config.packed_activations_size must be set");
176167
TORCHAO_CHECK(
177-
prev_config.m_step < config.m_step,
178-
"m_step values in configs must be strictly increasing.");
168+
config.pack_activations != nullptr,
169+
"config.pack_activations must be set");
170+
TORCHAO_CHECK(config.kernel != nullptr, "config.kernel must be set");
171+
172+
if (i > 0) {
173+
const auto& prev_config = configs[i - 1];
174+
TORCHAO_CHECK(
175+
prev_config.m_step > 0,
176+
"There cannot be a gap in configurations (m_step=0 followed by m_step>0)");
177+
TORCHAO_CHECK(
178+
prev_config.m_step < config.m_step,
179+
"m_step values in configs must be strictly increasing.");
180+
}
181+
if (i + 1 < configs.size()) {
182+
configs_set = (configs[i + 1].m_step >= 1);
183+
}
179184
}
180185
}
181186
}

torchao/experimental/ops/groupwise_lowbit_weight_lut/kernel_selector.h

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
#include <unordered_map>
1414

1515
#if defined(TORCHAO_BUILD_CPU_AARCH64)
16-
#if defined(TORCHAO_ENABLE_ARM_NEON_DOT)
17-
#include <torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.h>
18-
#endif // TORCHAO_ENABLE_ARM_NEON_DOT
16+
#include <torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h>
1917
#endif // TORCHAO_BUILD_CPU_AARCH64
2018

2119
namespace torchao::ops::groupwise_lowbit_weight_lut {
@@ -122,19 +120,22 @@ void register_ukernel_config(
122120
torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut;
123121

124122
using kernel_fn_ptr_t =
125-
decltype(&kernel_api::kernel_lowbit_1x4x32_f32<weight_nbit, true>);
123+
decltype(&kernel_api::groupwise_lowbit_weight_lut_kernel_1x4x32<
124+
weight_nbit,
125+
true>);
126126
kernel_fn_ptr_t kernel_dispatcher;
127127

128128
if (format.has_scales) {
129-
kernel_dispatcher =
130-
&kernel_api::kernel_lowbit_1x4x32_f32<weight_nbit, /*has_scales=*/true>;
129+
kernel_dispatcher = &kernel_api::groupwise_lowbit_weight_lut_kernel_1x4x32<
130+
weight_nbit,
131+
/*has_scales=*/true>;
131132
} else {
132-
kernel_dispatcher =
133-
&kernel_api::
134-
kernel_lowbit_1x4x32_f32<weight_nbit, /*has_scales=*/false>;
133+
kernel_dispatcher = &kernel_api::groupwise_lowbit_weight_lut_kernel_1x4x32<
134+
weight_nbit,
135+
/*has_scales=*/false>;
135136
}
136137
if (format.nr == 4 && format.kr == 32 && format.sr == 8) {
137-
log_registration(format, "lut: kernel_lowbit_1x4x32_f32");
138+
log_registration(format, "lut: groupwise_lowbit_weight_lut_kernel_1x4x32");
138139
constexpr int nr = 4;
139140
constexpr int kr = 32;
140141
constexpr int sr = 8;
@@ -152,22 +153,25 @@ void register_ukernel_config(
152153
/*has_scales=*/format.has_scales,
153154
/*has_bias=*/format.has_bias,
154155
/*packed_weights_size_fn_type=*/
155-
&kernel_api::packed_weights_size<weight_nbit, nr, kr, sr>,
156+
&kernel_api::packed_weights_size,
157+
/*packed_weights_offset_fn_type=*/
158+
&kernel_api::packed_weights_offset,
156159
/*pack_weights_fn_type=*/
157160
&kernel_api::
158-
pack_weights_for_groupwise_lut_kernel<weight_nbit, nr, kr, sr>,
161+
pack_weights<weight_nbit, nr, kr, sr>,
159162
/*configs=*/{});
160163

161-
uk.configs[0] = UKernelConfig::group_config_type(
164+
uk.configs[0] = UKernelConfig::config_type
162165
{m_step,
163166
mr,
164167
&kernel_api::packed_activations_size,
165168
&kernel_api::packed_activations_offset,
166169
&kernel_api::pack_activations<mr, kr, sr>,
167-
kernel_dispatcher});
170+
kernel_dispatcher};
168171

169172
// Resgister the kernel config.
170173
table.register_ukernel_config(format, uarch, std::move(uk));
174+
return;
171175
}
172176
}
173177
#endif // TORCHAO_BUILD_CPU_AARCH64
@@ -206,7 +210,9 @@ UKernelConfig select_ukernel_config(torchao::ops::PackedWeightsHeader header) {
206210
register_ukernel_config<weight_nbit>(table, format, uarch);
207211

208212
ukernel = table.get_ukernel_config(header, uarch);
209-
assert(ukernel.has_value() && "Kernel registration failed for the current CPU microarchitecture.");
213+
assert(
214+
ukernel.has_value() &&
215+
"Kernel registration failed for the current CPU microarchitecture.");
210216
return ukernel.value();
211217
#else
212218
throw std::runtime_error(

torchao/experimental/ops/groupwise_lowbit_weight_lut/packed_weights_format.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ struct PackedWeightsFormat {
6363
static_cast<bool>(header.params[4]), // has_bias
6464
header.params[5], // nr
6565
header.params[6], // kr
66-
header.params[7], // sr
66+
header.params[7] // sr
6767
);
6868
}
6969

0 commit comments

Comments
 (0)