Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 63 additions & 55 deletions ggml/src/ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -964,8 +964,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
}
aclTensor* acl_gamma = get_f32_cache_acl_tensor(
ctx,
&ctx.f32_one_cache,
ctx.f32_one_cache_element,
&ctx.rms_norm_one_tensor_cache.cache,
ctx.rms_norm_one_tensor_cache.size,
src->ne,
acl_gamma_nb,
1, // dims
Expand All @@ -980,8 +980,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
}
aclTensor* acl_rstd = get_f32_cache_acl_tensor(
ctx,
&ctx.f32_zero_cache,
ctx.f32_zero_cache_element,
&ctx.rms_norm_zero_tensor_cache.cache,
ctx.rms_norm_zero_tensor_cache.size,
src->ne,
acl_rstd_nb,
GGML_MAX_DIMS,
Expand Down Expand Up @@ -2249,7 +2249,7 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
* 6. Expand sin/cos values by repeat or repeat_interleave depending
* on whether @param is_neox is enabled.
* 7. Store the computed values into persistent buffers
* (ctx.rope_sin_ptr / ctx.rope_cos_ptr).
* (ctx.rope_cache.sin_cache / ctx.rope_cache.cos_cache).
*
* @param ctx The CANN backend context, holding memory pool,
* stream, and persistent buffers for rope init/cache.
Expand All @@ -2266,25 +2266,20 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
float attn_factor, bool is_neox) {
// int sin/cos cache, cache has different repeat method depond on
// @param.is_neox
bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);

// used for accuracy testing
bool is_attention = is_q || is_k;

// just compute in first layer in attention
bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
if(is_attention && !is_fisrt_layer) {
return;
}

ggml_tensor* src0 = dst->src[0]; // input
ggml_tensor* src1 = dst->src[1]; // position
ggml_tensor* src2 = dst->src[2]; // freq_factors

GGML_TENSOR_BINARY_OP_LOCALS
if(src2 == nullptr && ctx.rope_cache.cached) {
// use cache.
return;
}

int64_t theta_scale_length = ne00 / 2;
// Other layers use cache except first layer.
ctx.rope_cache.cached = true;

int64_t theta_scale_length = src0->ne[0] / 2;
int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1};
size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t),
theta_scale_length * sizeof(float_t)};
Expand All @@ -2302,21 +2297,32 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
}

// init theta scale, just one time
if(ctx.rope_init_ptr == nullptr || !is_attention) {
// theta_scale arange, [0,1,...,ne00/2 - 1]
if(ctx.rope_init_ptr != nullptr){
ACL_CHECK(aclrtFree(ctx.rope_init_ptr));
// theta_scale arange, [0,1,...,ne00/2 - 1]
aclTensor* acl_theta_scale_tensor = nullptr;
// cache theta scale
if (src2 != nullptr || ctx.rope_cache.theta_scale_length != theta_scale_length ||
// theta_scale and freq_scale should not change during the current token inference process,
// so we can directly use == here instead of comparing the absolute difference.
ctx.rope_cache.theta_scale != theta_scale ||
ctx.rope_cache.freq_scale != freq_scale) {

ctx.rope_cache.theta_scale_length = theta_scale_length;
ctx.rope_cache.theta_scale = theta_scale;
ctx.rope_cache.freq_scale = freq_scale;

if (ctx.rope_cache.theta_scale_cache != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
}
ACL_CHECK(aclrtMalloc(&ctx.rope_init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));

aclTensor* acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);

float start = 0;
float step = 1;
float stop = ne00 / 2;
float n_elements = ne00 / 2;
float stop = theta_scale_length;
float n_elements = theta_scale_length;
aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);

// power
Expand All @@ -2328,35 +2334,37 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
if (freq_scale != 1) {
aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
}
ggml_cann_release_resources(ctx, acl_theta_scale);
} else {
// use cache
acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
}

// freq_factors
if (src2) {
aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
src2->data, ggml_cann_type_mapping(src2->type),
ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
ggml_cann_release_resources(ctx, acl_freq_factors_tensor);
}
// release
ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale);
// freq_factors
if (src2) {
aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
src2->data, ggml_cann_type_mapping(src2->type),
ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
ggml_cann_release_resources(ctx, acl_freq_factors_tensor);
}

// init sin_repeat && cos_repeat, one token just init in 0 layer
if(position_length > ctx.max_prompt_length) {
ctx.max_prompt_length = position_length;
int64_t repeat_theta_length = theta_scale_length * ctx.max_prompt_length * 2;
if(ctx.rope_sin_ptr != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_sin_ptr));
ACL_CHECK(aclrtFree(ctx.rope_cos_ptr));
if (position_length > ctx.rope_cache.position_length) {
ctx.rope_cache.position_length = position_length;
if (ctx.rope_cache.sin_cache != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.sin_cache));
}
ACL_CHECK(aclrtMalloc(&ctx.rope_sin_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMalloc(&ctx.rope_cos_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
if (ctx.rope_cache.cos_cache != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.cos_cache));
}
int64_t repeat_theta_length = theta_scale_length * position_length * 2;
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.sin_cache, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
}

aclTensor* acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);

// position
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
src1->data, ggml_cann_type_mapping(src1->type),
Expand Down Expand Up @@ -2397,17 +2405,17 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true);
}

int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1};
size_t sin_reshape_nb[GGML_MAX_DIMS];
sin_reshape_nb[0] = sizeof(float_t);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
}
aclTensor* acl_sin_repeat_tensor =
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_cos_repeat_tensor =
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);

// repeat
Expand Down Expand Up @@ -2491,10 +2499,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
}
aclTensor* acl_sin_reshape_tensor =
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_cos_reshape_tensor =
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);

aclTensor* acl_src = ggml_cann_create_tensor(src0);
Expand Down
61 changes: 38 additions & 23 deletions ggml/src/ggml-cann/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,40 @@ struct ggml_cann_graph {
};
#endif // USE_ACL_GRAPH

struct ggml_cann_rope_cache {
~ggml_cann_rope_cache() {
if(theta_scale_cache != nullptr) {
ACL_CHECK(aclrtFree(theta_scale_cache));
}
if(sin_cache != nullptr) {
ACL_CHECK(aclrtFree(sin_cache));
}
if(cos_cache != nullptr) {
ACL_CHECK(aclrtFree(cos_cache));
}
}

void* theta_scale_cache = nullptr;
void* sin_cache = nullptr;
void* cos_cache = nullptr;
bool cached = false;
int64_t position_length = 0;
int64_t theta_scale_length = 0;
float theta_scale = 0.0f;
float freq_scale = 0.0f;
};

struct ggml_cann_tensor_cache {
~ggml_cann_tensor_cache() {
if(cache != nullptr) {
ACL_CHECK(aclrtFree(cache));
}
}

void* cache = nullptr;
int64_t size = 0;
};

/**
* @brief Context for managing CANN backend operations.
*/
Expand All @@ -376,15 +410,11 @@ struct ggml_backend_cann_context {
bool async_mode;
bool support_set_rows;
// Rope Cache
void* rope_init_ptr = nullptr;
void* rope_sin_ptr = nullptr;
void* rope_cos_ptr = nullptr;
int64_t max_prompt_length = 0;
ggml_cann_rope_cache rope_cache;
// Constant Pool
void* f32_zero_cache = nullptr;
void* f32_one_cache = nullptr;
int64_t f32_zero_cache_element = 0;
int64_t f32_one_cache_element = 0;
ggml_cann_tensor_cache rms_norm_one_tensor_cache;
ggml_cann_tensor_cache rms_norm_zero_tensor_cache;


aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */

Expand Down Expand Up @@ -424,21 +454,6 @@ struct ggml_backend_cann_context {
ACL_CHECK(aclrtDestroyStream(streams[i]));
}
}
if(rope_init_ptr != nullptr) {
ACL_CHECK(aclrtFree(rope_init_ptr));
}
if(rope_sin_ptr != nullptr) {
ACL_CHECK(aclrtFree(rope_sin_ptr));
}
if(rope_cos_ptr != nullptr) {
ACL_CHECK(aclrtFree(rope_cos_ptr));
}
if(f32_zero_cache != nullptr) {
ACL_CHECK(aclrtFree(f32_zero_cache));
}
if(f32_one_cache != nullptr) {
ACL_CHECK(aclrtFree(f32_one_cache));
}
}

/**
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2247,6 +2247,10 @@ static enum ggml_status ggml_backend_cann_graph_compute(
(ggml_backend_cann_context*)backend->context;
ggml_cann_set_device(cann_ctx->device);
release_nz_workspace();

// calculate rope cache for fist layer in current device.
cann_ctx->rope_cache.cached = false;

#ifdef USE_ACL_GRAPH
bool use_cann_graph = true;
bool cann_graph_update_required = false;
Expand Down
Loading