Skip to content

Commit 923e3ea

Browse files
authored
cuda : add set rows for bf16 (#14664)
1 parent e743cdd commit 923e3ea

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3226,8 +3226,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32263226
} break;
32273227
case GGML_OP_SET_ROWS:
32283228
{
3229-
#pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
3230-
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
3229+
#pragma message("TODO: implement Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
3230+
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16) &&
32313231
op->src[0]->type == GGML_TYPE_F32 &&
32323232
op->src[1]->type == GGML_TYPE_I64;
32333233
} break;

ggml/src/ggml-cuda/set-rows.cu

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ __device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, hal
1010
*dst_h = __float2half(*src_f);
1111
}
1212

13+
template<>
14+
__device__ __forceinline__ void set_rows_1<float, nv_bfloat16>(const float * src_f, nv_bfloat16 * dst_b) {
15+
*dst_b = *src_f;
16+
}
17+
1318
template<>
1419
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
1520
*dst_f = *src_f;
@@ -124,6 +129,16 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
124129
nb1, nb2, nb3,
125130
stream
126131
);
132+
} else if (dst->type == GGML_TYPE_BF16) {
133+
set_rows_cuda(
134+
src0_d, src1_d, (nv_bfloat16*)dst->data,
135+
ne00, ne01, ne02, ne03,
136+
ne10, ne11, ne12, ne13,
137+
nb01, nb02, nb03,
138+
nb10, nb11, nb12,
139+
nb1, nb2, nb3,
140+
stream
141+
);
127142
} else {
128143
GGML_ABORT("unsupported type");
129144
}

0 commit comments

Comments
 (0)