Skip to content

Commit eb2651c

Browse files
committed
make xgboost tests version robust
1 parent b6112b2 commit eb2651c

File tree

1 file changed

+146
-70
lines changed

1 file changed

+146
-70
lines changed

tests/testthat/test-boost_tree_xgboost.R

Lines changed: 146 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,24 @@ hpc_xgboost <-
88
boost_tree(trees = 2, mode = "classification") |>
99
set_engine("xgboost")
1010

11+
extract_xgb_param <- function(x, param) {
12+
if (utils::packageVersion("xgboost") > "2.0.0.0") {
13+
res <- attr(extract_fit_engine(x), "params")[[param]]
14+
} else {
15+
res <- extract_fit_engine(x)[[param]]
16+
}
17+
res
18+
}
19+
20+
extract_xgb_evaluation_log <- function(x) {
21+
if (utils::packageVersion("xgboost") > "2.0.0.0") {
22+
res <- attr(extract_fit_engine(x), "evaluation_log")
23+
} else {
24+
res <- extract_fit_engine(x)[["evaluation_log"]]
25+
}
26+
res
27+
}
28+
1129
# ------------------------------------------------------------------------------
1230

1331
test_that('xgboost execution, classification', {
@@ -59,13 +77,21 @@ test_that('xgboost execution, classification', {
5977
)
6078
})
6179

62-
expect_equal(res_f$fit$evaluation_log, res_xy$fit$evaluation_log)
63-
expect_equal(res_f_wts$fit$evaluation_log, res_xy_wts$fit$evaluation_log)
80+
expect_equal(
81+
extract_xgb_evaluation_log(res_f),
82+
extract_xgb_evaluation_log(res_xy)
83+
)
84+
expect_equal(
85+
extract_xgb_evaluation_log(res_f_wts),
86+
extract_xgb_evaluation_log(res_xy_wts)
87+
)
6488
# Check to see if the case weights had an effect
6589
expect_true(
66-
!isTRUE(all.equal(res_f$fit$evaluation_log, res_f_wts$fit$evaluation_log))
90+
!isTRUE(all.equal(
91+
extract_xgb_evaluation_log(res_f),
92+
extract_xgb_evaluation_log(res_f_wts)
93+
))
6794
)
68-
6995
expect_true(has_multi_predict(res_xy))
7096
expect_equal(multi_predict_args(res_xy), "trees")
7197

@@ -209,10 +235,7 @@ test_that('xgboost regression prediction', {
209235
)
210236
expect_equal(form_pred, predict(form_fit, new_data = mtcars[1:8, -1])$.pred)
211237

212-
expect_equal(
213-
extract_fit_engine(form_fit)$params$objective,
214-
"reg:squarederror"
215-
)
238+
expect_equal(extract_xgb_param(form_fit, "objective"), "reg:squarederror")
216239
})
217240

218241

@@ -228,10 +251,7 @@ test_that('xgboost alternate objective', {
228251
set_mode("regression")
229252

230253
xgb_fit <- spec |> fit(mpg ~ ., data = mtcars)
231-
expect_equal(
232-
extract_fit_engine(xgb_fit)$params$objective,
233-
"reg:pseudohubererror"
234-
)
254+
expect_equal(extract_xgb_param(xgb_fit, "objective"), "reg:pseudohubererror")
235255
expect_no_error(xgb_preds <- predict(xgb_fit, new_data = mtcars[1, ]))
236256
expect_s3_class(xgb_preds, "data.frame")
237257

@@ -333,7 +353,7 @@ test_that('validation sets', {
333353
)
334354

335355
expect_equal(
336-
colnames(extract_fit_engine(reg_fit)$evaluation_log)[2],
356+
colnames(extract_xgb_evaluation_log(reg_fit))[2],
337357
"validation_rmse"
338358
)
339359

@@ -345,7 +365,7 @@ test_that('validation sets', {
345365
)
346366

347367
expect_equal(
348-
colnames(extract_fit_engine(reg_fit)$evaluation_log)[2],
368+
colnames(extract_xgb_evaluation_log(reg_fit))[2],
349369
"validation_mae"
350370
)
351371

@@ -357,7 +377,7 @@ test_that('validation sets', {
357377
)
358378

359379
expect_equal(
360-
colnames(extract_fit_engine(reg_fit)$evaluation_log)[2],
380+
colnames(extract_xgb_evaluation_log(reg_fit))[2],
361381
"training_mae"
362382
)
363383

@@ -387,12 +407,29 @@ test_that('early stopping', {
387407
fit(mpg ~ ., data = mtcars[-(1:4), ])
388408
)
389409

410+
extract_xgb_nitter <- function(x) {
411+
if (utils::packageVersion("xgboost") > "2.0.0.0") {
412+
res <- nrow(attr(extract_fit_engine(x), "evaluation_log"))
413+
} else {
414+
res <- extract_fit_engine(reg_fit)$niter
415+
}
416+
res
417+
}
418+
extract_xgb_best_iteration <- function(x) {
419+
if (utils::packageVersion("xgboost") > "2.0.0.0") {
420+
res <- attr(extract_fit_engine(x), "early_stop")$best_iteration
421+
} else {
422+
res <- extract_fit_engine(reg_fit)$best_iteration
423+
}
424+
res
425+
}
426+
390427
expect_equal(
391-
extract_fit_engine(reg_fit)$niter -
392-
extract_fit_engine(reg_fit)$best_iteration,
428+
extract_xgb_nitter(reg_fit) -
429+
extract_xgb_best_iteration(reg_fit),
393430
5
394431
)
395-
expect_true(extract_fit_engine(reg_fit)$niter < 200)
432+
expect_true(extract_xgb_nitter(reg_fit) < 200)
396433

397434
expect_no_condition(
398435
reg_fit <-
@@ -535,16 +572,29 @@ test_that('xgboost data and sparse matrices', {
535572
from_mat$fit$handle <- NULL
536573
from_sparse$fit$handle <- NULL
537574

538-
expect_equal(
539-
extract_fit_engine(from_df),
540-
extract_fit_engine(from_mat),
541-
ignore_function_env = TRUE
542-
)
543-
expect_equal(
544-
extract_fit_engine(from_df),
545-
extract_fit_engine(from_sparse),
546-
ignore_function_env = TRUE
547-
)
575+
if (utils::packageVersion("xgboost") > "2.0.0.0") {
576+
expect_equal(
577+
attributes(extract_fit_engine(from_df)),
578+
attributes(extract_fit_engine(from_mat)),
579+
ignore_function_env = TRUE
580+
)
581+
expect_equal(
582+
attributes(extract_fit_engine(from_df)),
583+
attributes(extract_fit_engine(from_sparse)),
584+
ignore_function_env = TRUE
585+
)
586+
} else {
587+
expect_equal(
588+
extract_fit_engine(from_df),
589+
extract_fit_engine(from_mat),
590+
ignore_function_env = TRUE
591+
)
592+
expect_equal(
593+
extract_fit_engine(from_df),
594+
extract_fit_engine(from_sparse),
595+
ignore_function_env = TRUE
596+
)
597+
}
548598

549599
# case weights added
550600
expect_no_condition(
@@ -591,14 +641,20 @@ test_that('argument checks for data dimensions', {
591641
xy_fit <- spec |>
592642
fit_xy(x = penguins_dummy, y = penguins$species, control = ctrl)
593643
)
594-
expect_equal(extract_fit_engine(f_fit)$params$colsample_bynode, 1)
595644
expect_equal(
596-
extract_fit_engine(f_fit)$params$min_child_weight,
645+
extract_xgb_param(f_fit, "colsample_bynode"),
646+
1
647+
)
648+
expect_equal(
649+
extract_xgb_param(f_fit, "min_child_weight"),
597650
nrow(penguins)
598651
)
599-
expect_equal(extract_fit_engine(xy_fit)$params$colsample_bynode, 1)
600652
expect_equal(
601-
extract_fit_engine(xy_fit)$params$min_child_weight,
653+
extract_xgb_param(xy_fit, "colsample_bynode"),
654+
1
655+
)
656+
expect_equal(
657+
extract_xgb_param(xy_fit, "min_child_weight"),
602658
nrow(penguins)
603659
)
604660
})
@@ -633,15 +689,27 @@ test_that("fit and prediction with `event_level`", {
633689
xgbmat_train_1 <- xgb.DMatrix(data = train_x, label = train_y_1)
634690

635691
set.seed(24)
636-
fit_xgb_1 <- xgboost::xgb.train(
637-
data = xgbmat_train_1,
638-
nrounds = 10,
639-
watchlist = list("training" = xgbmat_train_1),
640-
objective = "binary:logistic",
641-
eval_metric = "auc",
642-
verbose = 0
643-
)
644-
692+
if (utils::packageVersion("xgboost") > "2.0.0.0") {
693+
fit_xgb_1 <- xgboost::xgb.train(
694+
params = list(
695+
objective = "binary:logistic",
696+
eval_metric = "auc"
697+
),
698+
data = xgbmat_train_1,
699+
nrounds = 10,
700+
evals = list("training" = xgbmat_train_1),
701+
verbose = 0
702+
)
703+
} else {
704+
fit_xgb_1 <- xgboost::xgb.train(
705+
data = xgbmat_train_1,
706+
nrounds = 10,
707+
watchlist = list("training" = xgbmat_train_1),
708+
objective = "binary:logistic",
709+
eval_metric = "auc",
710+
verbose = 0
711+
)
712+
}
645713
expect_equal(
646714
extract_fit_engine(fit_p_1)$evaluation_log,
647715
fit_xgb_1$evaluation_log
@@ -661,14 +729,27 @@ test_that("fit and prediction with `event_level`", {
661729
xgbmat_train_2 <- xgb.DMatrix(data = train_x, label = train_y_2)
662730

663731
set.seed(24)
664-
fit_xgb_2 <- xgboost::xgb.train(
665-
data = xgbmat_train_2,
666-
nrounds = 10,
667-
watchlist = list("training" = xgbmat_train_2),
668-
objective = "binary:logistic",
669-
eval_metric = "auc",
670-
verbose = 0
671-
)
732+
if (utils::packageVersion("xgboost") > "2.0.0.0") {
733+
fit_xgb_2 <- xgboost::xgb.train(
734+
params = list(
735+
eval_metric = "auc",
736+
objective = "binary:logistic"
737+
),
738+
data = xgbmat_train_2,
739+
nrounds = 10,
740+
evals = list("training" = xgbmat_train_2),
741+
verbose = 0
742+
)
743+
} else {
744+
fit_xgb_2 <- xgboost::xgb.train(
745+
data = xgbmat_train_2,
746+
nrounds = 10,
747+
watchlist = list("training" = xgbmat_train_2),
748+
objective = "binary:logistic",
749+
eval_metric = "auc",
750+
verbose = 0
751+
)
752+
}
672753

673754
expect_equal(
674755
extract_fit_engine(fit_p_2)$evaluation_log,
@@ -691,9 +772,9 @@ test_that("count/proportion parameters", {
691772
set_engine("xgboost") |>
692773
set_mode("regression") |>
693774
fit(mpg ~ ., data = mtcars)
694-
expect_equal(extract_fit_engine(fit1)$params$colsample_bytree, 1)
775+
expect_equal(extract_xgb_param(fit1, "colsample_bytree"), 1)
695776
expect_equal(
696-
extract_fit_engine(fit1)$params$colsample_bynode,
777+
extract_xgb_param(fit1, "colsample_bynode"),
697778
7 / (ncol(mtcars) - 1)
698779
)
699780

@@ -703,11 +784,11 @@ test_that("count/proportion parameters", {
703784
set_mode("regression") |>
704785
fit(mpg ~ ., data = mtcars)
705786
expect_equal(
706-
extract_fit_engine(fit2)$params$colsample_bytree,
787+
extract_xgb_param(fit2, "colsample_bytree"),
707788
4 / (ncol(mtcars) - 1)
708789
)
709790
expect_equal(
710-
extract_fit_engine(fit2)$params$colsample_bynode,
791+
extract_xgb_param(fit2, "colsample_bynode"),
711792
7 / (ncol(mtcars) - 1)
712793
)
713794

@@ -716,17 +797,18 @@ test_that("count/proportion parameters", {
716797
set_engine("xgboost") |>
717798
set_mode("regression") |>
718799
fit(mpg ~ ., data = mtcars)
719-
expect_equal(extract_fit_engine(fit3)$params$colsample_bytree, 1)
720-
expect_equal(extract_fit_engine(fit3)$params$colsample_bynode, 1)
800+
expect_equal(extract_xgb_param(fit3, "colsample_bytree"), 1)
801+
expect_equal(extract_xgb_param(fit3, "colsample_bynode"), 1)
721802

722803
fit4 <-
723804
boost_tree(mtry = .9, trees = 4) |>
724805
set_engine("xgboost", colsample_bytree = .1, counts = FALSE) |>
725806
set_mode("regression") |>
726807
fit(mpg ~ ., data = mtcars)
727-
expect_equal(extract_fit_engine(fit4)$params$colsample_bytree, .1)
728-
expect_equal(extract_fit_engine(fit4)$params$colsample_bynode, .9)
808+
expect_equal(extract_xgb_param(fit4, "colsample_bytree"), .1)
809+
expect_equal(extract_xgb_param(fit4, "colsample_bynode"), .9)
729810

811+
extract_xgb_param(fit4, "colsample_bynode")
730812
expect_snapshot(
731813
error = TRUE,
732814
boost_tree(mtry = .9, trees = 4) |>
@@ -758,7 +840,7 @@ test_that('interface to param arguments', {
758840
class = "xgboost_params_warning"
759841
)
760842

761-
expect_equal(extract_fit_engine(fit_1)$params$eval_metric, "mae")
843+
expect_equal(extract_xgb_param(fit_1, "eval_metric"), "mae")
762844

763845
# pass params as main argument (good)
764846
spec_2 <-
@@ -769,7 +851,7 @@ test_that('interface to param arguments', {
769851
fit_2 <- spec_2 |> fit(mpg ~ ., data = mtcars)
770852
)
771853

772-
expect_equal(extract_fit_engine(fit_2)$params$eval_metric, "mae")
854+
expect_equal(extract_xgb_param(fit_2, "eval_metric"), "mae")
773855

774856
# pass objective to params argument (bad)
775857
spec_3 <-
@@ -781,10 +863,7 @@ test_that('interface to param arguments', {
781863
class = "xgboost_params_warning"
782864
)
783865

784-
expect_equal(
785-
extract_fit_engine(fit_3)$params$objective,
786-
"reg:pseudohubererror"
787-
)
866+
expect_equal(extract_xgb_param(fit_3, "objective"), "reg:pseudohubererror")
788867

789868
# pass objective as main argument (good)
790869
spec_4 <-
@@ -795,10 +874,7 @@ test_that('interface to param arguments', {
795874
fit_4 <- spec_4 |> fit(mpg ~ ., data = mtcars)
796875
)
797876

798-
expect_equal(
799-
extract_fit_engine(fit_4)$params$objective,
800-
"reg:pseudohubererror"
801-
)
877+
expect_equal(extract_xgb_param(fit_4, "objective"), "reg:pseudohubererror")
802878

803879
# pass a guarded argument as a main argument (bad)
804880
spec_5 <-
@@ -810,7 +886,7 @@ test_that('interface to param arguments', {
810886
class = "xgboost_guarded_warning"
811887
)
812888

813-
expect_null(extract_fit_engine(fit_5)$params$watchlist)
889+
expect_null(extract_xgb_param(fit_5, "watchlist"))
814890

815891
# pass two guarded arguments as main arguments (bad)
816892
spec_6 <-
@@ -822,7 +898,7 @@ test_that('interface to param arguments', {
822898
class = "xgboost_guarded_warning"
823899
)
824900

825-
expect_null(extract_fit_engine(fit_5)$params$watchlist)
901+
expect_null(extract_xgb_param(fit_6, "watchlist"))
826902

827903
# pass a guarded argument as params argument (bad)
828904
spec_7 <-
@@ -834,5 +910,5 @@ test_that('interface to param arguments', {
834910
class = "xgboost_params_warning"
835911
)
836912

837-
expect_equal(extract_fit_engine(fit_5)$params$gamma, 0)
913+
expect_equal(extract_xgb_param(fit_7, "gamma"), 0)
838914
})

0 commit comments

Comments
 (0)