Skip to content

Commit c71e544

Browse files
update tests
1 parent 75d6c11 commit c71e544

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

test/test_solvers.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -764,13 +764,13 @@ def test_bary_free_support(nx, reg, reg_type, unbalanced, unbalanced_type, warms
764764
ns = rng.randint(10, 20, K) # number of samples within each distribution
765765
n = 5 # number of samples in the barycenter
766766

767-
X_s = [rng.randn(ns_i, 2) for ns_i in ns]
767+
X_list = [rng.randn(ns_i, 2) for ns_i in ns]
768768
# X_init = np.reshape(1.0 * np.randn(n, 2), (n, 1))
769769

770-
a_s = [ot.utils.unif(X.shape[0]) for X in X_s]
770+
a_list = [ot.utils.unif(X.shape[0]) for X in X_list]
771771
b = ot.utils.unif(n)
772772

773-
w_s = ot.utils.unif(K)
773+
w = ot.utils.unif(K)
774774

775775
try:
776776
if reg_type == "tuple":
@@ -784,10 +784,10 @@ def df(G):
784784
reg_type = (f, df)
785785
# print('test reg_type:', reg_type[0](None), reg_type[1](None))
786786
# solve default None weights
787-
sol0 = ot.bary_sample(
788-
X_s,
787+
sol0 = ot.bary_free_support(
788+
X_list,
789789
n,
790-
w_s=None,
790+
w=None,
791791
metric="sqeuclidean",
792792
reg=reg,
793793
reg_type=reg_type,
@@ -802,12 +802,12 @@ def df(G):
802802

803803
# solve provided uniform weights
804804

805-
sol = ot.bary_sample(
806-
X_s,
805+
sol = ot.bary_free_support(
806+
X_list,
807807
n,
808-
a_s=a_s,
808+
a_list=a_list,
809809
b_init=b,
810-
w_s=w_s,
810+
w=w,
811811
metric="sqeuclidean",
812812
reg=reg,
813813
reg_type=reg_type,
@@ -823,9 +823,9 @@ def df(G):
823823
assert_allclose_bary_sol(sol0, sol)
824824

825825
# solve in backend
826-
X_sb = nx.from_numpy(*X_s)
827-
a_sb = nx.from_numpy(*a_s)
828-
w_sb, bb = nx.from_numpy(w_s, b)
826+
X_listb = nx.from_numpy(*X_list)
827+
a_listb = nx.from_numpy(*a_list)
828+
wb, bb = nx.from_numpy(w, b)
829829

830830
if reg_type == "tuple":
831831

@@ -848,12 +848,12 @@ def dfb(G):
848848
"""
849849
reg_type = (f, df)
850850

851-
solb = ot.bary_sample(
852-
X_sb,
851+
solb = ot.bary_free_support(
852+
X_listb,
853853
n,
854-
a_s=a_sb,
854+
a_listb=a_listb,
855855
b_init=bb,
856-
w_s=w_sb,
856+
w=wb,
857857
metric="sqeuclidean",
858858
reg=reg,
859859
reg_type=reg_type,

0 commit comments

Comments
 (0)