Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,7 @@ extern "C" {

int32_t n_p_eval;
int32_t n_eval;
int32_t n_reused;
};

struct llama_perf_sampler_data {
Expand Down
175 changes: 87 additions & 88 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ llama_context::llama_context(

LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);

// buffer used to store the computation graph and the tensor meta data
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
gf_res_prev.reset(new llm_graph_result(max_nodes));
gf_res_reserve.reset(new llm_graph_result(max_nodes));

// TODO: move these checks to ggml_backend_sched
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
Expand Down Expand Up @@ -388,10 +388,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
return sched.get();
}

ggml_context * llama_context::get_ctx_compute() const {
return ctx_compute.get();
}

uint32_t llama_context::n_ctx() const {
return cparams.n_ctx;
}
Expand Down Expand Up @@ -678,38 +674,59 @@ bool llama_context::apply_adapter_cvec(
return cvec.apply(model, data, len, n_embd, il_start, il_end);
}

llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
llm_graph_result_i * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
if (mctx && !mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
ret = GGML_STATUS_FAILED;
return nullptr;
}

auto * gf = graph_init();
if (!gf) {
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
ret = GGML_STATUS_FAILED;
return nullptr;
}
auto * res = gf_res_prev.get();
auto * gf = res->get_gf();

auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
if (!res) {
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
ret = GGML_STATUS_FAILED;
return nullptr;
}
// the new graph parameters
// in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
const auto gparams = graph_params(res, ubatch, mctx, gtype);

// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
if (res->can_reuse(gparams)) {
//LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);

if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
ret = GGML_STATUS_ALLOC_FAILED;
return nullptr;
n_reused++;
} else {
res->reset();

ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);

//const auto t_start_us = ggml_time_us();

gf = model.build_graph(gparams);

//LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);

if (!gf) {
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
ret = GGML_STATUS_FAILED;
return nullptr;
}

if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
ret = GGML_STATUS_ALLOC_FAILED;
return nullptr;
}
}

res->set_inputs(&ubatch);
// set the input data for the input tensors
{
//const auto t_start_us = ggml_time_us();

res->set_inputs(&ubatch);

const auto status = graph_compute(gf, ubatch.n_tokens > 1);
//LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
}

const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
if (status != GGML_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
ret = status;
Expand Down Expand Up @@ -767,9 +784,6 @@ int llama_context::encode(const llama_batch & batch_inp) {

n_outputs = n_tokens;

ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);

const auto causal_attn_org = cparams.causal_attn;

// always use non-causal attention for encoder graphs
Expand All @@ -778,7 +792,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
cparams.causal_attn = false;

ggml_status status;
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);

cparams.causal_attn = causal_attn_org;

Expand Down Expand Up @@ -844,10 +858,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
}
}

// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
ggml_backend_sched_reset(sched.get());

// TODO: hacky solution
if (model.arch == LLM_ARCH_T5 && t_embd) {
//cross.t_embd = t_embd;
Expand Down Expand Up @@ -1005,11 +1015,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
n_outputs = n_outputs_new;
}

ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);

ggml_status status;
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);

if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
Expand Down Expand Up @@ -1190,10 +1197,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
// wait for the computation to finish (automatically done when obtaining the model output)
//synchronize();

// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
ggml_backend_sched_reset(sched.get());

return 0;
}

Expand Down Expand Up @@ -1275,20 +1278,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
// graph
//

int32_t llama_context::graph_max_nodes() const {
return std::max<int32_t>(65536, 5*model.n_tensors());
}

ggml_cgraph * llama_context::graph_init() {
ggml_init_params params = {
/*.mem_size =*/ buf_compute_meta.size(),
/*.mem_buffer =*/ buf_compute_meta.data(),
/*.no_alloc =*/ true,
};

ctx_compute.reset(ggml_init(params));

return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
uint32_t llama_context::graph_max_nodes() const {
return std::max<uint32_t>(65536u, 5u*model.n_tensors());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a side note, I don't know why the minimum number of nodes is so big. This is a waste of memory for the smaller models. The change seems to be from #11571.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will follow-up with a PR to fix this.

}

ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
Expand All @@ -1301,6 +1292,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
}

gf_res_prev->reset();
ggml_backend_sched_reset(sched.get());

// store the n_outputs as it is, and restore it afterwards
// TODO: not sure if needed, might simplify in the future by removing this
const auto save_n_outputs = this->n_outputs;
Expand All @@ -1310,17 +1304,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);

auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
auto * res = gf_res_reserve.get();

this->n_outputs = save_n_outputs;
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);

if (!res) {
LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
return nullptr;
}
res->reset();

ggml_backend_sched_reset(sched.get());
auto * gf = model.build_graph(gparams);

this->n_outputs = save_n_outputs;

// initialize scheduler with the specified graph
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
Expand All @@ -1331,28 +1323,27 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
return gf;
}

llm_graph_result_ptr llama_context::graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype,
const llama_memory_context_i * mctx) {
return model.build_graph(
{
/*.ctx =*/ ctx,
/*.arch =*/ model.arch,
/*.hparams =*/ model.hparams,
/*.cparams =*/ cparams,
/*.ubatch =*/ ubatch,
/*.sched =*/ sched.get(),
/*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.mctx =*/ mctx,
/*.cross =*/ &cross,
/*.n_outputs =*/ n_outputs,
/*.cb =*/ graph_get_cb(),
}, gf, gtype);
llm_graph_params llama_context::graph_params(
llm_graph_result_i * res,
const llama_ubatch & ubatch,
const llama_memory_context_i * mctx,
llm_graph_type gtype) const {
return {
/*.arch =*/ model.arch,
/*.hparams =*/ model.hparams,
/*.cparams =*/ cparams,
/*.ubatch =*/ ubatch,
/*.gtype =*/ gtype,
/*.sched =*/ sched.get(),
/*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.mctx =*/ mctx,
/*.cross =*/ &cross,
/*.n_outputs =*/ n_outputs,
/*.cb =*/ graph_get_cb(),
/*.res =*/ res,
};
}

ggml_status llama_context::graph_compute(
Expand Down Expand Up @@ -1930,6 +1921,7 @@ llama_perf_context_data llama_context::perf_get_data() const {
data.t_eval_ms = 1e-3 * t_eval_us;
data.n_p_eval = std::max(1, n_p_eval);
data.n_eval = std::max(1, n_eval);
data.n_reused = std::max(0, n_reused);

return data;
}
Expand All @@ -1938,6 +1930,7 @@ void llama_context::perf_reset() {
t_start_us = ggml_time_us();
t_eval_us = n_eval = 0;
t_p_eval_us = n_p_eval = 0;
n_reused = 0;
}

//
Expand Down Expand Up @@ -2064,8 +2057,13 @@ void llama_context::opt_epoch_iter(
break;
}

auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
auto * res = gf_res_prev.get();

const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);

res->reset();

auto * gf = model.build_graph(gparams);

struct ggml_context * ctx_compute_opt;
{
Expand Down Expand Up @@ -2807,6 +2805,7 @@ void llama_perf_context_print(const llama_context * ctx) {
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused);
}

void llama_perf_context_reset(llama_context * ctx) {
Expand Down
28 changes: 11 additions & 17 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ struct llama_context {

ggml_backend_sched_t get_sched() const;

ggml_context * get_ctx_compute() const;

uint32_t n_ctx() const;
uint32_t n_ctx_per_seq() const;
uint32_t n_batch() const;
Expand Down Expand Up @@ -96,7 +94,7 @@ struct llama_context {
// if memory_context is provided, it will be applied first to the context's memory
// ret contains the status of the graph computation
// returns nullptr only if ret != GGML_STATUS_SUCCESS
llm_graph_result_ptr process_ubatch(
llm_graph_result_i * process_ubatch(
const llama_ubatch & ubatch,
llm_graph_type gtype,
llama_memory_context_i * mctx,
Expand Down Expand Up @@ -188,10 +186,7 @@ struct llama_context {
//

public:
int32_t graph_max_nodes() const;

// zero-out inputs and create the ctx_compute for the compute graph
ggml_cgraph * graph_init();
uint32_t graph_max_nodes() const;

// returns the result of ggml_backend_sched_graph_compute_async execution
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
Expand All @@ -200,12 +195,11 @@ struct llama_context {
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);

private:
llm_graph_result_ptr graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype,
const llama_memory_context_i * mctx);
llm_graph_params graph_params(
llm_graph_result_i * res,
const llama_ubatch & ubatch,
const llama_memory_context_i * mctx,
llm_graph_type gtype) const;

llm_graph_cb graph_get_cb() const;

Expand Down Expand Up @@ -258,8 +252,6 @@ struct llama_context {
ggml_backend_t backend_cpu = nullptr;
std::vector<ggml_backend_ptr> backends;

ggml_context_ptr ctx_compute;

// training
ggml_opt_context_t opt_ctx = nullptr;

Expand All @@ -275,8 +267,8 @@ struct llama_context {
std::vector<ggml_backend_t> backend_ptrs;
std::vector<ggml_backend_buffer_type_t> backend_buft;

// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
llm_graph_result_ptr gf_res_prev;
llm_graph_result_ptr gf_res_reserve;

// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_ptr buf_output;
Expand All @@ -294,4 +286,6 @@ struct llama_context {

mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
mutable int32_t n_eval = 0; // number of eval calls

mutable int32_t n_reused = 0; // number of times the previous graph was reused
};
Loading
Loading