Skip to content

Commit dd1b31f

Browse files
committed
dense computation of costs for sliced_plans with jax
1 parent aee78fc commit dd1b31f

File tree

1 file changed

+69
-25
lines changed

1 file changed

+69
-25
lines changed

ot/sliced.py

Lines changed: 69 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -821,42 +821,86 @@ def sliced_plans(
821821
for k in range(n_proj)
822822
]
823823

824+
if not dense and str(nx) == "jax":
825+
warnings.warn("JAX does not support sparse matrices, converting to dense")
826+
plan = [nx.todense(plan[k]) for k in range(n_proj)]
827+
824828
else: # we compute plans
825829
_, plan = wasserstein_1d(
826830
X_theta, Y_theta, a, b, p, require_sort=True, return_plan=True
827831
)
828832

829-
if metric in ("minkowski", "euclidean", "cityblock"):
830-
costs = [
831-
nx.sum(
832-
(
833-
(nx.sum(nx.abs(X[plan[k].row] - Y[plan[k].col]) ** p, axis=1))
834-
** (1 / p)
833+
if str(nx) == "jax": # dense computation
834+
if not dense:
835+
warnings.warn(
836+
"JAX does not support sparse matrices, converting to dense"
837+
)
838+
839+
plan = [nx.todense(plan[k]) for k in range(n_proj)]
840+
841+
if metric in ("minkowski", "euclidean", "cityblock"):
842+
costs = [
843+
nx.sum(
844+
(
845+
(
846+
nx.sum(
847+
nx.abs(X[:, None, :] - Y[None, :, :]) ** p, axis=-1
848+
)
849+
)
850+
** (1 / p)
851+
)
852+
* plan[k].data
853+
)
854+
for k in range(n_proj)
855+
]
856+
elif metric == "sqeuclidean":
857+
costs = [
858+
nx.sum(
859+
(nx.sum((X[:, None, :] - Y[None, :, :]) ** 2, axis=-1))
860+
* plan[k].data
835861
)
836-
* plan[k].data
862+
for k in range(n_proj)
863+
]
864+
else:
865+
raise ValueError(
866+
"Sliced plans work only with metrics "
867+
+ "from the following list: "
868+
+ "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`"
837869
)
838-
for k in range(n_proj)
839-
]
840-
elif metric == "sqeuclidean":
841-
costs = [
842-
nx.sum(
843-
(nx.sum((X[plan[k].row] - Y[plan[k].col]) ** 2, axis=1))
844-
* plan[k].data
870+
871+
else: # not jax, sparse computation
872+
if metric in ("minkowski", "euclidean", "cityblock"):
873+
costs = [
874+
nx.sum(
875+
(
876+
(
877+
nx.sum(
878+
nx.abs(X[plan[k].row] - Y[plan[k].col]) ** p, axis=1
879+
)
880+
)
881+
** (1 / p)
882+
)
883+
* plan[k].data
884+
)
885+
for k in range(n_proj)
886+
]
887+
elif metric == "sqeuclidean":
888+
costs = [
889+
nx.sum(
890+
(nx.sum((X[plan[k].row] - Y[plan[k].col]) ** 2, axis=1))
891+
* plan[k].data
892+
)
893+
for k in range(n_proj)
894+
]
895+
else:
896+
raise ValueError(
897+
"Sliced plans work only with metrics "
898+
+ "from the following list: "
899+
+ "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`"
845900
)
846-
for k in range(n_proj)
847-
]
848-
else:
849-
raise ValueError(
850-
"Sliced plans work only with metrics "
851-
+ "from the following list: "
852-
+ "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`"
853-
)
854901

855902
if dense:
856903
plan = [nx.todense(plan[k]) for k in range(n_proj)]
857-
elif str(nx) == "jax":
858-
warnings.warn("JAX does not support sparse matrices, converting to dense")
859-
plan = [nx.todense(plan[k]) for k in range(n_proj)]
860904

861905
if log:
862906
log_dict = {"X_theta": X_theta, "Y_theta": Y_theta, "thetas": thetas}

0 commit comments

Comments
 (0)