Skip to content

Commit 15d2b46

Browse files
authored
rpc : cache and reuse compute graphs (#15405)
Store the last computed graph and reuse it when possible. Also do not return response from GRAPH_COMPUTE and assume it always completes successfully. If this this is not the case, the server closes the connection. This saves us a network round trip to the server.
1 parent 6bca76f commit 15d2b46

File tree

2 files changed

+92
-21
lines changed

2 files changed

+92
-21
lines changed

ggml/include/ggml-rpc.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ extern "C" {
88
#endif
99

1010
#define RPC_PROTO_MAJOR_VERSION 3
11-
#define RPC_PROTO_MINOR_VERSION 0
11+
#define RPC_PROTO_MINOR_VERSION 5
1212
#define RPC_PROTO_PATCH_VERSION 0
1313
#define GGML_RPC_MAX_SERVERS 16
1414

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 91 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ enum rpc_cmd {
106106
RPC_CMD_GET_ALLOC_SIZE,
107107
RPC_CMD_HELLO,
108108
RPC_CMD_DEVICE_COUNT,
109+
RPC_CMD_GRAPH_RECOMPUTE,
109110
RPC_CMD_COUNT,
110111
};
111112

@@ -205,10 +206,6 @@ struct rpc_msg_copy_tensor_rsp {
205206
uint8_t result;
206207
};
207208

208-
struct rpc_msg_graph_compute_rsp {
209-
uint8_t result;
210-
};
211-
212209
struct rpc_msg_get_device_memory_req {
213210
uint32_t device;
214211
};
@@ -217,6 +214,11 @@ struct rpc_msg_get_device_memory_rsp {
217214
uint64_t free_mem;
218215
uint64_t total_mem;
219216
};
217+
218+
struct rpc_msg_graph_recompute_req {
219+
uint32_t device;
220+
};
221+
220222
#pragma pack(pop)
221223

222224
// RPC data structures
@@ -234,10 +236,35 @@ struct ggml_backend_rpc_buffer_type_context {
234236
size_t max_size;
235237
};
236238

239+
struct graph_cache {
240+
241+
bool is_cached(const ggml_cgraph * cgraph) {
242+
if ((int)last_graph.size() != cgraph->n_nodes) {
243+
return false;
244+
}
245+
for (int i = 0; i < cgraph->n_nodes; i++) {
246+
if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
247+
return false;
248+
}
249+
}
250+
return true;
251+
}
252+
253+
void add(const ggml_cgraph * cgraph) {
254+
last_graph.resize(cgraph->n_nodes);
255+
for (int i = 0; i < cgraph->n_nodes; i++) {
256+
memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
257+
}
258+
}
259+
260+
std::vector<ggml_tensor> last_graph;
261+
};
262+
237263
struct ggml_backend_rpc_context {
238264
std::string endpoint;
239265
uint32_t device;
240266
std::string name;
267+
graph_cache gc;
241268
};
242269

243270
struct ggml_backend_rpc_buffer_context {
@@ -815,13 +842,24 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve
815842

816843
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
817844
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
818-
std::vector<uint8_t> input;
819-
serialize_graph(rpc_ctx->device, cgraph, input);
820-
rpc_msg_graph_compute_rsp response;
821-
auto sock = get_socket(rpc_ctx->endpoint);
822-
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
823-
RPC_STATUS_ASSERT(status);
824-
return (enum ggml_status)response.result;
845+
846+
GGML_ASSERT(cgraph->n_nodes > 0);
847+
bool reuse = rpc_ctx->gc.is_cached(cgraph);
848+
if (reuse) {
849+
rpc_msg_graph_recompute_req request;
850+
request.device = rpc_ctx->device;
851+
auto sock = get_socket(rpc_ctx->endpoint);
852+
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
853+
RPC_STATUS_ASSERT(status);
854+
} else {
855+
rpc_ctx->gc.add(cgraph);
856+
std::vector<uint8_t> input;
857+
serialize_graph(rpc_ctx->device, cgraph, input);
858+
auto sock = get_socket(rpc_ctx->endpoint);
859+
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
860+
RPC_STATUS_ASSERT(status);
861+
}
862+
return GGML_STATUS_SUCCESS;
825863
}
826864

827865
static ggml_backend_i ggml_backend_rpc_interface = {
@@ -880,7 +918,8 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
880918
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
881919
/* .endpoint = */ endpoint,
882920
/* .device = */ device,
883-
/* .name = */ dev_name
921+
/* .name = */ dev_name,
922+
/* .gc = */ {},
884923
};
885924
auto reg = ggml_backend_rpc_add_server(endpoint);
886925
ggml_backend_t backend = new ggml_backend {
@@ -920,8 +959,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device,
920959

921960
class rpc_server {
922961
public:
923-
rpc_server(std::vector<ggml_backend_t> backends, const char * cache_dir)
924-
: backends(std::move(backends)), cache_dir(cache_dir) {
962+
rpc_server(std::vector<ggml_backend_t> all_backends, const char * cache_dir)
963+
: backends(std::move(all_backends)), cache_dir(cache_dir) {
964+
stored_graphs.resize(backends.size());
925965
}
926966
~rpc_server();
927967

@@ -936,11 +976,17 @@ class rpc_server {
936976
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
937977
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
938978
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
939-
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
979+
bool graph_compute(const std::vector<uint8_t> & input);
980+
bool graph_recompute(const rpc_msg_graph_recompute_req & request);
940981
bool init_tensor(const rpc_msg_init_tensor_req & request);
941982
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
942983
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
943984

985+
struct stored_graph {
986+
ggml_context_ptr ctx_ptr;
987+
ggml_cgraph * graph;
988+
};
989+
944990
private:
945991
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
946992
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
@@ -953,6 +999,8 @@ class rpc_server {
953999
std::vector<ggml_backend_t> backends;
9541000
const char * cache_dir;
9551001
std::unordered_set<ggml_backend_buffer_t> buffers;
1002+
// store the last computed graph for each backend
1003+
std::vector<stored_graph> stored_graphs;
9561004
};
9571005

9581006
void rpc_server::hello(rpc_msg_hello_rsp & response) {
@@ -1394,7 +1442,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
13941442
return result;
13951443
}
13961444

1397-
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
1445+
bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
13981446
// serialization format:
13991447
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
14001448
if (input.size() < 2*sizeof(uint32_t)) {
@@ -1455,7 +1503,24 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
14551503
}
14561504
}
14571505
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
1458-
response.result = status;
1506+
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
1507+
stored_graphs[device].ctx_ptr.swap(ctx_ptr);
1508+
stored_graphs[device].graph = graph;
1509+
return true;
1510+
}
1511+
1512+
bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {
1513+
uint32_t device = request.device;
1514+
if (device >= backends.size()) {
1515+
return false;
1516+
}
1517+
if (stored_graphs[device].graph == nullptr) {
1518+
return false;
1519+
}
1520+
ggml_cgraph * graph = stored_graphs[device].graph;
1521+
LOG_DBG("[%s] device: %u\n", __func__, device);
1522+
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
1523+
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
14591524
return true;
14601525
}
14611526

@@ -1690,11 +1755,17 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
16901755
if (!recv_msg(sockfd, input)) {
16911756
return;
16921757
}
1693-
rpc_msg_graph_compute_rsp response;
1694-
if (!server.graph_compute(input, response)) {
1758+
if (!server.graph_compute(input)) {
16951759
return;
16961760
}
1697-
if (!send_msg(sockfd, &response, sizeof(response))) {
1761+
break;
1762+
}
1763+
case RPC_CMD_GRAPH_RECOMPUTE: {
1764+
rpc_msg_graph_recompute_req request;
1765+
if (!recv_msg(sockfd, &request, sizeof(request))) {
1766+
return;
1767+
}
1768+
if (!server.graph_recompute(request)) {
16981769
return;
16991770
}
17001771
break;

0 commit comments

Comments
 (0)