@@ -821,42 +821,86 @@ def sliced_plans(
821821 for k in range (n_proj )
822822 ]
823823
824+ if not dense and str (nx ) == "jax" :
825+ warnings .warn ("JAX does not support sparse matrices, converting to dense" )
826+ plan = [nx .todense (plan [k ]) for k in range (n_proj )]
827+
824828 else : # we compute plans
825829 _ , plan = wasserstein_1d (
826830 X_theta , Y_theta , a , b , p , require_sort = True , return_plan = True
827831 )
828832
829- if metric in ("minkowski" , "euclidean" , "cityblock" ):
830- costs = [
831- nx .sum (
832- (
833- (nx .sum (nx .abs (X [plan [k ].row ] - Y [plan [k ].col ]) ** p , axis = 1 ))
834- ** (1 / p )
833+ if str (nx ) == "jax" : # dense computation
834+ if not dense :
835+ warnings .warn (
836+ "JAX does not support sparse matrices, converting to dense"
837+ )
838+
839+ plan = [nx .todense (plan [k ]) for k in range (n_proj )]
840+
841+ if metric in ("minkowski" , "euclidean" , "cityblock" ):
842+ costs = [
843+ nx .sum (
844+ (
845+ (
846+ nx .sum (
847+ nx .abs (X [:, None , :] - Y [None , :, :]) ** p , axis = - 1
848+ )
849+ )
850+ ** (1 / p )
851+ )
852+ * plan [k ].data
853+ )
854+ for k in range (n_proj )
855+ ]
856+ elif metric == "sqeuclidean" :
857+ costs = [
858+ nx .sum (
859+ (nx .sum ((X [:, None , :] - Y [None , :, :]) ** 2 , axis = - 1 ))
860+ * plan [k ].data
835861 )
836- * plan [k ].data
862+ for k in range (n_proj )
863+ ]
864+ else :
865+ raise ValueError (
866+ "Sliced plans work only with metrics "
867+ + "from the following list: "
868+ + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`"
837869 )
838- for k in range (n_proj )
839- ]
840- elif metric == "sqeuclidean" :
841- costs = [
842- nx .sum (
843- (nx .sum ((X [plan [k ].row ] - Y [plan [k ].col ]) ** 2 , axis = 1 ))
844- * plan [k ].data
870+
871+ else : # not jax, sparse computation
872+ if metric in ("minkowski" , "euclidean" , "cityblock" ):
873+ costs = [
874+ nx .sum (
875+ (
876+ (
877+ nx .sum (
878+ nx .abs (X [plan [k ].row ] - Y [plan [k ].col ]) ** p , axis = 1
879+ )
880+ )
881+ ** (1 / p )
882+ )
883+ * plan [k ].data
884+ )
885+ for k in range (n_proj )
886+ ]
887+ elif metric == "sqeuclidean" :
888+ costs = [
889+ nx .sum (
890+ (nx .sum ((X [plan [k ].row ] - Y [plan [k ].col ]) ** 2 , axis = 1 ))
891+ * plan [k ].data
892+ )
893+ for k in range (n_proj )
894+ ]
895+ else :
896+ raise ValueError (
897+ "Sliced plans work only with metrics "
898+ + "from the following list: "
899+ + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`"
845900 )
846- for k in range (n_proj )
847- ]
848- else :
849- raise ValueError (
850- "Sliced plans work only with metrics "
851- + "from the following list: "
852- + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`"
853- )
854901
855902 if dense :
856903 plan = [nx .todense (plan [k ]) for k in range (n_proj )]
857- elif str (nx ) == "jax" :
858- warnings .warn ("JAX does not support sparse matrices, converting to dense" )
859- plan = [nx .todense (plan [k ]) for k in range (n_proj )]
860904
861905 if log :
862906 log_dict = {"X_theta" : X_theta , "Y_theta" : Y_theta , "thetas" : thetas }
0 commit comments