Skip to content

Commit 022720b

Browse files
committed
Fix int64_t type compatibility for Linux, remove sparse and return matrix parameter from emd and fix linting issues
1 parent 0eee6f1 commit 022720b

File tree

3 files changed

+65
-61
lines changed

3 files changed

+65
-61
lines changed

ot/lp/_network_simplex.py

Lines changed: 38 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import warnings
1313

1414
import scipy.sparse as sp
15-
import time
1615
from ..utils import list_to_array, check_number_threads
1716
from ..backend import get_backend
1817
from .emd_wrap import emd_c, emd_c_sparse, check_result
@@ -174,8 +173,6 @@ def emd(
174173
center_dual=True,
175174
numThreads=1,
176175
check_marginals=True,
177-
sparse=False,
178-
return_matrix=False,
179176
):
180177
r"""Solves the Earth Movers distance problem and returns the OT matrix
181178
@@ -236,22 +233,26 @@ def emd(
236233
check_marginals: bool, optional (default=True)
237234
If True, checks that the marginals mass are equal. If False, skips the
238235
check.
239-
sparse: bool, optional (default=False)
240-
If True, uses the sparse solver that only stores edges with finite costs.
241-
When sparse=True, M should be a scipy.sparse matrix.
242-
return_matrix: bool, optional (default=True)
243-
If True, returns the transport matrix. If False and sparse=True, returns
244-
sparse flow representation in log.
236+
237+
.. note:: The solver automatically detects sparse format when M is provided as:
238+
- A scipy.sparse matrix (coo, csr, csc, etc.)
239+
- A tuple (row_indices, col_indices, costs) representing an edge list
240+
241+
For sparse inputs, the solver uses a memory-efficient algorithm and returns
242+
the flow in edge format (via log dict) instead of a full matrix.
245243
246244
247245
Returns
248246
-------
249-
gamma: array-like, shape (ns, nt)
250-
Optimal transportation matrix for the given
251-
parameters
247+
gamma: array-like, shape (ns, nt), or None
248+
Optimal transportation matrix for the given parameters.
249+
For sparse inputs, returns None (use log=True to get flow in edge format).
252250
log: dict, optional
253-
If input log is true, a dictionary containing the
254-
cost and dual variables and exit status
251+
If input log is True, a dictionary containing the cost, dual variables,
252+
and exit status. For sparse inputs with log=True, also contains:
253+
- 'flow_sources': source nodes of flow edges
254+
- 'flow_targets': target nodes of flow edges
255+
- 'flow_values': flow values on edges
255256
256257
257258
Examples
@@ -287,7 +288,10 @@ def emd(
287288
edge_costs = None
288289
n1, n2 = None, None
289290

290-
if sparse:
291+
# Auto-detect sparse format
292+
is_sparse = sp.issparse(M) or (isinstance(M, tuple) and len(M) == 3)
293+
294+
if is_sparse:
291295
if sp.issparse(M):
292296
if not isinstance(M, sp.coo_matrix):
293297
M_coo = sp.coo_matrix(M)
@@ -312,10 +316,6 @@ def emd(
312316
edge_costs = np.asarray(M[2], dtype=np.float64)
313317
n1 = int(edge_sources.max() + 1)
314318
n2 = int(edge_targets.max() + 1)
315-
else:
316-
raise ValueError(
317-
"When sparse=True, M must be a scipy sparse matrix or a tuple (row, col, data)"
318-
)
319319

320320
a, b = list_to_array(a, b)
321321
else:
@@ -343,7 +343,7 @@ def emd(
343343
else nx.ones((M.shape[1],), type_as=type_as) / M.shape[1]
344344
)
345345

346-
if sparse:
346+
if is_sparse:
347347
a, b = nx.to_numpy(a, b)
348348
else:
349349
M, a, b = nx.to_numpy(M, a, b)
@@ -375,14 +375,11 @@ def emd(
375375
numThreads = check_number_threads(numThreads)
376376

377377
if edge_sources is not None:
378+
# Sparse solver - never build full matrix
378379
flow_sources, flow_targets, flow_values, cost, u, v, result_code = emd_c_sparse(
379380
a, b, edge_sources, edge_targets, edge_costs, numItermax
380381
)
381-
if return_matrix:
382-
G = np.zeros((len(a), len(b)), dtype=np.float64)
383-
G[flow_sources, flow_targets] = flow_values
384-
else:
385-
G = None
382+
G = None
386383
else:
387384
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
388385

@@ -413,7 +410,8 @@ def emd(
413410
log_dict["warning"] = result_code_string
414411
log_dict["result_code"] = result_code
415412

416-
if edge_sources is not None and not return_matrix:
413+
if edge_sources is not None:
414+
# For sparse, include flow in edge format
417415
log_dict["flow_sources"] = flow_sources
418416
log_dict["flow_targets"] = flow_targets
419417
log_dict["flow_values"] = flow_values
@@ -427,7 +425,7 @@ def emd(
427425
return nx.from_numpy(G, type_as=type_as)
428426
else:
429427
raise ValueError(
430-
"Cannot return matrix when return_matrix=False and sparse=True without log=True"
428+
"For sparse inputs, log=True is required to get the flow in edge format"
431429
)
432430

433431

@@ -441,7 +439,6 @@ def emd2(
441439
center_dual=True,
442440
numThreads=1,
443441
check_marginals=True,
444-
sparse=False,
445442
return_matrix=False,
446443
):
447444
r"""Solves the Earth Movers distance problem and returns the loss
@@ -503,11 +500,12 @@ def emd2(
503500
check_marginals: bool, optional (default=True)
504501
If True, checks that the marginals mass are equal. If False, skips the
505502
check.
506-
sparse: bool, optional (default=False)
507-
If True, uses the sparse solver that only stores edges with finite costs.
508-
This is memory-efficient when M has many infinite or forbidden edges.
509-
When sparse=True, M should be a scipy.sparse matrix (coo, csr, or csc format)
510-
or a tuple (row_indices, col_indices, costs) representing the edge list.
503+
504+
.. note:: The solver automatically detects sparse format when M is provided as:
505+
- A scipy.sparse matrix (coo, csr, csc, etc.)
506+
- A tuple (row_indices, col_indices, costs) representing an edge list
507+
508+
For sparse inputs, the solver uses a memory-efficient algorithm.
511509
Edges not included are treated as having infinite cost (forbidden).
512510
513511
@@ -554,14 +552,15 @@ def emd2(
554552
edge_costs = None
555553
n1, n2 = None, None
556554

557-
if sparse:
555+
# Auto-detect sparse format
556+
is_sparse = sp.issparse(M) or (isinstance(M, tuple) and len(M) == 3)
557+
558+
if is_sparse:
558559
if sp.issparse(M):
559-
t0 = time.perf_counter()
560560
if not isinstance(M, sp.coo_matrix):
561561
M_coo = sp.coo_matrix(M)
562562
else:
563563
M_coo = M
564-
t1 = time.perf_counter()
565564

566565
edge_sources = (
567566
M_coo.row if M_coo.row.dtype == np.int64 else M_coo.row.astype(np.int64)
@@ -574,21 +573,13 @@ def emd2(
574573
if M_coo.data.dtype == np.float64
575574
else M_coo.data.astype(np.float64)
576575
)
577-
t2 = time.perf_counter()
578-
print(
579-
f"[PY SPARSE] COO conversion: {(t1-t0)*1000:.3f} ms, array copies: {(t2-t1)*1000:.3f} ms"
580-
)
581576
n1, n2 = M_coo.shape
582577
elif isinstance(M, tuple) and len(M) == 3:
583578
edge_sources = np.asarray(M[0], dtype=np.int64)
584579
edge_targets = np.asarray(M[1], dtype=np.int64)
585580
edge_costs = np.asarray(M[2], dtype=np.float64)
586581
n1 = int(edge_sources.max() + 1)
587582
n2 = int(edge_targets.max() + 1)
588-
else:
589-
raise ValueError(
590-
"When sparse=True, M must be a scipy sparse matrix or a tuple (row, col, data)"
591-
)
592583

593584
a, b = list_to_array(a, b)
594585
else:
@@ -618,14 +609,14 @@ def emd2(
618609
)
619610

620611
a0, b0 = a, b
621-
M0 = None if sparse else M
612+
M0 = None if is_sparse else M
622613

623-
if sparse:
614+
if is_sparse:
624615
edge_costs_original = nx.from_numpy(edge_costs, type_as=type_as)
625616
else:
626617
edge_costs_original = None
627618

628-
if sparse:
619+
if is_sparse:
629620
a, b = nx.to_numpy(a, b)
630621
else:
631622
M, a, b = nx.to_numpy(M, a, b)

ot/lp/emd_wrap.pyx

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ from ..utils import dist
1414

1515
cimport cython
1616
cimport libc.math as math
17-
from libc.stdint cimport uint64_t
17+
from libc.stdint cimport uint64_t, int64_t
1818

1919
import warnings
2020

2121

2222
cdef extern from "EMD.h":
2323
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil
2424
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil
25-
int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, long long *edge_sources, long long *edge_targets, double *edge_costs, long long *flow_sources_out, long long *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter) nogil
25+
int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, int64_t *edge_sources, int64_t *edge_targets, double *edge_costs, int64_t *flow_sources_out, int64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter) nogil
2626
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
2727

2828

@@ -212,8 +212,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
212212
@cython.wraparound(False)
213213
def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a,
214214
np.ndarray[double, ndim=1, mode="c"] b,
215-
np.ndarray[long long, ndim=1, mode="c"] edge_sources,
216-
np.ndarray[long long, ndim=1, mode="c"] edge_targets,
215+
np.ndarray[int64_t, ndim=1, mode="c"] edge_sources,
216+
np.ndarray[int64_t, ndim=1, mode="c"] edge_targets,
217217
np.ndarray[double, ndim=1, mode="c"] edge_costs,
218218
uint64_t max_iter):
219219
"""
@@ -259,8 +259,8 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a,
259259
cdef double cost = 0
260260

261261
# Allocate output arrays (max size = n_edges)
262-
cdef np.ndarray[long long, ndim=1, mode="c"] flow_sources = np.zeros(n_edges, dtype=np.int64)
263-
cdef np.ndarray[long long, ndim=1, mode="c"] flow_targets = np.zeros(n_edges, dtype=np.int64)
262+
cdef np.ndarray[int64_t, ndim=1, mode="c"] flow_sources = np.zeros(n_edges, dtype=np.int64)
263+
cdef np.ndarray[int64_t, ndim=1, mode="c"] flow_targets = np.zeros(n_edges, dtype=np.int64)
264264
cdef np.ndarray[double, ndim=1, mode="c"] flow_values = np.zeros(n_edges, dtype=np.float64)
265265
cdef np.ndarray[double, ndim=1, mode="c"] alpha = np.zeros(n1)
266266
cdef np.ndarray[double, ndim=1, mode="c"] beta = np.zeros(n2)
@@ -270,8 +270,8 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a,
270270
n1, n2,
271271
<double*> a.data, <double*> b.data,
272272
n_edges,
273-
<long long*> edge_sources.data, <long long*> edge_targets.data, <double*> edge_costs.data,
274-
<long long*> flow_sources.data, <long long*> flow_targets.data, <double*> flow_values.data,
273+
<int64_t*> edge_sources.data, <int64_t*> edge_targets.data, <double*> edge_costs.data,
274+
<int64_t*> flow_sources.data, <int64_t*> flow_targets.data, <double*> flow_values.data,
275275
&n_flows_out,
276276
<double*> alpha.data, <double*> beta.data, &cost, max_iter
277277
)

test/test_ot.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -985,19 +985,32 @@ def test_emd_sparse_vs_dense():
985985
C_augmented_dense[C_augmented_array > 0] = C_augmented_array[C_augmented_array > 0]
986986

987987
G_dense, log_dense = ot.emd(a, b, C_augmented_dense, log=True)
988-
G_sparse, log_sparse = ot.emd(
989-
a, b, C_augmented, log=True, sparse=True, return_matrix=True
990-
)
988+
G_sparse, log_sparse = ot.emd(a, b, C_augmented, log=True)
991989

992990
cost_dense = log_dense["cost"]
993991
cost_sparse = log_sparse["cost"]
994992

995993
np.testing.assert_allclose(cost_dense, cost_sparse, rtol=1e-5, atol=1e-7)
996994

995+
# For dense, G_dense is returned; for sparse, reconstruct from flow edges
997996
np.testing.assert_allclose(a, G_dense.sum(1), rtol=1e-5, atol=1e-7)
998997
np.testing.assert_allclose(b, G_dense.sum(0), rtol=1e-5, atol=1e-7)
999-
np.testing.assert_allclose(a, G_sparse.sum(1), rtol=1e-5, atol=1e-7)
1000-
np.testing.assert_allclose(b, G_sparse.sum(0), rtol=1e-5, atol=1e-7)
998+
999+
# Reconstruct sparse matrix from flow for marginal checks
1000+
if G_sparse is None:
1001+
G_sparse_reconstructed = np.zeros((n_source, n_target))
1002+
G_sparse_reconstructed[
1003+
log_sparse["flow_sources"], log_sparse["flow_targets"]
1004+
] = log_sparse["flow_values"]
1005+
np.testing.assert_allclose(
1006+
a, G_sparse_reconstructed.sum(1), rtol=1e-5, atol=1e-7
1007+
)
1008+
np.testing.assert_allclose(
1009+
b, G_sparse_reconstructed.sum(0), rtol=1e-5, atol=1e-7
1010+
)
1011+
else:
1012+
np.testing.assert_allclose(a, G_sparse.sum(1), rtol=1e-5, atol=1e-7)
1013+
np.testing.assert_allclose(b, G_sparse.sum(0), rtol=1e-5, atol=1e-7)
10011014

10021015

10031016
def test_emd2_sparse_vs_dense():
@@ -1071,7 +1084,7 @@ def test_emd2_sparse_vs_dense():
10711084
C_augmented_dense[C_augmented_array > 0] = C_augmented_array[C_augmented_array > 0]
10721085

10731086
cost_dense = ot.emd2(a, b, C_augmented_dense)
1074-
cost_sparse = ot.emd2(a, b, C_augmented, sparse=True)
1087+
cost_sparse = ot.emd2(a, b, C_augmented)
10751088

10761089
np.testing.assert_allclose(cost_dense, cost_sparse, rtol=1e-5, atol=1e-7)
10771090

0 commit comments

Comments
 (0)