Skip to content

Commit 1151a2c

Browse files
committed
testing update
1 parent 791601d commit 1151a2c

File tree

5 files changed

+113
-79
lines changed

5 files changed

+113
-79
lines changed

tests/testthat/_snaps/args_and_modes.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
set_mode(rand_forest())
1313
Condition
1414
Error in `set_mode()`:
15-
! Available modes for model type rand_forest are: "unknown", "classification", "regression", and "censored regression".
15+
! Available modes for model type rand_forest are: "unknown", "classification", "regression", "censored regression", and "quantile regression".
1616

1717
---
1818

tests/testthat/_snaps/rand_forest.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
res <- translate(set_engine(rand_forest(mode = "classification"), NULL))
2222
Condition
2323
Error in `set_engine()`:
24-
! Missing engine. Possible mode/engine combinations are: classification {ranger, randomForest, spark} and regression {ranger, randomForest, spark}.
24+
! Missing engine. Possible mode/engine combinations are: classification {ranger, randomForest, spark, grf}, quantile regression {grf}, and regression {ranger, randomForest, spark, grf}.
2525

2626
---
2727

tests/testthat/_snaps/registration.md

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,12 @@
363363
show_model_info("rand_forest")
364364
Output
365365
Information for `rand_forest`
366-
modes: unknown, classification, regression, censored regression
366+
modes: unknown, classification, regression, censored regression, quantile regression
367367
368368
engines:
369-
classification: randomForest, ranger1, spark
370-
regression: randomForest, ranger1, spark
369+
classification: grf1, randomForest, ranger1, spark
370+
quantile regression: grf1
371+
regression: grf1, randomForest, ranger1, spark
371372
372373
1The model can use case weights.
373374
@@ -384,24 +385,34 @@
384385
mtry --> feature_subset_strategy
385386
trees --> num_trees
386387
min_n --> min_instances_per_node
388+
grf:
389+
mtry --> mtry
390+
trees --> num.trees
391+
min_n --> min.node.size
387392
388393
fit modules:
389-
engine mode
390-
ranger classification
391-
ranger regression
392-
randomForest classification
393-
randomForest regression
394-
spark classification
395-
spark regression
394+
engine mode
395+
ranger classification
396+
ranger regression
397+
randomForest classification
398+
randomForest regression
399+
spark classification
400+
spark regression
401+
grf classification
402+
grf regression
403+
grf quantile regression
396404
397405
prediction modules:
398-
mode engine methods
399-
classification randomForest class, prob, raw
400-
classification ranger class, conf_int, prob, raw
401-
classification spark class, prob
402-
regression randomForest numeric, raw
403-
regression ranger conf_int, numeric, raw
404-
regression spark numeric
406+
mode engine methods
407+
classification grf class, conf_int, prob
408+
classification randomForest class, prob, raw
409+
classification ranger class, conf_int, prob, raw
410+
classification spark class, prob
411+
quantile regression grf quantile
412+
regression grf conf_int, numeric
413+
regression randomForest numeric, raw
414+
regression ranger conf_int, numeric, raw
415+
regression spark numeric
405416
406417

407418
---

tests/testthat/helper-objects.R

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
ctrl <- control_parsnip(verbosity = 1, catch = FALSE)
2-
caught_ctrl <- control_parsnip(verbosity = 1, catch = TRUE)
3-
quiet_ctrl <- control_parsnip(verbosity = 0, catch = TRUE)
1+
ctrl <- control_parsnip(verbosity = 1, catch = FALSE)
2+
caught_ctrl <- control_parsnip(verbosity = 1, catch = TRUE)
3+
quiet_ctrl <- control_parsnip(verbosity = 0, catch = TRUE)
44

55
run_glmnet <- utils::compareVersion('3.6.0', as.character(getRversion())) > 0
66

@@ -29,15 +29,15 @@ if (rlang::is_installed("modeldata")) {
2929

3030
# ------------------------------------------------------------------------------
3131

32-
hpc <- hpc_data[1:150, c(2:5, 8)]
32+
hpc <- modeldata::hpc_data[1:150, c(2:5, 8)]
3333
num_hpc_pred <- names(hpc)[1:4]
3434
class_tab <- table(hpc$class, dnn = NULL)
3535
hpc_bad <-
3636
hpc |>
3737
dplyr::mutate(big_num = Inf)
3838

3939
set.seed(352)
40-
mlp_dat <- hpc[order(runif(150)),]
40+
mlp_dat <- hpc[order(runif(150)), ]
4141

4242
tr_mlp_dat <- mlp_dat[1:140, ]
4343
te_mlp_dat <- mlp_dat[141:150, ]
@@ -46,7 +46,7 @@ if (rlang::is_installed("modeldata")) {
4646
mlp_hpc_pred_list <- names(hpc)[1:4]
4747
nnet_hpc_pred_list <- names(hpc)[1:4]
4848

49-
hpc_nnet_dat <- hpc_data[1:150, c(2:5, 8)]
49+
hpc_nnet_dat <- modeldata::hpc_data[1:150, c(2:5, 8)]
5050

5151
# ------------------------------------------------------------------------------
5252

@@ -56,7 +56,7 @@ if (rlang::is_installed("modeldata")) {
5656
fit(compounds ~ ., data = hpc)
5757

5858
lending_club <-
59-
lending_club |>
59+
modeldata::lending_club |>
6060
dplyr::slice(1:200) |>
6161
dplyr::mutate(big_num = Inf)
6262

@@ -73,7 +73,7 @@ if (rlang::is_installed("modeldata")) {
7373
dplyr::select(price, beds, baths, sqft, latitude, longitude)
7474

7575
sac_train <- Sacramento_small[-(1:5), ]
76-
sac_test <- Sacramento_small[ 1:5 , ]
76+
sac_test <- Sacramento_small[1:5, ]
7777

7878
# ------------------------------------------------------------------------------
7979
# For sparse tibble testing

tests/testthat/test-registration.R

Lines changed: 75 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3,45 +3,60 @@ test_that('adding a new model', {
33

44
mod_items <- get_model_env() |> rlang::env_names()
55
sponges <- grep("sponge", mod_items, value = TRUE)
6-
exp_obj <- c('sponge_modes', 'sponge_fit', 'sponge_args',
7-
'sponge_predict', 'sponge_pkgs', 'sponge')
6+
exp_obj <- c(
7+
'sponge_modes',
8+
'sponge_fit',
9+
'sponge_args',
10+
'sponge_predict',
11+
'sponge_pkgs',
12+
'sponge'
13+
)
814
expect_equal(sort(sponges), sort(exp_obj))
915

1016
expect_equal(
1117
get_from_env("sponge"),
1218
tibble(engine = character(0), mode = character(0))
1319
)
1420

15-
expect_equal(
16-
get_from_env("sponge_pkgs"),
17-
tibble(engine = character(0), pkg = list(), mode = character(0))
18-
)
19-
20-
expect_equal(
21-
get_from_env("sponge_modes"), "unknown"
22-
)
23-
24-
expect_equal(
25-
get_from_env("sponge_args"),
26-
dplyr::tibble(engine = character(0), parsnip = character(0),
27-
original = character(0), func = vector("list"),
28-
has_submodel = logical(0))
29-
)
30-
31-
expect_equal(
32-
get_from_env("sponge_fit"),
33-
tibble(engine = character(0), mode = character(0), value = vector("list"))
34-
)
35-
36-
expect_equal(
37-
get_from_env("sponge_predict"),
38-
tibble(engine = character(0), mode = character(0),
39-
type = character(0), value = vector("list"))
40-
)
41-
42-
expect_snapshot(error = TRUE, set_new_model())
43-
expect_snapshot(error = TRUE, set_new_model(2))
44-
expect_snapshot(error = TRUE, set_new_model(letters[1:2]))
21+
expect_equal(
22+
get_from_env("sponge_pkgs"),
23+
tibble(engine = character(0), pkg = list(), mode = character(0))
24+
)
25+
26+
expect_equal(
27+
get_from_env("sponge_modes"),
28+
"unknown"
29+
)
30+
31+
expect_equal(
32+
get_from_env("sponge_args"),
33+
dplyr::tibble(
34+
engine = character(0),
35+
parsnip = character(0),
36+
original = character(0),
37+
func = vector("list"),
38+
has_submodel = logical(0)
39+
)
40+
)
41+
42+
expect_equal(
43+
get_from_env("sponge_fit"),
44+
tibble(engine = character(0), mode = character(0), value = vector("list"))
45+
)
46+
47+
expect_equal(
48+
get_from_env("sponge_predict"),
49+
tibble(
50+
engine = character(0),
51+
mode = character(0),
52+
type = character(0),
53+
value = vector("list")
54+
)
55+
)
56+
57+
expect_snapshot(error = TRUE, set_new_model())
58+
expect_snapshot(error = TRUE, set_new_model(2))
59+
expect_snapshot(error = TRUE, set_new_model(letters[1:2]))
4560
})
4661

4762

@@ -58,7 +73,6 @@ test_that('adding a new mode', {
5873
expect_equal(get_from_env("sponge_modes"), c("unknown", "classification"))
5974

6075
expect_snapshot(error = TRUE, set_model_mode("sponge"))
61-
6276
})
6377

6478

@@ -75,7 +89,10 @@ test_that('adding a new engine', {
7589
expect_equal(get_from_env("sponge_modes"), c("unknown", "classification"))
7690

7791
expect_snapshot(error = TRUE, set_model_engine("sponge", eng = "gum"))
78-
expect_snapshot(error = TRUE, set_model_engine("sponge", mode = "classification"))
92+
expect_snapshot(
93+
error = TRUE,
94+
set_model_engine("sponge", mode = "classification")
95+
)
7996
expect_snapshot(
8097
error = TRUE,
8198
set_model_engine("sponge", mode = "regression", eng = "gum")
@@ -90,7 +107,10 @@ test_that('adding a new package', {
90107

91108
expect_snapshot(error = TRUE, set_dependency("sponge", "gum", letters[1:2]))
92109
expect_snapshot(error = TRUE, set_dependency("sponge", "gummies", "trident"))
93-
expect_snapshot(error = TRUE, set_dependency("sponge", "gum", "trident", mode = "regression"))
110+
expect_snapshot(
111+
error = TRUE,
112+
set_dependency("sponge", "gum", "trident", mode = "regression")
113+
)
94114

95115
expect_equal(
96116
get_from_env("sponge_pkgs"),
@@ -100,16 +120,20 @@ test_that('adding a new package', {
100120
set_dependency("sponge", "gum", "juicy-fruit", mode = "classification")
101121
expect_equal(
102122
get_from_env("sponge_pkgs"),
103-
tibble(engine = "gum",
104-
pkg = list(c("trident", "juicy-fruit")),
105-
mode = "classification")
123+
tibble(
124+
engine = "gum",
125+
pkg = list(c("trident", "juicy-fruit")),
126+
mode = "classification"
127+
)
106128
)
107129

108130
expect_equal(
109131
get_dependency("sponge"),
110-
tibble(engine = "gum",
111-
pkg = list(c("trident", "juicy-fruit")),
112-
mode = "classification")
132+
tibble(
133+
engine = "gum",
134+
pkg = list(c("trident", "juicy-fruit")),
135+
mode = "classification"
136+
)
113137
)
114138
})
115139

@@ -140,9 +164,13 @@ test_that('adding a new argument', {
140164

141165
expect_equal(
142166
get_from_env("sponge_args"),
143-
tibble(engine = "gum", parsnip = "modeling", original = "modelling",
144-
func = list(list(pkg = "foo", fun = "bar")),
145-
has_submodel = FALSE)
167+
tibble(
168+
engine = "gum",
169+
parsnip = "modeling",
170+
original = "modelling",
171+
func = list(list(pkg = "foo", fun = "bar")),
172+
has_submodel = FALSE
173+
)
146174
)
147175

148176
expect_snapshot(
@@ -252,7 +280,6 @@ test_that('adding a new argument', {
252280
})
253281

254282

255-
256283
# ------------------------------------------------------------------------------
257284

258285
test_that('adding a new fit', {
@@ -273,7 +300,7 @@ test_that('adding a new fit', {
273300

274301
fit_env_data <- get_from_env("sponge_fit")
275302
expect_equal(
276-
fit_env_data[ 1:2],
303+
fit_env_data[1:2],
277304
tibble(engine = "gum", mode = "classification")
278305
)
279306

@@ -405,7 +432,7 @@ test_that('adding a new predict method', {
405432

406433
pred_env_data <- get_from_env("sponge_predict")
407434
expect_equal(
408-
pred_env_data[ 1:3],
435+
pred_env_data[1:3],
409436
tibble(engine = "gum", mode = "classification", type = "class")
410437
)
411438

@@ -415,7 +442,7 @@ test_that('adding a new predict method', {
415442
)
416443

417444
expect_equal(
418-
get_pred_type("sponge", "class")[ 1:3],
445+
get_pred_type("sponge", "class")[1:3],
419446
tibble(engine = "gum", mode = "classification", type = "class")
420447
)
421448

@@ -446,7 +473,6 @@ test_that('adding a new predict method', {
446473
)
447474
)
448475

449-
450476
expect_snapshot(
451477
error = TRUE,
452478
set_pred(
@@ -520,16 +546,13 @@ test_that('adding a new predict method', {
520546
value = class_vals_2
521547
)
522548
)
523-
524549
})
525550

526551

527-
528552
test_that('showing model info', {
529553
expect_snapshot(show_model_info("rand_forest"))
530554

531555
# ensure that we don't mention case weight support when the
532556
# notation would be ambiguous (#1000)
533557
expect_snapshot(show_model_info("mlp"))
534558
})
535-

0 commit comments

Comments
 (0)