Skip to content

fix: improve formatting and resolve minor bug for better utility #2634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,17 @@ chunked and interleaved during the packing process.
* @param input Pointer to the source activation matrix (float32, row-major).
*/
template <int mr_, int kr_, int sr_>
inline void pack_activations(float* output, int m, int k, const float* input) {
inline void pack_activations(
float* output,
int m,
int k,
const float* input,
int mr,
int kr,
int sr) {
(void)mr; // unused
(void)kr; // unused
(void)sr; // unused {
activation_packing::pack_activations<mr_, kr_, sr_>(output, m, k, input);
}

Expand Down Expand Up @@ -100,7 +110,7 @@ row-major).
* @param bias Pointer to the bias vector (float32, row-major).
*/
template <int weight_nbit_, int nr_, int kr_, int sr_>
void pack_weights_for_groupwise_lut_kernel(
void pack_weights(
/*output*/
void* packed_weights_ptr,
/*inputs*/
Expand All @@ -113,7 +123,13 @@ void pack_weights_for_groupwise_lut_kernel(
int lut_group_size,
bool has_scales,
bool has_bias,
const float* bias) {
const float* bias,
int nr,
int kr,
int sr) {
(void)nr; // unused
(void)kr; // unused
(void)sr; // unused
weight_packing::pack_weights<weight_nbit_, nr_, kr_, sr_>(
packed_weights_ptr,
weight_qvals_indices,
Expand Down Expand Up @@ -190,7 +206,11 @@ inline void groupwise_lowbit_weight_lut_kernel_1x4x32(
* @param k The K dimension (width) of the activation matrix.
* @return The byte offset from the start of the buffer.
*/
inline size_t packed_activations_offset(int m_idx, int k) {
inline size_t
packed_activations_offset(int m_idx, int k, int mr, int kr, int sr) {
(void)mr; // unused
(void)kr; // unused
(void)sr; // unused
// For a simple padded row-major format, the offset is just m_idx * k.
return sizeof(float) * m_idx * k;
}
Expand Down
36 changes: 22 additions & 14 deletions torchao/experimental/kernels/cpu/aarch64/tests/test_lut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,13 @@ void test_groupwise_lowbit_lut_kernel(
std::vector<float> packed_activations_buffer(
kernel_api::packed_activations_size(m, k, mr_, kr_, sr_));
kernel_api::pack_activations<mr_, kr_, sr_>(
packed_activations_buffer.data(), m, k, source_activations.data());
packed_activations_buffer.data(),
m,
k,
source_activations.data(),
mr_,
kr_,
sr_);
// 3. Pack Weights
std::vector<char> packed_weights(kernel_api::packed_weights_size(
n,
Expand All @@ -83,19 +89,21 @@ void test_groupwise_lowbit_lut_kernel(
nr_,
kr_,
sr_));
kernel_api::
pack_weights_for_groupwise_lut_kernel<weight_nbit_, nr_, kr_, sr_>(
packed_weights.data(),
test_case.weight_qval_indices.data(),
test_case.weight_scales.data(),
test_case.weight_luts.data(),
n,
k,
flat_scale_group_size,
flat_lut_group_size,
has_scales_,
has_bias,
test_case.bias.data());
kernel_api::pack_weights<weight_nbit_, nr_, kr_, sr_>(
packed_weights.data(),
test_case.weight_qval_indices.data(),
test_case.weight_scales.data(),
test_case.weight_luts.data(),
n,
k,
flat_scale_group_size,
flat_lut_group_size,
has_scales_,
has_bias,
test_case.bias.data(),
nr_,
kr_,
sr_);

// 4. Run the kernel
std::vector<float> output(m * n);
Expand Down
10 changes: 1 addition & 9 deletions torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -640,11 +640,10 @@ struct groupwise_lowbit_weight_lut_test_case {
const int total_weights = n * k;
// Frequencies are controlled by their group sizes.
assert(total_weights % scale_group_size == 0);
assert(total_weights % lut_group_size == 0);

// The number of unique scales/LUTs is derived directly from their group size.
const int num_scales = total_weights / scale_group_size;
const int num_luts = total_weights / lut_group_size;
const int num_luts = (total_weights + lut_group_size - 1) / lut_group_size;
const int lut_size = 1 << weight_nbit;
std::mt19937 gen(std::random_device{}());

Expand Down Expand Up @@ -726,9 +725,6 @@ struct groupwise_lowbit_weight_lut_test_case {
int weight_nbit, bool has_scales,
bool has_bias, bool has_clamp) {

std::cout << "[Generator Info] Using 'Per-Group' model.\n"
<< " - Both scales and LUTs will switch every " << group_size << " weights." << std::endl;

// Just call the decoupled generator with the same group size for both.
return _generate_master(
m, k, n,
Expand All @@ -748,10 +744,6 @@ struct groupwise_lowbit_weight_lut_test_case {
int scale_group_size, int lut_group_size, int weight_nbit, bool has_scales,
bool has_bias, bool has_clamp) {

std::cout << "[Generator Info] Using 'Decoupled Grouping' model.\n"
<< " - Scales will switch every " << scale_group_size << " weights.\n"
<< " - LUTs will switch every " << lut_group_size << " weights." << std::endl;

return _generate_master(
m, k, n,
scale_group_size, lut_group_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ void pack_weights_operator(
const float* weight_scales,
const float* weight_luts,
const float* bias) {
TORCHAO_CHECK(
lut_group_size % scale_group_size == 0,
"scale_group_size must devide lut_group_size");
TORCHAO_CHECK(k % scale_group_size == 0, "scale_group_size must divide k");
if (uk.has_scales) {
TORCHAO_CHECK(
lut_group_size % scale_group_size == 0,
"scale_group_size must devide lut_group_size");
TORCHAO_CHECK(k % scale_group_size == 0, "scale_group_size must divide k");
}
TORCHAO_CHECK(
lut_group_size % (k * uk.nr) == 0,
"lut_group_size must be a multiple of k*nr");
Expand Down Expand Up @@ -139,14 +141,17 @@ void groupwise_lowbit_weight_lut_parallel_operator(
bool has_clamp,
float clamp_min,
float clamp_max) {
TORCHAO_CHECK(
lut_group_size % scale_group_size == 0,
"scale_group_size must divide lut_group_size");
TORCHAO_CHECK(k % scale_group_size == 0, "scale_group_size must divide k");
if (uk.has_scales) {
TORCHAO_CHECK(
lut_group_size % scale_group_size == 0,
"scale_group_size must divide lut_group_size");
TORCHAO_CHECK(k % scale_group_size == 0, "scale_group_size must divide k");
TORCHAO_CHECK(
scale_group_size % uk.kr == 0, "kr must divide scale_group_size");
}

TORCHAO_CHECK(
lut_group_size % (k * uk.nr) == 0, "(k * nr) must divide lut_group_size");
TORCHAO_CHECK(
scale_group_size % uk.kr == 0, "kr must divide scale_group_size");
int config_idx = uk.select_config_idx(m);
auto& kernel_config = uk.configs[config_idx];
int n_step = uk.n_step;
Expand Down Expand Up @@ -191,7 +196,7 @@ void groupwise_lowbit_weight_lut_parallel_operator(
mc_tile_size,
k,
activation_row_ptr,
kernel_config.mr,
uk.nr,
uk.kr,
uk.sr);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,32 +150,37 @@ struct UKernelConfig {
packed_weights_offset != nullptr,
"packed_weights_offset_fn_type must be set");
TORCHAO_CHECK(pack_weights != nullptr, "pack_weights must be set");

// 2. Validate the Array of Linear Configurations
// At least one configuration must be defined.
TORCHAO_CHECK(
!configs.empty(),
"At least one valid kernel configuration must be provided.");

bool configs_set = true; // first linear config must be set
for (size_t i = 0; i < configs.size(); ++i) {
const auto& config = configs[i];

TORCHAO_CHECK(
config.packed_activations_size != nullptr,
"config.packed_activations_size must be set");
TORCHAO_CHECK(
config.pack_activations != nullptr,
"config.pack_activations must be set");
TORCHAO_CHECK(config.kernel != nullptr, "config.kernel must be set");

if (i > 0) {
const auto& prev_config = configs[i - 1];
if (configs_set) {
const auto& config = configs[i];

TORCHAO_CHECK(
prev_config.m_step > 0,
"There cannot be a gap in configurations (m_step=0 followed by m_step>0)");
config.packed_activations_size != nullptr,
"config.packed_activations_size must be set");
TORCHAO_CHECK(
prev_config.m_step < config.m_step,
"m_step values in configs must be strictly increasing.");
config.pack_activations != nullptr,
"config.pack_activations must be set");
TORCHAO_CHECK(config.kernel != nullptr, "config.kernel must be set");

if (i > 0) {
const auto& prev_config = configs[i - 1];
TORCHAO_CHECK(
prev_config.m_step > 0,
"There cannot be a gap in configurations (m_step=0 followed by m_step>0)");
TORCHAO_CHECK(
prev_config.m_step < config.m_step,
"m_step values in configs must be strictly increasing.");
}
if (i + 1 < configs.size()) {
configs_set = (configs[i + 1].m_step >= 1);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
#include <unordered_map>

#if defined(TORCHAO_BUILD_CPU_AARCH64)
#if defined(TORCHAO_ENABLE_ARM_NEON_DOT)
#include <torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.h>
#endif // TORCHAO_ENABLE_ARM_NEON_DOT
#include <torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h>
#endif // TORCHAO_BUILD_CPU_AARCH64

namespace torchao::ops::groupwise_lowbit_weight_lut {
Expand Down Expand Up @@ -122,19 +120,22 @@ void register_ukernel_config(
torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut;

using kernel_fn_ptr_t =
decltype(&kernel_api::kernel_lowbit_1x4x32_f32<weight_nbit, true>);
decltype(&kernel_api::groupwise_lowbit_weight_lut_kernel_1x4x32<
weight_nbit,
true>);
kernel_fn_ptr_t kernel_dispatcher;

if (format.has_scales) {
kernel_dispatcher =
&kernel_api::kernel_lowbit_1x4x32_f32<weight_nbit, /*has_scales=*/true>;
kernel_dispatcher = &kernel_api::groupwise_lowbit_weight_lut_kernel_1x4x32<
weight_nbit,
/*has_scales=*/true>;
} else {
kernel_dispatcher =
&kernel_api::
kernel_lowbit_1x4x32_f32<weight_nbit, /*has_scales=*/false>;
kernel_dispatcher = &kernel_api::groupwise_lowbit_weight_lut_kernel_1x4x32<
weight_nbit,
/*has_scales=*/false>;
}
if (format.nr == 4 && format.kr == 32 && format.sr == 8) {
log_registration(format, "lut: kernel_lowbit_1x4x32_f32");
log_registration(format, "lut: groupwise_lowbit_weight_lut_kernel_1x4x32");
constexpr int nr = 4;
constexpr int kr = 32;
constexpr int sr = 8;
Expand All @@ -152,22 +153,25 @@ void register_ukernel_config(
/*has_scales=*/format.has_scales,
/*has_bias=*/format.has_bias,
/*packed_weights_size_fn_type=*/
&kernel_api::packed_weights_size<weight_nbit, nr, kr, sr>,
&kernel_api::packed_weights_size,
/*packed_weights_offset_fn_type=*/
&kernel_api::packed_weights_offset,
/*pack_weights_fn_type=*/
&kernel_api::
pack_weights_for_groupwise_lut_kernel<weight_nbit, nr, kr, sr>,
pack_weights<weight_nbit, nr, kr, sr>,
/*configs=*/{});

uk.configs[0] = UKernelConfig::group_config_type(
uk.configs[0] = UKernelConfig::config_type
{m_step,
mr,
&kernel_api::packed_activations_size,
&kernel_api::packed_activations_offset,
&kernel_api::pack_activations<mr, kr, sr>,
kernel_dispatcher});
kernel_dispatcher};

// Resgister the kernel config.
table.register_ukernel_config(format, uarch, std::move(uk));
return;
}
}
#endif // TORCHAO_BUILD_CPU_AARCH64
Expand Down Expand Up @@ -206,7 +210,9 @@ UKernelConfig select_ukernel_config(torchao::ops::PackedWeightsHeader header) {
register_ukernel_config<weight_nbit>(table, format, uarch);

ukernel = table.get_ukernel_config(header, uarch);
assert(ukernel.has_value() && "Kernel registration failed for the current CPU microarchitecture.");
assert(
ukernel.has_value() &&
"Kernel registration failed for the current CPU microarchitecture.");
return ukernel.value();
#else
throw std::runtime_error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct PackedWeightsFormat {
static_cast<bool>(header.params[4]), // has_bias
header.params[5], // nr
header.params[6], // kr
header.params[7], // sr
header.params[7] // sr
);
}

Expand Down
Loading