Skip to content

Commit 8f34111

Browse files
committed
Revert "CUDA: fix quantized KV cache + multiple sequences (ggml-org#14822)"
This reverts commit 07a19e2.
1 parent 8d05773 commit 8f34111

File tree

2 files changed

+35
-107
lines changed

2 files changed

+35
-107
lines changed

ggml/src/ggml-cuda/convert.cu

Lines changed: 17 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,33 +6,24 @@
66
#define CUDA_Q8_0_NE_ALIGN 2048
77

88
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
9-
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
10-
const int64_t ne00, const int64_t ne01, const int64_t ne02,
11-
const int64_t s01, const int64_t s02, const int64_t s03) {
12-
const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
9+
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
10+
const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x);
1311

14-
if (i00 >= ne00) {
12+
if (i >= k) {
1513
return;
1614
}
1715

18-
const int64_t i01 = blockIdx.y;
19-
const int64_t i02 = blockIdx.z % ne02;
20-
const int64_t i03 = blockIdx.z / ne02;
21-
22-
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
23-
24-
const int64_t ib = ibx0 + i00/qk; // block index
25-
const int64_t iqs = (i00%qk)/qr; // quant index
26-
const int64_t iybs = i00 - i00%qk; // y block start index
16+
const int64_t ib = i/qk; // block index
17+
const int64_t iqs = (i%qk)/qr; // quant index
18+
const int64_t iybs = i - i%qk; // y block start index
2719
const int64_t y_offset = qr == 1 ? 1 : qk/2;
2820

2921
// dequantize
3022
dfloat2 v;
3123
dequantize_kernel(vx, ib, iqs, v);
3224

33-
const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
34-
y[iy0 + 0] = v.x;
35-
y[iy0 + y_offset] = v.y;
25+
y[iybs + iqs + 0] = v.x;
26+
y[iybs + iqs + y_offset] = v.y;
3627
}
3728

3829
template <bool need_check>
@@ -466,17 +457,9 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
466457
}
467458

468459
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
469-
static void dequantize_block_cuda(const void * vx, dst_t * y,
470-
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
471-
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
472-
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
473-
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
474-
(vx, y, ne00, ne01, ne02, s01, s02, s03);
475-
}
476-
477-
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
478-
static void dequantize_block_cont_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
479-
dequantize_block_cuda<qk, qr, dequantize_kernel, dst_t>(vx, y, k, 1, 1, 1, k/qk, k/qk, k/qk, stream);
460+
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
461+
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
462+
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
480463
}
481464

482465
static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
@@ -641,14 +624,14 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
641624
case GGML_TYPE_Q4_1:
642625
return dequantize_row_q4_1_cuda;
643626
case GGML_TYPE_Q5_0:
644-
return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
627+
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
645628
case GGML_TYPE_Q5_1:
646-
return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
629+
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
647630
case GGML_TYPE_Q8_0:
648631
if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
649632
return dequantize_block_q8_0_f16_cuda;
650633
}
651-
return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
634+
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
652635
case GGML_TYPE_Q2_K:
653636
return dequantize_row_q2_K_cuda;
654637
case GGML_TYPE_Q3_K:
@@ -693,11 +676,11 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
693676
case GGML_TYPE_Q4_1:
694677
return dequantize_row_q4_1_cuda;
695678
case GGML_TYPE_Q5_0:
696-
return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
679+
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
697680
case GGML_TYPE_Q5_1:
698-
return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
681+
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
699682
case GGML_TYPE_Q8_0:
700-
return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
683+
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
701684
case GGML_TYPE_Q2_K:
702685
return dequantize_row_q2_K_cuda;
703686
case GGML_TYPE_Q3_K:
@@ -739,16 +722,6 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
739722
switch (type) {
740723
case GGML_TYPE_F32:
741724
return convert_unary_cuda<float>;
742-
case GGML_TYPE_Q4_0:
743-
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
744-
case GGML_TYPE_Q4_1:
745-
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
746-
case GGML_TYPE_Q5_0:
747-
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
748-
case GGML_TYPE_Q5_1:
749-
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
750-
case GGML_TYPE_Q8_0:
751-
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
752725
case GGML_TYPE_BF16:
753726
return convert_unary_cuda<nv_bfloat16>;
754727
default:
@@ -760,16 +733,6 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
760733
switch (type) {
761734
case GGML_TYPE_F32:
762735
return convert_unary_cuda<float, nv_bfloat16>;
763-
case GGML_TYPE_Q4_0:
764-
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
765-
case GGML_TYPE_Q4_1:
766-
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
767-
case GGML_TYPE_Q5_0:
768-
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
769-
case GGML_TYPE_Q5_1:
770-
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
771-
case GGML_TYPE_Q8_0:
772-
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
773736
case GGML_TYPE_F16:
774737
return convert_unary_cuda<half, nv_bfloat16>;
775738
default:
@@ -781,16 +744,6 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
781744
switch (type) {
782745
case GGML_TYPE_F16:
783746
return convert_unary_cuda<half, float>;
784-
case GGML_TYPE_Q4_0:
785-
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
786-
case GGML_TYPE_Q4_1:
787-
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
788-
case GGML_TYPE_Q5_0:
789-
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
790-
case GGML_TYPE_Q5_1:
791-
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
792-
case GGML_TYPE_Q8_0:
793-
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
794747
case GGML_TYPE_BF16:
795748
return convert_unary_cuda<nv_bfloat16, float>;
796749
default:

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 18 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -725,58 +725,33 @@ void launch_fattn(
725725
size_t nb23 = V ? V->nb[3] : nb13;
726726

727727
if (need_f16_K && K->type != GGML_TYPE_F16) {
728+
GGML_ASSERT(ggml_is_contiguously_allocated(K));
729+
K_f16.alloc(ggml_nelements(K));
730+
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
731+
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
732+
K_data = (char *) K_f16.ptr;
733+
728734
const size_t bs = ggml_blck_size(K->type);
729735
const size_t ts = ggml_type_size(K->type);
730736

731-
K_f16.alloc(ggml_nelements(K));
732-
if (ggml_is_contiguously_allocated(K)) {
733-
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
734-
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
735-
736-
nb11 = nb11*bs*sizeof(half)/ts;
737-
nb12 = nb12*bs*sizeof(half)/ts;
738-
nb13 = nb13*bs*sizeof(half)/ts;
739-
} else {
740-
GGML_ASSERT(K->nb[0] == ts);
741-
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type);
742-
const int64_t s01 = nb11 / ts;
743-
const int64_t s02 = nb12 / ts;
744-
const int64_t s03 = nb13 / ts;
745-
to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
746-
747-
nb11 = K->ne[0] * sizeof(half);
748-
nb12 = K->ne[1] * nb11;
749-
nb13 = K->ne[2] * nb12;
750-
}
751-
K_data = (char *) K_f16.ptr;
737+
nb11 = nb11*bs*sizeof(half)/ts;
738+
nb12 = nb12*bs*sizeof(half)/ts;
739+
nb13 = nb13*bs*sizeof(half)/ts;
752740
}
753741

754742
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
743+
GGML_ASSERT(ggml_is_contiguously_allocated(V));
744+
V_f16.alloc(ggml_nelements(V));
745+
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
746+
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
747+
V_data = (char *) V_f16.ptr;
748+
755749
const size_t bs = ggml_blck_size(V->type);
756750
const size_t ts = ggml_type_size(V->type);
757751

758-
V_f16.alloc(ggml_nelements(V));
759-
if (ggml_is_contiguously_allocated(V)) {
760-
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
761-
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
762-
V_data = (char *) V_f16.ptr;
763-
764-
nb21 = nb21*bs*sizeof(half)/ts;
765-
nb22 = nb22*bs*sizeof(half)/ts;
766-
nb23 = nb23*bs*sizeof(half)/ts;
767-
} else {
768-
GGML_ASSERT(V->nb[0] == ts);
769-
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
770-
const int64_t s01 = nb21 / ts;
771-
const int64_t s02 = nb22 / ts;
772-
const int64_t s03 = nb23 / ts;
773-
to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
774-
775-
nb21 = V->ne[0] * sizeof(half);
776-
nb22 = V->ne[1] * nb21;
777-
nb23 = V->ne[2] * nb22;
778-
}
779-
V_data = (char *) V_f16.ptr;
752+
nb21 = nb21*bs*sizeof(half)/ts;
753+
nb22 = nb22*bs*sizeof(half)/ts;
754+
nb23 = nb23*bs*sizeof(half)/ts;
780755
}
781756

782757
int parallel_blocks = 1;

0 commit comments

Comments
 (0)