Skip to content

Commit 6d27ff9

Browse files
initial resample function
1 parent 1fd8bd6 commit 6d27ff9

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

tests/test_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -334,20 +334,20 @@ def test_check_logger_exists():
334334

335335

336336
def test_class_stratify_check():
337-
selection_frac = 0.9
337+
train_frac = 0.9
338338
idx = np.arange(100)
339339
y = np.tile(np.arange(5), 20)
340-
train, test = resample(idx, selection_frac=selection_frac, random_state=0, stratify=y)
340+
train, test = resample(idx, train_frac=train_frac, random_state=0, stratify=y)
341341

342-
if int(np.ceil(len(idx) * selection_frac)) != len(train):
342+
if int(np.ceil(len(idx) * train_frac)) != len(train):
343343
raise ValueError("Incorrect train size")
344-
if (len(idx) - int(np.ceil(len(idx) * selection_frac))) != len(test):
344+
if (len(idx) - int(np.ceil(len(idx) * train_frac))) != len(test):
345345
raise ValueError("Incorrect test size")
346346

347347
classes, dist = np.unique(y, return_counts=True)
348348

349349
for cl, di in zip(classes, dist):
350-
if int(np.ceil(di * selection_frac)) != sum(y[train] == cl):
350+
if int(np.ceil(di * train_frac)) != sum(y[train] == cl):
351351
raise ValueError(f"Incorrect train class size {cl}")
352-
if di - int(np.ceil(di * selection_frac)) != sum(y[test] == cl):
352+
if di - int(np.ceil(di * train_frac)) != sum(y[test] == cl):
353353
raise ValueError(f"Incorrect test class size {cl}")

0 commit comments

Comments
 (0)