@@ -294,9 +294,17 @@ def emd(
294294 else :
295295 M_coo = M
296296
297- edge_sources = M_coo .row if M_coo .row .dtype == np .int64 else M_coo .row .astype (np .int64 )
298- edge_targets = M_coo .col if M_coo .col .dtype == np .int64 else M_coo .col .astype (np .int64 )
299- edge_costs = M_coo .data if M_coo .data .dtype == np .float64 else M_coo .data .astype (np .float64 )
297+ edge_sources = (
298+ M_coo .row if M_coo .row .dtype == np .int64 else M_coo .row .astype (np .int64 )
299+ )
300+ edge_targets = (
301+ M_coo .col if M_coo .col .dtype == np .int64 else M_coo .col .astype (np .int64 )
302+ )
303+ edge_costs = (
304+ M_coo .data
305+ if M_coo .data .dtype == np .float64
306+ else M_coo .data .astype (np .float64 )
307+ )
300308 n1 , n2 = M_coo .shape
301309 elif isinstance (M , tuple ) and len (M ) == 3 :
302310 edge_sources = np .asarray (M [0 ], dtype = np .int64 )
@@ -305,7 +313,9 @@ def emd(
305313 n1 = int (edge_sources .max () + 1 )
306314 n2 = int (edge_targets .max () + 1 )
307315 else :
308- raise ValueError ("When sparse=True, M must be a scipy sparse matrix or a tuple (row, col, data)" )
316+ raise ValueError (
317+ "When sparse=True, M must be a scipy sparse matrix or a tuple (row, col, data)"
318+ )
309319
310320 a , b = list_to_array (a , b )
311321 else :
@@ -321,9 +331,17 @@ def emd(
321331 type_as = a
322332
323333 if len (a ) == 0 :
324- a = nx .ones ((n1 ,), type_as = type_as ) / n1 if n1 else nx .ones ((M .shape [0 ],), type_as = type_as ) / M .shape [0 ]
334+ a = (
335+ nx .ones ((n1 ,), type_as = type_as ) / n1
336+ if n1
337+ else nx .ones ((M .shape [0 ],), type_as = type_as ) / M .shape [0 ]
338+ )
325339 if len (b ) == 0 :
326- b = nx .ones ((n2 ,), type_as = type_as ) / n2 if n2 else nx .ones ((M .shape [1 ],), type_as = type_as ) / M .shape [1 ]
340+ b = (
341+ nx .ones ((n2 ,), type_as = type_as ) / n2
342+ if n2
343+ else nx .ones ((M .shape [1 ],), type_as = type_as ) / M .shape [1 ]
344+ )
327345
328346 if sparse :
329347 a , b = nx .to_numpy (a , b )
@@ -334,7 +352,6 @@ def emd(
334352 a = np .asarray (a , dtype = np .float64 )
335353 b = np .asarray (b , dtype = np .float64 )
336354
337-
338355 if n1 is None :
339356 n1 , n2 = M .shape
340357
@@ -409,7 +426,9 @@ def emd(
409426 if G is not None :
410427 return nx .from_numpy (G , type_as = type_as )
411428 else :
412- raise ValueError ("Cannot return matrix when return_matrix=False and sparse=True without log=True" )
429+ raise ValueError (
430+ "Cannot return matrix when return_matrix=False and sparse=True without log=True"
431+ )
413432
414433
415434def emd2 (
@@ -419,12 +438,11 @@ def emd2(
419438 processes = 1 ,
420439 numItermax = 100000 ,
421440 log = False ,
422-
423441 center_dual = True ,
424442 numThreads = 1 ,
425443 check_marginals = True ,
426444 sparse = False ,
427- return_matrix = False
445+ return_matrix = False ,
428446):
429447 r"""Solves the Earth Movers distance problem and returns the loss
430448
@@ -534,7 +552,7 @@ def emd2(
534552 edge_sources = None
535553 edge_targets = None
536554 edge_costs = None
537- n1 , n2 = None , None
555+ n1 , n2 = None , None
538556
539557 if sparse :
540558 if sp .issparse (M ):
@@ -545,11 +563,21 @@ def emd2(
545563 M_coo = M
546564 t1 = time .perf_counter ()
547565
548- edge_sources = M_coo .row if M_coo .row .dtype == np .int64 else M_coo .row .astype (np .int64 )
549- edge_targets = M_coo .col if M_coo .col .dtype == np .int64 else M_coo .col .astype (np .int64 )
550- edge_costs = M_coo .data if M_coo .data .dtype == np .float64 else M_coo .data .astype (np .float64 )
566+ edge_sources = (
567+ M_coo .row if M_coo .row .dtype == np .int64 else M_coo .row .astype (np .int64 )
568+ )
569+ edge_targets = (
570+ M_coo .col if M_coo .col .dtype == np .int64 else M_coo .col .astype (np .int64 )
571+ )
572+ edge_costs = (
573+ M_coo .data
574+ if M_coo .data .dtype == np .float64
575+ else M_coo .data .astype (np .float64 )
576+ )
551577 t2 = time .perf_counter ()
552- print (f"[PY SPARSE] COO conversion: { (t1 - t0 )* 1000 :.3f} ms, array copies: { (t2 - t1 )* 1000 :.3f} ms" )
578+ print (
579+ f"[PY SPARSE] COO conversion: { (t1 - t0 )* 1000 :.3f} ms, array copies: { (t2 - t1 )* 1000 :.3f} ms"
580+ )
553581 n1 , n2 = M_coo .shape
554582 elif isinstance (M , tuple ) and len (M ) == 3 :
555583 edge_sources = np .asarray (M [0 ], dtype = np .int64 )
@@ -577,12 +605,20 @@ def emd2(
577605
578606 # if empty array given then use uniform distributions
579607 if len (a ) == 0 :
580- a = nx .ones ((n1 ,), type_as = type_as ) / n1 if n1 else nx .ones ((M .shape [0 ],), type_as = type_as ) / M .shape [0 ]
608+ a = (
609+ nx .ones ((n1 ,), type_as = type_as ) / n1
610+ if n1
611+ else nx .ones ((M .shape [0 ],), type_as = type_as ) / M .shape [0 ]
612+ )
581613 if len (b ) == 0 :
582- b = nx .ones ((n2 ,), type_as = type_as ) / n2 if n2 else nx .ones ((M .shape [1 ],), type_as = type_as ) / M .shape [1 ]
614+ b = (
615+ nx .ones ((n2 ,), type_as = type_as ) / n2
616+ if n2
617+ else nx .ones ((M .shape [1 ],), type_as = type_as ) / M .shape [1 ]
618+ )
583619
584620 a0 , b0 = a , b
585- M0 = None if sparse else M
621+ M0 = None if sparse else M
586622
587623 if sparse :
588624 edge_costs_original = nx .from_numpy (edge_costs , type_as = type_as )
@@ -625,15 +661,24 @@ def f(b):
625661 bsel = b != 0
626662
627663 if edge_sources is not None :
628- flow_sources , flow_targets , flow_values , cost , u , v , result_code = emd_c_sparse (
629- a , b , edge_sources , edge_targets , edge_costs , numItermax
664+ flow_sources , flow_targets , flow_values , cost , u , v , result_code = (
665+ emd_c_sparse (
666+ a , b , edge_sources , edge_targets , edge_costs , numItermax
667+ )
630668 )
631669
632- edge_to_idx = {(edge_sources [k ], edge_targets [k ]): k for k in range (len (edge_sources ))}
670+ edge_to_idx = {
671+ (edge_sources [k ], edge_targets [k ]): k
672+ for k in range (len (edge_sources ))
673+ }
633674
634675 grad_edge_costs = np .zeros (len (edge_costs ), dtype = np .float64 )
635676 for idx in range (len (flow_sources )):
636- src , tgt , flow = flow_sources [idx ], flow_targets [idx ], flow_values [idx ]
677+ src , tgt , flow = (
678+ flow_sources [idx ],
679+ flow_targets [idx ],
680+ flow_values [idx ],
681+ )
637682 edge_idx = edge_to_idx .get ((src , tgt ), - 1 )
638683 if edge_idx >= 0 :
639684 grad_edge_costs [edge_idx ] = flow
@@ -679,7 +724,11 @@ def f(b):
679724 cost = nx .set_gradients (
680725 nx .from_numpy (cost , type_as = type_as ),
681726 (a0 , b0 , edge_costs_original ),
682- (log ["u" ] - nx .mean (log ["u" ]), log ["v" ] - nx .mean (log ["v" ]), nx .from_numpy (grad_edge_costs , type_as = type_as )),
727+ (
728+ log ["u" ] - nx .mean (log ["u" ]),
729+ log ["v" ] - nx .mean (log ["v" ]),
730+ nx .from_numpy (grad_edge_costs , type_as = type_as ),
731+ ),
683732 )
684733 else :
685734 cost = nx .set_gradients (
@@ -694,14 +743,23 @@ def f(b):
694743 bsel = b != 0
695744
696745 if edge_sources is not None :
697- flow_sources , flow_targets , flow_values , cost , u , v , result_code = emd_c_sparse (
698- a , b , edge_sources , edge_targets , edge_costs , numItermax
746+ flow_sources , flow_targets , flow_values , cost , u , v , result_code = (
747+ emd_c_sparse (
748+ a , b , edge_sources , edge_targets , edge_costs , numItermax
749+ )
699750 )
700751
701- edge_to_idx = {(edge_sources [k ], edge_targets [k ]): k for k in range (len (edge_sources ))}
752+ edge_to_idx = {
753+ (edge_sources [k ], edge_targets [k ]): k
754+ for k in range (len (edge_sources ))
755+ }
702756 grad_edge_costs = np .zeros (len (edge_costs ), dtype = np .float64 )
703757 for idx in range (len (flow_sources )):
704- src , tgt , flow = flow_sources [idx ], flow_targets [idx ], flow_values [idx ]
758+ src , tgt , flow = (
759+ flow_sources [idx ],
760+ flow_targets [idx ],
761+ flow_values [idx ],
762+ )
705763 edge_idx = edge_to_idx .get ((src , tgt ), - 1 )
706764 if edge_idx >= 0 :
707765 grad_edge_costs [edge_idx ] = flow
0 commit comments