1212import warnings
1313
1414import scipy .sparse as sp
15- import time
1615from ..utils import list_to_array , check_number_threads
1716from ..backend import get_backend
1817from .emd_wrap import emd_c , emd_c_sparse , check_result
@@ -174,8 +173,6 @@ def emd(
174173 center_dual = True ,
175174 numThreads = 1 ,
176175 check_marginals = True ,
177- sparse = False ,
178- return_matrix = False ,
179176):
180177 r"""Solves the Earth Movers distance problem and returns the OT matrix
181178
@@ -236,22 +233,26 @@ def emd(
236233 check_marginals: bool, optional (default=True)
237234 If True, checks that the marginals mass are equal. If False, skips the
238235 check.
239- sparse: bool, optional (default=False)
240- If True, uses the sparse solver that only stores edges with finite costs.
241- When sparse=True, M should be a scipy.sparse matrix.
242- return_matrix: bool, optional (default=True)
243- If True, returns the transport matrix. If False and sparse=True, returns
244- sparse flow representation in log.
236+
237+ .. note:: The solver automatically detects sparse format when M is provided as:
238+ - A scipy.sparse matrix (coo, csr, csc, etc.)
239+ - A tuple (row_indices, col_indices, costs) representing an edge list
240+
241+ For sparse inputs, the solver uses a memory-efficient algorithm and returns
242+ the flow in edge format (via log dict) instead of a full matrix.
245243
246244
247245 Returns
248246 -------
249- gamma: array-like, shape (ns, nt)
250- Optimal transportation matrix for the given
251- parameters
247+ gamma: array-like, shape (ns, nt), or None
248+ Optimal transportation matrix for the given parameters.
249+ For sparse inputs, returns None (use log=True to get flow in edge format).
252250 log: dict, optional
253- If input log is true, a dictionary containing the
254- cost and dual variables and exit status
251+ If input log is True, a dictionary containing the cost, dual variables,
252+ and exit status. For sparse inputs with log=True, also contains:
253+ - 'flow_sources': source nodes of flow edges
254+ - 'flow_targets': target nodes of flow edges
255+ - 'flow_values': flow values on edges
255256
256257
257258 Examples
@@ -287,7 +288,10 @@ def emd(
287288 edge_costs = None
288289 n1 , n2 = None , None
289290
290- if sparse :
291+ # Auto-detect sparse format
292+ is_sparse = sp .issparse (M ) or (isinstance (M , tuple ) and len (M ) == 3 )
293+
294+ if is_sparse :
291295 if sp .issparse (M ):
292296 if not isinstance (M , sp .coo_matrix ):
293297 M_coo = sp .coo_matrix (M )
@@ -312,10 +316,6 @@ def emd(
312316 edge_costs = np .asarray (M [2 ], dtype = np .float64 )
313317 n1 = int (edge_sources .max () + 1 )
314318 n2 = int (edge_targets .max () + 1 )
315- else :
316- raise ValueError (
317- "When sparse=True, M must be a scipy sparse matrix or a tuple (row, col, data)"
318- )
319319
320320 a , b = list_to_array (a , b )
321321 else :
@@ -343,7 +343,7 @@ def emd(
343343 else nx .ones ((M .shape [1 ],), type_as = type_as ) / M .shape [1 ]
344344 )
345345
346- if sparse :
346+ if is_sparse :
347347 a , b = nx .to_numpy (a , b )
348348 else :
349349 M , a , b = nx .to_numpy (M , a , b )
@@ -375,14 +375,11 @@ def emd(
375375 numThreads = check_number_threads (numThreads )
376376
377377 if edge_sources is not None :
378+ # Sparse solver - never build full matrix
378379 flow_sources , flow_targets , flow_values , cost , u , v , result_code = emd_c_sparse (
379380 a , b , edge_sources , edge_targets , edge_costs , numItermax
380381 )
381- if return_matrix :
382- G = np .zeros ((len (a ), len (b )), dtype = np .float64 )
383- G [flow_sources , flow_targets ] = flow_values
384- else :
385- G = None
382+ G = None
386383 else :
387384 G , cost , u , v , result_code = emd_c (a , b , M , numItermax , numThreads )
388385
@@ -413,7 +410,8 @@ def emd(
413410 log_dict ["warning" ] = result_code_string
414411 log_dict ["result_code" ] = result_code
415412
416- if edge_sources is not None and not return_matrix :
413+ if edge_sources is not None :
414+ # For sparse, include flow in edge format
417415 log_dict ["flow_sources" ] = flow_sources
418416 log_dict ["flow_targets" ] = flow_targets
419417 log_dict ["flow_values" ] = flow_values
@@ -427,7 +425,7 @@ def emd(
427425 return nx .from_numpy (G , type_as = type_as )
428426 else :
429427 raise ValueError (
430- "Cannot return matrix when return_matrix=False and sparse=True without log=True "
428+ "For sparse inputs, log=True is required to get the flow in edge format "
431429 )
432430
433431
@@ -441,7 +439,6 @@ def emd2(
441439 center_dual = True ,
442440 numThreads = 1 ,
443441 check_marginals = True ,
444- sparse = False ,
445442 return_matrix = False ,
446443):
447444 r"""Solves the Earth Movers distance problem and returns the loss
@@ -503,11 +500,12 @@ def emd2(
503500 check_marginals: bool, optional (default=True)
504501 If True, checks that the marginals mass are equal. If False, skips the
505502 check.
506- sparse: bool, optional (default=False)
507- If True, uses the sparse solver that only stores edges with finite costs.
508- This is memory-efficient when M has many infinite or forbidden edges.
509- When sparse=True, M should be a scipy.sparse matrix (coo, csr, or csc format)
510- or a tuple (row_indices, col_indices, costs) representing the edge list.
503+
504+ .. note:: The solver automatically detects sparse format when M is provided as:
505+ - A scipy.sparse matrix (coo, csr, csc, etc.)
506+ - A tuple (row_indices, col_indices, costs) representing an edge list
507+
508+ For sparse inputs, the solver uses a memory-efficient algorithm.
511509 Edges not included are treated as having infinite cost (forbidden).
512510
513511
@@ -554,14 +552,15 @@ def emd2(
554552 edge_costs = None
555553 n1 , n2 = None , None
556554
557- if sparse :
555+ # Auto-detect sparse format
556+ is_sparse = sp .issparse (M ) or (isinstance (M , tuple ) and len (M ) == 3 )
557+
558+ if is_sparse :
558559 if sp .issparse (M ):
559- t0 = time .perf_counter ()
560560 if not isinstance (M , sp .coo_matrix ):
561561 M_coo = sp .coo_matrix (M )
562562 else :
563563 M_coo = M
564- t1 = time .perf_counter ()
565564
566565 edge_sources = (
567566 M_coo .row if M_coo .row .dtype == np .int64 else M_coo .row .astype (np .int64 )
@@ -574,21 +573,13 @@ def emd2(
574573 if M_coo .data .dtype == np .float64
575574 else M_coo .data .astype (np .float64 )
576575 )
577- t2 = time .perf_counter ()
578- print (
579- f"[PY SPARSE] COO conversion: { (t1 - t0 )* 1000 :.3f} ms, array copies: { (t2 - t1 )* 1000 :.3f} ms"
580- )
581576 n1 , n2 = M_coo .shape
582577 elif isinstance (M , tuple ) and len (M ) == 3 :
583578 edge_sources = np .asarray (M [0 ], dtype = np .int64 )
584579 edge_targets = np .asarray (M [1 ], dtype = np .int64 )
585580 edge_costs = np .asarray (M [2 ], dtype = np .float64 )
586581 n1 = int (edge_sources .max () + 1 )
587582 n2 = int (edge_targets .max () + 1 )
588- else :
589- raise ValueError (
590- "When sparse=True, M must be a scipy sparse matrix or a tuple (row, col, data)"
591- )
592583
593584 a , b = list_to_array (a , b )
594585 else :
@@ -618,14 +609,14 @@ def emd2(
618609 )
619610
620611 a0 , b0 = a , b
621- M0 = None if sparse else M
612+ M0 = None if is_sparse else M
622613
623- if sparse :
614+ if is_sparse :
624615 edge_costs_original = nx .from_numpy (edge_costs , type_as = type_as )
625616 else :
626617 edge_costs_original = None
627618
628- if sparse :
619+ if is_sparse :
629620 a , b = nx .to_numpy (a , b )
630621 else :
631622 M , a , b = nx .to_numpy (M , a , b )
0 commit comments