-
Notifications
You must be signed in to change notification settings - Fork 13k
llama : reuse compute graphs #14482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : reuse compute graphs #14482
Changes from 6 commits
594b184
ffc7634
02cd957
3b8cef9
c82a533
3d28b3b
b7b6caf
84e0c1f
0ef1f6b
3d7ec2b
acaf4b7
04155f0
a872790
41366a4
c7ccf38
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -227,8 +227,8 @@ llama_context::llama_context( | |
|
||
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); | ||
|
||
// buffer used to store the computation graph and the tensor meta data | ||
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); | ||
gf_res_prev.reset(new llm_graph_result(max_nodes)); | ||
gf_res_reserve.reset(new llm_graph_result(max_nodes)); | ||
|
||
// TODO: move these checks to ggml_backend_sched | ||
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary | ||
|
@@ -388,10 +388,6 @@ ggml_backend_sched_t llama_context::get_sched() const { | |
return sched.get(); | ||
} | ||
|
||
ggml_context * llama_context::get_ctx_compute() const { | ||
return ctx_compute.get(); | ||
} | ||
|
||
uint32_t llama_context::n_ctx() const { | ||
return cparams.n_ctx; | ||
} | ||
|
@@ -678,38 +674,59 @@ bool llama_context::apply_adapter_cvec( | |
return cvec.apply(model, data, len, n_embd, il_start, il_end); | ||
} | ||
|
||
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { | ||
llm_graph_result_i * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { | ||
if (mctx && !mctx->apply()) { | ||
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); | ||
ret = GGML_STATUS_FAILED; | ||
return nullptr; | ||
} | ||
|
||
auto * gf = graph_init(); | ||
if (!gf) { | ||
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); | ||
ret = GGML_STATUS_FAILED; | ||
return nullptr; | ||
} | ||
auto * res = gf_res_prev.get(); | ||
auto * gf = res->get_gf(); | ||
|
||
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx); | ||
if (!res) { | ||
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__); | ||
ret = GGML_STATUS_FAILED; | ||
return nullptr; | ||
} | ||
// the new graph parameters | ||
// in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters | ||
const auto gparams = graph_params(res, ubatch, mctx, gtype); | ||
|
||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); | ||
if (res->can_reuse(gparams)) { | ||
//LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); | ||
|
||
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { | ||
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); | ||
ret = GGML_STATUS_ALLOC_FAILED; | ||
return nullptr; | ||
n_reused++; | ||
} else { | ||
res->reset(); | ||
|
||
ggml_backend_sched_reset(sched.get()); | ||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); | ||
|
||
//const auto t_start_us = ggml_time_us(); | ||
|
||
gf = model.build_graph(gparams); | ||
|
||
//LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); | ||
|
||
if (!gf) { | ||
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); | ||
ret = GGML_STATUS_FAILED; | ||
return nullptr; | ||
} | ||
|
||
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { | ||
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); | ||
ret = GGML_STATUS_ALLOC_FAILED; | ||
return nullptr; | ||
} | ||
} | ||
|
||
res->set_inputs(&ubatch); | ||
// set the input data for the input tensors | ||
{ | ||
//const auto t_start_us = ggml_time_us(); | ||
|
||
res->set_inputs(&ubatch); | ||
|
||
const auto status = graph_compute(gf, ubatch.n_tokens > 1); | ||
//LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); | ||
} | ||
|
||
const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1); | ||
if (status != GGML_STATUS_SUCCESS) { | ||
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); | ||
ret = status; | ||
|
@@ -767,9 +784,6 @@ int llama_context::encode(const llama_batch & batch_inp) { | |
|
||
n_outputs = n_tokens; | ||
|
||
ggml_backend_sched_reset(sched.get()); | ||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); | ||
|
||
const auto causal_attn_org = cparams.causal_attn; | ||
|
||
// always use non-causal attention for encoder graphs | ||
|
@@ -778,7 +792,7 @@ int llama_context::encode(const llama_batch & batch_inp) { | |
cparams.causal_attn = false; | ||
|
||
ggml_status status; | ||
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); | ||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); | ||
|
||
cparams.causal_attn = causal_attn_org; | ||
|
||
|
@@ -844,10 +858,6 @@ int llama_context::encode(const llama_batch & batch_inp) { | |
} | ||
} | ||
|
||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to | ||
// overlap with device computation. | ||
ggml_backend_sched_reset(sched.get()); | ||
|
||
// TODO: hacky solution | ||
if (model.arch == LLM_ARCH_T5 && t_embd) { | ||
//cross.t_embd = t_embd; | ||
|
@@ -1005,11 +1015,8 @@ int llama_context::decode(const llama_batch & batch_inp) { | |
n_outputs = n_outputs_new; | ||
} | ||
|
||
ggml_backend_sched_reset(sched.get()); | ||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); | ||
|
||
ggml_status status; | ||
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); | ||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); | ||
|
||
if (!res) { | ||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache | ||
|
@@ -1190,10 +1197,6 @@ int llama_context::decode(const llama_batch & batch_inp) { | |
// wait for the computation to finish (automatically done when obtaining the model output) | ||
//synchronize(); | ||
|
||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to | ||
// overlap with device computation. | ||
ggml_backend_sched_reset(sched.get()); | ||
|
||
return 0; | ||
} | ||
|
||
|
@@ -1275,20 +1278,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { | |
// graph | ||
// | ||
|
||
int32_t llama_context::graph_max_nodes() const { | ||
return std::max<int32_t>(65536, 5*model.n_tensors()); | ||
} | ||
|
||
ggml_cgraph * llama_context::graph_init() { | ||
ggml_init_params params = { | ||
/*.mem_size =*/ buf_compute_meta.size(), | ||
/*.mem_buffer =*/ buf_compute_meta.data(), | ||
/*.no_alloc =*/ true, | ||
}; | ||
|
||
ctx_compute.reset(ggml_init(params)); | ||
|
||
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false); | ||
uint32_t llama_context::graph_max_nodes() const { | ||
return std::max<uint32_t>(65536u, 5u*model.n_tensors()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On a side note, I don't know why the minimum number of nodes is so big. This is a waste of memory for the smaller models. The change seems to be from #11571. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will follow-up with a PR to fix this. |
||
} | ||
|
||
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) { | ||
|
@@ -1301,6 +1292,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u | |
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); | ||
} | ||
|
||
gf_res_prev->reset(); | ||
ggml_backend_sched_reset(sched.get()); | ||
|
||
// store the n_outputs as it is, and restore it afterwards | ||
// TODO: not sure if needed, might simplify in the future by removing this | ||
const auto save_n_outputs = this->n_outputs; | ||
|
@@ -1310,17 +1304,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u | |
llama_batch_allocr balloc(model.hparams.n_pos_per_embd()); | ||
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs); | ||
|
||
auto * gf = graph_init(); | ||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx); | ||
auto * res = gf_res_reserve.get(); | ||
|
||
this->n_outputs = save_n_outputs; | ||
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); | ||
|
||
if (!res) { | ||
LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__); | ||
return nullptr; | ||
} | ||
res->reset(); | ||
|
||
ggml_backend_sched_reset(sched.get()); | ||
auto * gf = model.build_graph(gparams); | ||
|
||
this->n_outputs = save_n_outputs; | ||
|
||
// initialize scheduler with the specified graph | ||
if (!ggml_backend_sched_reserve(sched.get(), gf)) { | ||
|
@@ -1331,28 +1323,27 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u | |
return gf; | ||
} | ||
|
||
llm_graph_result_ptr llama_context::graph_build( | ||
ggml_context * ctx, | ||
ggml_cgraph * gf, | ||
const llama_ubatch & ubatch, | ||
llm_graph_type gtype, | ||
const llama_memory_context_i * mctx) { | ||
return model.build_graph( | ||
{ | ||
/*.ctx =*/ ctx, | ||
/*.arch =*/ model.arch, | ||
/*.hparams =*/ model.hparams, | ||
/*.cparams =*/ cparams, | ||
/*.ubatch =*/ ubatch, | ||
/*.sched =*/ sched.get(), | ||
/*.backend_cpu =*/ backend_cpu, | ||
/*.cvec =*/ &cvec, | ||
/*.loras =*/ &loras, | ||
/*.mctx =*/ mctx, | ||
/*.cross =*/ &cross, | ||
/*.n_outputs =*/ n_outputs, | ||
/*.cb =*/ graph_get_cb(), | ||
}, gf, gtype); | ||
llm_graph_params llama_context::graph_params( | ||
llm_graph_result_i * res, | ||
const llama_ubatch & ubatch, | ||
const llama_memory_context_i * mctx, | ||
llm_graph_type gtype) const { | ||
return { | ||
/*.arch =*/ model.arch, | ||
/*.hparams =*/ model.hparams, | ||
/*.cparams =*/ cparams, | ||
/*.ubatch =*/ ubatch, | ||
/*.gtype =*/ gtype, | ||
/*.sched =*/ sched.get(), | ||
/*.backend_cpu =*/ backend_cpu, | ||
/*.cvec =*/ &cvec, | ||
/*.loras =*/ &loras, | ||
/*.mctx =*/ mctx, | ||
/*.cross =*/ &cross, | ||
/*.n_outputs =*/ n_outputs, | ||
/*.cb =*/ graph_get_cb(), | ||
/*.res =*/ res, | ||
}; | ||
} | ||
|
||
ggml_status llama_context::graph_compute( | ||
|
@@ -1930,6 +1921,7 @@ llama_perf_context_data llama_context::perf_get_data() const { | |
data.t_eval_ms = 1e-3 * t_eval_us; | ||
data.n_p_eval = std::max(1, n_p_eval); | ||
data.n_eval = std::max(1, n_eval); | ||
data.n_reused = std::max(0, n_reused); | ||
|
||
return data; | ||
} | ||
|
@@ -1938,6 +1930,7 @@ void llama_context::perf_reset() { | |
t_start_us = ggml_time_us(); | ||
t_eval_us = n_eval = 0; | ||
t_p_eval_us = n_p_eval = 0; | ||
n_reused = 0; | ||
} | ||
|
||
// | ||
|
@@ -2064,8 +2057,13 @@ void llama_context::opt_epoch_iter( | |
break; | ||
} | ||
|
||
auto * gf = graph_init(); | ||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get()); | ||
auto * res = gf_res_prev.get(); | ||
|
||
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); | ||
|
||
res->reset(); | ||
|
||
auto * gf = model.build_graph(gparams); | ||
|
||
struct ggml_context * ctx_compute_opt; | ||
{ | ||
|
@@ -2807,6 +2805,7 @@ void llama_perf_context_print(const llama_context * ctx) { | |
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", | ||
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval); | ||
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval)); | ||
LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused); | ||
} | ||
|
||
void llama_perf_context_reset(llama_context * ctx) { | ||
|
Uh oh!
There was an error while loading. Please reload this page.