Skip to content

Commit 376d6d2

Browse files
authored
Revert "Update function params and corresponding usages. " (#2596)
Revert "Update function params and corresponding usages." This reverts commit 9da7ad5.
1 parent 75fc571 commit 376d6d2

File tree

3 files changed

+16
-30
lines changed

3 files changed

+16
-30
lines changed

torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,7 @@ chunked and interleaved during the packing process.
4444
* @param input Pointer to the source activation matrix (float32, row-major).
4545
*/
4646
template <int mr_, int kr_, int sr_>
47-
inline void pack_activations(
48-
float* output,
49-
int m,
50-
int k,
51-
const float* input,
52-
int mr,
53-
int kr,
54-
int sr) {
55-
(void)mr; // unused
56-
(void)kr; // unused
57-
(void)sr; // unused
47+
inline void pack_activations(float* output, int m, int k, const float* input) {
5848
activation_packing::pack_activations<mr_, kr_, sr_>(output, m, k, input);
5949
}
6050

@@ -110,7 +100,7 @@ row-major).
110100
* @param bias Pointer to the bias vector (float32, row-major).
111101
*/
112102
template <int weight_nbit_, int nr_, int kr_, int sr_>
113-
void pack_weights(
103+
void pack_weights_for_groupwise_lut_kernel(
114104
/*output*/
115105
void* packed_weights_ptr,
116106
/*inputs*/
@@ -123,14 +113,7 @@ void pack_weights(
123113
int lut_group_size,
124114
bool has_scales,
125115
bool has_bias,
126-
const float* bias,
127-
int nr,
128-
int kr,
129-
int sr) {
130-
(void)nr; // unused
131-
(void)kr; // unused
132-
(void)sr; // unused
133-
116+
const float* bias) {
134117
weight_packing::pack_weights<weight_nbit_, nr_, kr_, sr_>(
135118
packed_weights_ptr,
136119
weight_qvals_indices,
@@ -207,12 +190,7 @@ inline void groupwise_lowbit_weight_lut_kernel_1x4x32(
207190
* @param k The K dimension (width) of the activation matrix.
208191
* @return The byte offset from the start of the buffer.
209192
*/
210-
inline size_t
211-
packed_activations_offset(int m_idx, int k, int mr, int kr, int sr) {
212-
(void)mr; // unused
213-
(void)kr; // unused
214-
(void)sr; // unused
215-
193+
inline size_t packed_activations_offset(int m_idx, int k) {
216194
// For a simple padded row-major format, the offset is just m_idx * k.
217195
return sizeof(float) * m_idx * k;
218196
}

torchao/experimental/kernels/cpu/aarch64/tests/test_lut.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ void test_groupwise_lowbit_lut_kernel(
7171
std::vector<float> packed_activations_buffer(
7272
kernel_api::packed_activations_size(m, k, mr_, kr_, sr_));
7373
kernel_api::pack_activations<mr_, kr_, sr_>(
74-
packed_activations_buffer.data(), m, k, source_activations.data(), mr_, kr_, sr_);
74+
packed_activations_buffer.data(), m, k, source_activations.data());
7575
// 3. Pack Weights
7676
std::vector<char> packed_weights(kernel_api::packed_weights_size(
7777
n,
@@ -84,7 +84,7 @@ void test_groupwise_lowbit_lut_kernel(
8484
kr_,
8585
sr_));
8686
kernel_api::
87-
pack_weights<weight_nbit_, nr_, kr_, sr_>(
87+
pack_weights_for_groupwise_lut_kernel<weight_nbit_, nr_, kr_, sr_>(
8888
packed_weights.data(),
8989
test_case.weight_qval_indices.data(),
9090
test_case.weight_scales.data(),
@@ -95,7 +95,7 @@ void test_groupwise_lowbit_lut_kernel(
9595
flat_lut_group_size,
9696
has_scales_,
9797
has_bias,
98-
test_case.bias.data(), nr_, kr_, sr_);
98+
test_case.bias.data());
9999

100100
// 4. Run the kernel
101101
std::vector<float> output(m * n);

torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,10 +640,11 @@ struct groupwise_lowbit_weight_lut_test_case {
640640
const int total_weights = n * k;
641641
// Frequencies are controlled by their group sizes.
642642
assert(total_weights % scale_group_size == 0);
643+
assert(total_weights % lut_group_size == 0);
643644

644645
// The number of unique scales/LUTs is derived directly from their group size.
645646
const int num_scales = total_weights / scale_group_size;
646-
const int num_luts = (total_weights + lut_group_size - 1) / lut_group_size;
647+
const int num_luts = total_weights / lut_group_size;
647648
const int lut_size = 1 << weight_nbit;
648649
std::mt19937 gen(std::random_device{}());
649650

@@ -725,6 +726,9 @@ struct groupwise_lowbit_weight_lut_test_case {
725726
int weight_nbit, bool has_scales,
726727
bool has_bias, bool has_clamp) {
727728

729+
std::cout << "[Generator Info] Using 'Per-Group' model.\n"
730+
<< " - Both scales and LUTs will switch every " << group_size << " weights." << std::endl;
731+
728732
// Just call the decoupled generator with the same group size for both.
729733
return _generate_master(
730734
m, k, n,
@@ -744,6 +748,10 @@ struct groupwise_lowbit_weight_lut_test_case {
744748
int scale_group_size, int lut_group_size, int weight_nbit, bool has_scales,
745749
bool has_bias, bool has_clamp) {
746750

751+
std::cout << "[Generator Info] Using 'Decoupled Grouping' model.\n"
752+
<< " - Scales will switch every " << scale_group_size << " weights.\n"
753+
<< " - LUTs will switch every " << lut_group_size << " weights." << std::endl;
754+
747755
return _generate_master(
748756
m, k, n,
749757
scale_group_size, lut_group_size,

0 commit comments

Comments
 (0)