Skip to content

Commit 04c12a0

Browse files
committed
Add sparse EMD solver with unit tests
- Implement sparse bipartite graph EMD solver in C++ - Add Python bindings for sparse solver (emd_wrap.pyx, _network_simplex.py) - Add unit tests to verify sparse and dense solvers produce identical results - Tests use augmented k-NN approach to ensure fair comparison - Update setup.py to include sparse solver compilation Both test_emd_sparse_vs_dense() and test_emd2_sparse_vs_dense() verify: * Identical costs between sparse and dense solvers * Marginal constraint satisfaction for both solvers
1 parent be211ac commit 04c12a0

File tree

7 files changed

+932
-63
lines changed

7 files changed

+932
-63
lines changed

ot/lp/EMD.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,24 @@ enum ProblemType {
3232
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter);
3333
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads);
3434

35+
int EMD_wrap_sparse(
36+
int n1,
37+
int n2,
38+
double *X,
39+
double *Y,
40+
uint64_t n_edges, // Number of edges in sparse graph
41+
int64_t *edge_sources, // Source indices for each edge (n_edges)
42+
int64_t *edge_targets, // Target indices for each edge (n_edges)
43+
double *edge_costs, // Cost for each edge (n_edges)
44+
int64_t *flow_sources_out, // Output: source indices of non-zero flows
45+
int64_t *flow_targets_out, // Output: target indices of non-zero flows
46+
double *flow_values_out, // Output: flow values
47+
uint64_t *n_flows_out,
48+
double *alpha, // Output: dual variables for sources (n1)
49+
double *beta, // Output: dual variables for targets (n2)
50+
double *cost, // Output: total transportation cost
51+
uint64_t maxIter // Maximum iterations for solver
52+
);
3553

3654

3755
#endif

ot/lp/EMD_wrapper.cpp

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515

1616
#include "network_simplex_simple.h"
1717
#include "network_simplex_simple_omp.h"
18+
#include "sparse_bipartitegraph.h"
1819
#include "EMD.h"
1920
#include <cstdint>
21+
#include <unordered_map>
2022

2123

2224
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
@@ -216,3 +218,156 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
216218

217219
return ret;
218220
}
221+
222+
// ============================================================================
223+
// SPARSE VERSION: Accepts edge list instead of dense cost matrix
224+
// ============================================================================
225+
int EMD_wrap_sparse(
226+
int n1,
227+
int n2,
228+
double *X,
229+
double *Y,
230+
uint64_t n_edges,
231+
int64_t *edge_sources,
232+
int64_t *edge_targets,
233+
double *edge_costs,
234+
int64_t *flow_sources_out,
235+
int64_t *flow_targets_out,
236+
double *flow_values_out,
237+
uint64_t *n_flows_out,
238+
double *alpha,
239+
double *beta,
240+
double *cost,
241+
uint64_t maxIter
242+
) {
243+
using namespace lemon;
244+
245+
uint64_t n = 0;
246+
for (int i = 0; i < n1; i++) {
247+
double val = *(X + i);
248+
if (val > 0) {
249+
n++;
250+
} else if (val < 0) {
251+
return INFEASIBLE;
252+
}
253+
}
254+
255+
uint64_t m = 0;
256+
for (int i = 0; i < n2; i++) {
257+
double val = *(Y + i);
258+
if (val > 0) {
259+
m++;
260+
} else if (val < 0) {
261+
return INFEASIBLE;
262+
}
263+
}
264+
265+
std::vector<uint64_t> indI(n); // indI[graph_idx] = original_source_idx
266+
std::vector<uint64_t> indJ(m); // indJ[graph_idx] = original_target_idx
267+
std::vector<double> weights1(n); // Source masses (positive only)
268+
std::vector<double> weights2(m); // Target masses (negative for demand)
269+
270+
// Create reverse mapping: original_idx → graph_idx
271+
std::vector<int64_t> source_to_graph(n1, -1);
272+
std::vector<int64_t> target_to_graph(n2, -1);
273+
274+
uint64_t cur = 0;
275+
for (int i = 0; i < n1; i++) {
276+
double val = *(X + i);
277+
if (val > 0) {
278+
weights1[cur] = val; // Store the mass
279+
indI[cur] = i; // Forward map: graph → original
280+
source_to_graph[i] = cur; // Reverse map: original → graph
281+
cur++;
282+
}
283+
}
284+
285+
cur = 0;
286+
for (int i = 0; i < n2; i++) {
287+
double val = *(Y + i);
288+
if (val > 0) {
289+
weights2[cur] = -val;
290+
indJ[cur] = i; // Forward map: graph → original
291+
target_to_graph[i] = cur; // Reverse map: original → graph
292+
cur++;
293+
}
294+
}
295+
296+
typedef SparseBipartiteDigraph Digraph;
297+
DIGRAPH_TYPEDEFS(Digraph);
298+
299+
Digraph di(n, m);
300+
301+
std::vector<std::pair<int, int>> edges; // (source, target) pairs
302+
std::vector<uint64_t> edge_to_arc; // edge_to_arc[k] = arc ID for edge k
303+
std::vector<double> arc_costs; // arc_costs[arc_id] = cost (for O(1) lookup)
304+
edges.reserve(n_edges);
305+
edge_to_arc.reserve(n_edges);
306+
307+
uint64_t valid_edge_count = 0;
308+
for (uint64_t k = 0; k < n_edges; k++) {
309+
int64_t src_orig = edge_sources[k];
310+
int64_t tgt_orig = edge_targets[k];
311+
int64_t src = source_to_graph[src_orig];
312+
int64_t tgt = target_to_graph[tgt_orig];
313+
314+
if (src >= 0 && tgt >= 0) {
315+
edges.emplace_back(src, tgt + n);
316+
edge_to_arc.push_back(valid_edge_count);
317+
arc_costs.push_back(edge_costs[k]); // Store cost indexed by arc ID
318+
valid_edge_count++;
319+
} else {
320+
edge_to_arc.push_back(UINT64_MAX);
321+
}
322+
}
323+
324+
325+
di.buildFromEdges(edges);
326+
327+
NetworkSimplexSimple<Digraph, double, double, node_id_type> net(
328+
di, true, (int)(n + m), di.arcNum(), maxIter
329+
);
330+
331+
net.supplyMap(&weights1[0], (int)n, &weights2[0], (int)m);
332+
333+
for (uint64_t k = 0; k < n_edges; k++) {
334+
if (edge_to_arc[k] != UINT64_MAX) {
335+
net.setCost(edge_to_arc[k], edge_costs[k]);
336+
}
337+
}
338+
339+
int ret = net.run();
340+
341+
if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) {
342+
*cost = 0;
343+
*n_flows_out = 0;
344+
345+
Arc a;
346+
di.first(a);
347+
for (; a != INVALID; di.next(a)) {
348+
uint64_t i = di.source(a);
349+
uint64_t j = di.target(a);
350+
double flow = net.flow(a);
351+
352+
uint64_t orig_i = indI[i];
353+
uint64_t orig_j = indJ[j - n];
354+
355+
356+
double arc_cost = arc_costs[a];
357+
358+
*cost += flow * arc_cost;
359+
360+
361+
*(alpha + orig_i) = -net.potential(i);
362+
*(beta + orig_j) = net.potential(j);
363+
364+
if (flow > 1e-15) {
365+
flow_sources_out[*n_flows_out] = orig_i;
366+
flow_targets_out[*n_flows_out] = orig_j;
367+
flow_values_out[*n_flows_out] = flow;
368+
(*n_flows_out)++;
369+
}
370+
}
371+
}
372+
return ret;
373+
}

0 commit comments

Comments
 (0)