Skip to content

get_rows & dequantize function implementation for repacked weights of type q4_K (q4_Kx8) #3291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 159 additions & 1 deletion ggml/src/ggml-cpu/repack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,11 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR

size += sizeof_mmid_row_mapping*ne02*(ne12 + 1);

return true;
}
case GGML_OP_GET_ROWS:
{
size = 0; // GET_ROWS (standard and repacked) doesn't need a work buffer
return true;
}
default:
Expand All @@ -1197,6 +1202,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
case GGML_OP_MUL_MAT_ID:
forward_mul_mat_id(params, op);
return true;
case GGML_OP_GET_ROWS:
forward_get_rows(params, op);
return true;
default:
// GGML_ABORT("fatal error");
break;
Expand Down Expand Up @@ -1405,6 +1413,145 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
#undef MMID_MATRIX_ROW
}

void forward_get_rows(const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_Q4_K:
ggml_compute_forward_get_rows_q4_Kx8(params, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
}

static void ggml_compute_forward_get_rows_q4_Kx8(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];

GGML_TENSOR_BINARY_OP_LOCALS
const int64_t nc = ne00;
const int64_t nr = ggml_nelements(src1);

assert(ne0 == nc);
assert(ne02 == ne11);
assert(nb00 == ggml_type_size(src0->type));
assert(ggml_nrows(dst) == nr);

const int ith = params->ith;
const int nth = params->nth;

// rows per thread
const int dr = (nr + nth - 1) / nth;

// row range for this thread
const int ir0 = dr * ith;
const int ir1 = MIN(ir0 + dr, nr);

constexpr int nrows_interleaved = 8;
const size_t sizeof_one_repacked_block = sizeof(block_q4_Kx8);

const int num_repacked_blocks_per_row_width = nc / QK_K;

const size_t stride_between_actual_row_groups = num_repacked_blocks_per_row_width * sizeof_one_repacked_block;

for (int64_t i = ir0; i < ir1; ++i) {
const int64_t i12 = i / (ne11 * ne10);
const int64_t i11 = (i - i12 * ne11 * ne10) / ne10;
const int64_t i10 = (i - i12 * ne11 * ne10 - i11 * ne10);
const int64_t i01 = *(int32_t *)((char *)src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); // original logical row

GGML_ASSERT(i01 >= 0 && i01 < ne01);

const int row_group_idx = i01 / nrows_interleaved;
const int row_idx_in_group = i01 % nrows_interleaved;

const char * base_ptr_for_higher_dims_in_src0 = (const char *)src0->data + i11 * nb02 + i12 * nb03;

// Pointer to the first block_q4_Kx8 of the identified row_group_idx
const block_q4_Kx8 * p_first_repacked_block_of_group_x8 = (const block_q4_Kx8 *)(base_ptr_for_higher_dims_in_src0 + row_group_idx * stride_between_actual_row_groups);

dequantize_row_q4_Kx8(
p_first_repacked_block_of_group_x8,
(float *)((char *)dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3), nc, row_idx_in_group);
}
}

/**
* Dequantizes a single logical row from the repacked q4_Kx8 data format.
*
* @param p_repacked_blocks Pointer to the start of the 'block_q4_Kx8' structures for the entire row.
* @param y Output buffer for the dequantized float values.
* @param k Total number of elements (columns) in the logical row.
* @param row_idx_in_group The index (0-7) of the logical row to extract from the interleaved data.
*/

static void dequantize_row_q4_Kx8(
const void * GGML_RESTRICT p_repacked_blocks,
float * GGML_RESTRICT y,
int64_t k,
int row_idx_in_group) {

assert(k % QK_K == 0);
assert(row_idx_in_group >= 0 && row_idx_in_group < 8);

const int nb = k / QK_K;
const block_q4_Kx8 * blocks = (const block_q4_Kx8 *)p_repacked_blocks;

for (int i = 0; i < nb; i++) {
const block_q4_Kx8 * current_block = &blocks[i];

const float d_super_block = GGML_FP16_TO_FP32(current_block->d[row_idx_in_group]);
const float dmin_super_block = GGML_FP16_TO_FP32(current_block->dmin[row_idx_in_group]);

const uint8_t * ptr_qs_base = current_block->qs;
const uint8_t * ptr_repacked_scales = (const uint8_t *)current_block->scales;
int is = 0, chunk_group_start_idx = 0;
for (int j = 0; j < QK_K; j += 64) {

uint8_t sc1, m1_val, sc2, m2_val;
const uint8_t *scales_repacked_data;

scales_repacked_data = &ptr_repacked_scales[(is + 0) * 12];
get_scale_min_k4(row_idx_in_group, scales_repacked_data, &sc1, &m1_val);

scales_repacked_data = &ptr_repacked_scales[(is + 1) * 12];
get_scale_min_k4(row_idx_in_group, scales_repacked_data, &sc2, &m2_val);

const float d1 = d_super_block * sc1;
const float m1 = dmin_super_block * m1_val;
const float d2 = d_super_block * sc2;
const float m2 = dmin_super_block * m2_val;

for (int idx = 0; idx < 4; idx++) {
const uint8_t * ptr_qs_chunk = ptr_qs_base + ((chunk_group_start_idx + idx) * 64) + row_idx_in_group * 8;
for (int l = 0; l < 8; ++l) *y++ = d1 * (ptr_qs_chunk[l] & 0xF) - m1; // 16 elements of quants
}

for (int idx = 0; idx < 4; idx++) {
const uint8_t * ptr_qs_chunk = ptr_qs_base + ((chunk_group_start_idx + idx) * 64) + row_idx_in_group * 8;
for (int l = 0; l < 8; ++l) *y++ = d2 * (ptr_qs_chunk[l] >> 4) - m2; // 16 elements of quants
}
is += 2;
chunk_group_start_idx += 4;
}
}
}

static inline void get_scale_min_k4(int j, const uint8_t *GGML_RESTRICT s, uint8_t *GGML_RESTRICT d, uint8_t *GGML_RESTRICT m) {
if (j < 4) {
*d = s[j] & 63;
*m = s[j + 4] & 63;
} else {
*d = (s[j + 4] & 0xF) | ((s[j - 4] >> 6) << 4);
*m = (s[j + 4] >> 4) | ((s[j - 0] >> 6) << 4);
}
}

int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
(int) NB_COLS, (int) INTER_SIZE);
Expand Down Expand Up @@ -1538,12 +1685,23 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
//if (op->src[1]->type == GGML_TYPE_Q8_0) {
// return true;
//}
} else if (op->op == GGML_OP_GET_ROWS
&& op->src[0]->buffer
&& (ggml_n_dims(op->src[0]) == 2)
&& op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()
&& ggml_repack_get_optimal_repack_type(op->src[0])) {
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}
if (op->src[0]->type == GGML_TYPE_Q4_K) {
return true;
}
}
return false;
}

ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID || op->op == GGML_OP_GET_ROWS) {
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()) {
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
}
Expand Down
43 changes: 28 additions & 15 deletions src/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1437,24 +1437,25 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor *
// GPU and default CPU backend support all operators
op_supported = true;
} else {
switch (op) {
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
case GGML_OP_MUL_MAT: {
ggml_init_params params = {
/*.mem_size =*/ 2 * ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
ggml_init_params params = {
/*.mem_size =*/ 2 * ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};

ggml_context_ptr ctx_ptr { ggml_init(params) };
if (!ctx_ptr) {
throw std::runtime_error("failed to create ggml context");
}
ggml_context * ctx = ctx_ptr.get();
ggml_context_ptr ctx_ptr { ggml_init(params) };
if (!ctx_ptr) {
throw std::runtime_error("failed to create ggml context");
}
ggml_context * ctx = ctx_ptr.get();

ggml_tensor * op_tensor = nullptr;
ggml_tensor * op_tensor = nullptr;

int64_t n_ctx = hparams.n_audio_ctx;

int64_t n_ctx = hparams.n_audio_ctx;
switch (op) {
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT & GGML_OP_GET_ROWS (repacked - q4_K)
case GGML_OP_MUL_MAT: {
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
op_tensor = ggml_mul_mat(ctx, w, b);

Expand All @@ -1466,6 +1467,18 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor *
w->buffer = nullptr;
break;
}
case GGML_OP_GET_ROWS: {
ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx);
op_tensor = ggml_get_rows(ctx, w, b);

// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
GGML_ASSERT(w->buffer == nullptr);
w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
ggml_backend_buffer_free(w->buffer);
w->buffer = nullptr;
break;
}
default: {
op_supported = false;
break;
Expand Down