Skip to content

Commit 0eee6f1

Browse files
committed
[WIP] Add sparse EMD solver with unit tests
This PR implements a sparse bipartite graph EMD solver for memory-efficient optimal transport when the cost matrix has many infinite or forbidden edges. Changes: - Implement sparse bipartite graph EMD solver in C++ - Add Python bindings for sparse solver (emd_wrap.pyx, _network_simplex.py) - Add unit tests to verify sparse and dense solvers produce identical results - Tests use augmented k-NN approach to ensure fair comparison Tests verify correctness: * test_emd_sparse_vs_dense() - verifies identical costs and marginal constraints * test_emd2_sparse_vs_dense() - verifies cost-only version Status: WIP - seeking feedback on implementation approach TODO: Add example script and documentation
1 parent 04c12a0 commit 0eee6f1

File tree

2 files changed

+107
-34
lines changed

2 files changed

+107
-34
lines changed

ot/lp/_network_simplex.py

Lines changed: 85 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,17 @@ def emd(
294294
else:
295295
M_coo = M
296296

297-
edge_sources = M_coo.row if M_coo.row.dtype == np.int64 else M_coo.row.astype(np.int64)
298-
edge_targets = M_coo.col if M_coo.col.dtype == np.int64 else M_coo.col.astype(np.int64)
299-
edge_costs = M_coo.data if M_coo.data.dtype == np.float64 else M_coo.data.astype(np.float64)
297+
edge_sources = (
298+
M_coo.row if M_coo.row.dtype == np.int64 else M_coo.row.astype(np.int64)
299+
)
300+
edge_targets = (
301+
M_coo.col if M_coo.col.dtype == np.int64 else M_coo.col.astype(np.int64)
302+
)
303+
edge_costs = (
304+
M_coo.data
305+
if M_coo.data.dtype == np.float64
306+
else M_coo.data.astype(np.float64)
307+
)
300308
n1, n2 = M_coo.shape
301309
elif isinstance(M, tuple) and len(M) == 3:
302310
edge_sources = np.asarray(M[0], dtype=np.int64)
@@ -305,7 +313,9 @@ def emd(
305313
n1 = int(edge_sources.max() + 1)
306314
n2 = int(edge_targets.max() + 1)
307315
else:
308-
raise ValueError("When sparse=True, M must be a scipy sparse matrix or a tuple (row, col, data)")
316+
raise ValueError(
317+
"When sparse=True, M must be a scipy sparse matrix or a tuple (row, col, data)"
318+
)
309319

310320
a, b = list_to_array(a, b)
311321
else:
@@ -321,9 +331,17 @@ def emd(
321331
type_as = a
322332

323333
if len(a) == 0:
324-
a = nx.ones((n1,), type_as=type_as) / n1 if n1 else nx.ones((M.shape[0],), type_as=type_as) / M.shape[0]
334+
a = (
335+
nx.ones((n1,), type_as=type_as) / n1
336+
if n1
337+
else nx.ones((M.shape[0],), type_as=type_as) / M.shape[0]
338+
)
325339
if len(b) == 0:
326-
b = nx.ones((n2,), type_as=type_as) / n2 if n2 else nx.ones((M.shape[1],), type_as=type_as) / M.shape[1]
340+
b = (
341+
nx.ones((n2,), type_as=type_as) / n2
342+
if n2
343+
else nx.ones((M.shape[1],), type_as=type_as) / M.shape[1]
344+
)
327345

328346
if sparse:
329347
a, b = nx.to_numpy(a, b)
@@ -334,7 +352,6 @@ def emd(
334352
a = np.asarray(a, dtype=np.float64)
335353
b = np.asarray(b, dtype=np.float64)
336354

337-
338355
if n1 is None:
339356
n1, n2 = M.shape
340357

@@ -409,7 +426,9 @@ def emd(
409426
if G is not None:
410427
return nx.from_numpy(G, type_as=type_as)
411428
else:
412-
raise ValueError("Cannot return matrix when return_matrix=False and sparse=True without log=True")
429+
raise ValueError(
430+
"Cannot return matrix when return_matrix=False and sparse=True without log=True"
431+
)
413432

414433

415434
def emd2(
@@ -419,12 +438,11 @@ def emd2(
419438
processes=1,
420439
numItermax=100000,
421440
log=False,
422-
423441
center_dual=True,
424442
numThreads=1,
425443
check_marginals=True,
426444
sparse=False,
427-
return_matrix=False
445+
return_matrix=False,
428446
):
429447
r"""Solves the Earth Movers distance problem and returns the loss
430448
@@ -534,7 +552,7 @@ def emd2(
534552
edge_sources = None
535553
edge_targets = None
536554
edge_costs = None
537-
n1, n2 = None, None
555+
n1, n2 = None, None
538556

539557
if sparse:
540558
if sp.issparse(M):
@@ -545,11 +563,21 @@ def emd2(
545563
M_coo = M
546564
t1 = time.perf_counter()
547565

548-
edge_sources = M_coo.row if M_coo.row.dtype == np.int64 else M_coo.row.astype(np.int64)
549-
edge_targets = M_coo.col if M_coo.col.dtype == np.int64 else M_coo.col.astype(np.int64)
550-
edge_costs = M_coo.data if M_coo.data.dtype == np.float64 else M_coo.data.astype(np.float64)
566+
edge_sources = (
567+
M_coo.row if M_coo.row.dtype == np.int64 else M_coo.row.astype(np.int64)
568+
)
569+
edge_targets = (
570+
M_coo.col if M_coo.col.dtype == np.int64 else M_coo.col.astype(np.int64)
571+
)
572+
edge_costs = (
573+
M_coo.data
574+
if M_coo.data.dtype == np.float64
575+
else M_coo.data.astype(np.float64)
576+
)
551577
t2 = time.perf_counter()
552-
print(f"[PY SPARSE] COO conversion: {(t1-t0)*1000:.3f} ms, array copies: {(t2-t1)*1000:.3f} ms")
578+
print(
579+
f"[PY SPARSE] COO conversion: {(t1-t0)*1000:.3f} ms, array copies: {(t2-t1)*1000:.3f} ms"
580+
)
553581
n1, n2 = M_coo.shape
554582
elif isinstance(M, tuple) and len(M) == 3:
555583
edge_sources = np.asarray(M[0], dtype=np.int64)
@@ -577,12 +605,20 @@ def emd2(
577605

578606
# if empty array given then use uniform distributions
579607
if len(a) == 0:
580-
a = nx.ones((n1,), type_as=type_as) / n1 if n1 else nx.ones((M.shape[0],), type_as=type_as) / M.shape[0]
608+
a = (
609+
nx.ones((n1,), type_as=type_as) / n1
610+
if n1
611+
else nx.ones((M.shape[0],), type_as=type_as) / M.shape[0]
612+
)
581613
if len(b) == 0:
582-
b = nx.ones((n2,), type_as=type_as) / n2 if n2 else nx.ones((M.shape[1],), type_as=type_as) / M.shape[1]
614+
b = (
615+
nx.ones((n2,), type_as=type_as) / n2
616+
if n2
617+
else nx.ones((M.shape[1],), type_as=type_as) / M.shape[1]
618+
)
583619

584620
a0, b0 = a, b
585-
M0 = None if sparse else M
621+
M0 = None if sparse else M
586622

587623
if sparse:
588624
edge_costs_original = nx.from_numpy(edge_costs, type_as=type_as)
@@ -625,15 +661,24 @@ def f(b):
625661
bsel = b != 0
626662

627663
if edge_sources is not None:
628-
flow_sources, flow_targets, flow_values, cost, u, v, result_code = emd_c_sparse(
629-
a, b, edge_sources, edge_targets, edge_costs, numItermax
664+
flow_sources, flow_targets, flow_values, cost, u, v, result_code = (
665+
emd_c_sparse(
666+
a, b, edge_sources, edge_targets, edge_costs, numItermax
667+
)
630668
)
631669

632-
edge_to_idx = {(edge_sources[k], edge_targets[k]): k for k in range(len(edge_sources))}
670+
edge_to_idx = {
671+
(edge_sources[k], edge_targets[k]): k
672+
for k in range(len(edge_sources))
673+
}
633674

634675
grad_edge_costs = np.zeros(len(edge_costs), dtype=np.float64)
635676
for idx in range(len(flow_sources)):
636-
src, tgt, flow = flow_sources[idx], flow_targets[idx], flow_values[idx]
677+
src, tgt, flow = (
678+
flow_sources[idx],
679+
flow_targets[idx],
680+
flow_values[idx],
681+
)
637682
edge_idx = edge_to_idx.get((src, tgt), -1)
638683
if edge_idx >= 0:
639684
grad_edge_costs[edge_idx] = flow
@@ -679,7 +724,11 @@ def f(b):
679724
cost = nx.set_gradients(
680725
nx.from_numpy(cost, type_as=type_as),
681726
(a0, b0, edge_costs_original),
682-
(log["u"] - nx.mean(log["u"]), log["v"] - nx.mean(log["v"]), nx.from_numpy(grad_edge_costs, type_as=type_as)),
727+
(
728+
log["u"] - nx.mean(log["u"]),
729+
log["v"] - nx.mean(log["v"]),
730+
nx.from_numpy(grad_edge_costs, type_as=type_as),
731+
),
683732
)
684733
else:
685734
cost = nx.set_gradients(
@@ -694,14 +743,23 @@ def f(b):
694743
bsel = b != 0
695744

696745
if edge_sources is not None:
697-
flow_sources, flow_targets, flow_values, cost, u, v, result_code = emd_c_sparse(
698-
a, b, edge_sources, edge_targets, edge_costs, numItermax
746+
flow_sources, flow_targets, flow_values, cost, u, v, result_code = (
747+
emd_c_sparse(
748+
a, b, edge_sources, edge_targets, edge_costs, numItermax
749+
)
699750
)
700751

701-
edge_to_idx = {(edge_sources[k], edge_targets[k]): k for k in range(len(edge_sources))}
752+
edge_to_idx = {
753+
(edge_sources[k], edge_targets[k]): k
754+
for k in range(len(edge_sources))
755+
}
702756
grad_edge_costs = np.zeros(len(edge_costs), dtype=np.float64)
703757
for idx in range(len(flow_sources)):
704-
src, tgt, flow = flow_sources[idx], flow_targets[idx], flow_values[idx]
758+
src, tgt, flow = (
759+
flow_sources[idx],
760+
flow_targets[idx],
761+
flow_values[idx],
762+
)
705763
edge_idx = edge_to_idx.get((src, tgt), -1)
706764
if edge_idx >= 0:
707765
grad_edge_costs[edge_idx] = flow

test/test_ot.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ot.backend import torch, tf, get_backend
1515
from scipy.sparse import coo_matrix
1616

17+
1718
def test_emd_dimension_and_mass_mismatch():
1819
# test emd and emd2 for dimension mismatch
1920
n_samples = 100
@@ -915,10 +916,14 @@ def test_dual_variables():
915916

916917

917918
def test_emd_sparse_vs_dense():
919+
"""Test that sparse and dense EMD solvers produce identical results.
918920
921+
Uses augmented k-NN graph approach: first solves with dense solver to
922+
identify needed edges, then compares both solvers on the same graph.
923+
"""
919924
n_source = 100
920925
n_target = 100
921-
k = 10
926+
k = 10
922927

923928
rng = np.random.RandomState(42)
924929

@@ -971,17 +976,21 @@ def test_emd_sparse_vs_dense():
971976
cols_aug.append(j)
972977
data_aug.append(C[i, j])
973978

974-
C_augmented = coo_matrix((data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target))
979+
C_augmented = coo_matrix(
980+
(data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target)
981+
)
975982

976983
C_augmented_dense = np.full((n_source, n_target), large_cost)
977984
C_augmented_array = C_augmented.toarray()
978985
C_augmented_dense[C_augmented_array > 0] = C_augmented_array[C_augmented_array > 0]
979986

980987
G_dense, log_dense = ot.emd(a, b, C_augmented_dense, log=True)
981-
G_sparse, log_sparse = ot.emd(a, b, C_augmented, log=True, sparse=True, return_matrix=True)
988+
G_sparse, log_sparse = ot.emd(
989+
a, b, C_augmented, log=True, sparse=True, return_matrix=True
990+
)
982991

983-
cost_dense = log_dense['cost']
984-
cost_sparse = log_sparse['cost']
992+
cost_dense = log_dense["cost"]
993+
cost_sparse = log_sparse["cost"]
985994

986995
np.testing.assert_allclose(cost_dense, cost_sparse, rtol=1e-5, atol=1e-7)
987996

@@ -992,10 +1001,14 @@ def test_emd_sparse_vs_dense():
9921001

9931002

9941003
def test_emd2_sparse_vs_dense():
1004+
"""Test that sparse and dense emd2 solvers produce identical results.
9951005
1006+
Uses augmented k-NN graph approach: first solves with dense solver to
1007+
identify needed edges, then compares both solvers on the same graph.
1008+
"""
9961009
n_source = 100
9971010
n_target = 100
998-
k = 10
1011+
k = 10
9991012

10001013
rng = np.random.RandomState(42)
10011014

@@ -1049,7 +1062,9 @@ def test_emd2_sparse_vs_dense():
10491062
cols_aug.append(j)
10501063
data_aug.append(C[i, j])
10511064

1052-
C_augmented = coo_matrix((data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target))
1065+
C_augmented = coo_matrix(
1066+
(data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target)
1067+
)
10531068

10541069
C_augmented_dense = np.full((n_source, n_target), large_cost)
10551070
C_augmented_array = C_augmented.toarray()

0 commit comments

Comments
 (0)