1515
1616#include " network_simplex_simple.h"
1717#include " network_simplex_simple_omp.h"
18+ #include " sparse_bipartitegraph.h"
1819#include " EMD.h"
1920#include < cstdint>
21+ #include < unordered_map>
2022
2123
2224int EMD_wrap (int n1, int n2, double *X, double *Y, double *D, double *G,
@@ -216,3 +218,156 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
216218
217219 return ret;
218220}
221+
222+ // ============================================================================
223+ // SPARSE VERSION: Accepts edge list instead of dense cost matrix
224+ // ============================================================================
225+ int EMD_wrap_sparse (
226+ int n1,
227+ int n2,
228+ double *X,
229+ double *Y,
230+ uint64_t n_edges,
231+ int64_t *edge_sources,
232+ int64_t *edge_targets,
233+ double *edge_costs,
234+ int64_t *flow_sources_out,
235+ int64_t *flow_targets_out,
236+ double *flow_values_out,
237+ uint64_t *n_flows_out,
238+ double *alpha,
239+ double *beta,
240+ double *cost,
241+ uint64_t maxIter
242+ ) {
243+ using namespace lemon ;
244+
245+ uint64_t n = 0 ;
246+ for (int i = 0 ; i < n1; i++) {
247+ double val = *(X + i);
248+ if (val > 0 ) {
249+ n++;
250+ } else if (val < 0 ) {
251+ return INFEASIBLE;
252+ }
253+ }
254+
255+ uint64_t m = 0 ;
256+ for (int i = 0 ; i < n2; i++) {
257+ double val = *(Y + i);
258+ if (val > 0 ) {
259+ m++;
260+ } else if (val < 0 ) {
261+ return INFEASIBLE;
262+ }
263+ }
264+
265+ std::vector<uint64_t > indI (n); // indI[graph_idx] = original_source_idx
266+ std::vector<uint64_t > indJ (m); // indJ[graph_idx] = original_target_idx
267+ std::vector<double > weights1 (n); // Source masses (positive only)
268+ std::vector<double > weights2 (m); // Target masses (negative for demand)
269+
270+ // Create reverse mapping: original_idx → graph_idx
271+ std::vector<int64_t > source_to_graph (n1, -1 );
272+ std::vector<int64_t > target_to_graph (n2, -1 );
273+
274+ uint64_t cur = 0 ;
275+ for (int i = 0 ; i < n1; i++) {
276+ double val = *(X + i);
277+ if (val > 0 ) {
278+ weights1[cur] = val; // Store the mass
279+ indI[cur] = i; // Forward map: graph → original
280+ source_to_graph[i] = cur; // Reverse map: original → graph
281+ cur++;
282+ }
283+ }
284+
285+ cur = 0 ;
286+ for (int i = 0 ; i < n2; i++) {
287+ double val = *(Y + i);
288+ if (val > 0 ) {
289+ weights2[cur] = -val;
290+ indJ[cur] = i; // Forward map: graph → original
291+ target_to_graph[i] = cur; // Reverse map: original → graph
292+ cur++;
293+ }
294+ }
295+
296+ typedef SparseBipartiteDigraph Digraph;
297+ DIGRAPH_TYPEDEFS (Digraph);
298+
299+ Digraph di (n, m);
300+
301+ std::vector<std::pair<int , int >> edges; // (source, target) pairs
302+ std::vector<uint64_t > edge_to_arc; // edge_to_arc[k] = arc ID for edge k
303+ std::vector<double > arc_costs; // arc_costs[arc_id] = cost (for O(1) lookup)
304+ edges.reserve (n_edges);
305+ edge_to_arc.reserve (n_edges);
306+
307+ uint64_t valid_edge_count = 0 ;
308+ for (uint64_t k = 0 ; k < n_edges; k++) {
309+ int64_t src_orig = edge_sources[k];
310+ int64_t tgt_orig = edge_targets[k];
311+ int64_t src = source_to_graph[src_orig];
312+ int64_t tgt = target_to_graph[tgt_orig];
313+
314+ if (src >= 0 && tgt >= 0 ) {
315+ edges.emplace_back (src, tgt + n);
316+ edge_to_arc.push_back (valid_edge_count);
317+ arc_costs.push_back (edge_costs[k]); // Store cost indexed by arc ID
318+ valid_edge_count++;
319+ } else {
320+ edge_to_arc.push_back (UINT64_MAX);
321+ }
322+ }
323+
324+
325+ di.buildFromEdges (edges);
326+
327+ NetworkSimplexSimple<Digraph, double , double , node_id_type> net (
328+ di, true , (int )(n + m), di.arcNum (), maxIter
329+ );
330+
331+ net.supplyMap (&weights1[0 ], (int )n, &weights2[0 ], (int )m);
332+
333+ for (uint64_t k = 0 ; k < n_edges; k++) {
334+ if (edge_to_arc[k] != UINT64_MAX) {
335+ net.setCost (edge_to_arc[k], edge_costs[k]);
336+ }
337+ }
338+
339+ int ret = net.run ();
340+
341+ if (ret == (int )net.OPTIMAL || ret == (int )net.MAX_ITER_REACHED ) {
342+ *cost = 0 ;
343+ *n_flows_out = 0 ;
344+
345+ Arc a;
346+ di.first (a);
347+ for (; a != INVALID; di.next (a)) {
348+ uint64_t i = di.source (a);
349+ uint64_t j = di.target (a);
350+ double flow = net.flow (a);
351+
352+ uint64_t orig_i = indI[i];
353+ uint64_t orig_j = indJ[j - n];
354+
355+
356+ double arc_cost = arc_costs[a];
357+
358+ *cost += flow * arc_cost;
359+
360+
361+ *(alpha + orig_i) = -net.potential (i);
362+ *(beta + orig_j) = net.potential (j);
363+
364+ if (flow > 1e-15 ) {
365+ flow_sources_out[*n_flows_out] = orig_i;
366+ flow_targets_out[*n_flows_out] = orig_j;
367+ flow_values_out[*n_flows_out] = flow;
368+ (*n_flows_out)++;
369+ }
370+ }
371+ }
372+ return ret;
373+ }
0 commit comments