@@ -8,6 +8,24 @@ hpc_xgboost <-
88 boost_tree(trees = 2 , mode = " classification" ) | >
99 set_engine(" xgboost" )
1010
11+ extract_xgb_param <- function (x , param ) {
12+ if (utils :: packageVersion(" xgboost" ) > " 2.0.0.0" ) {
13+ res <- attr(extract_fit_engine(x ), " params" )[[param ]]
14+ } else {
15+ res <- extract_fit_engine(x )[[param ]]
16+ }
17+ res
18+ }
19+
20+ extract_xgb_evaluation_log <- function (x ) {
21+ if (utils :: packageVersion(" xgboost" ) > " 2.0.0.0" ) {
22+ res <- attr(extract_fit_engine(x ), " evaluation_log" )
23+ } else {
24+ res <- extract_fit_engine(x )[[" evaluation_log" ]]
25+ }
26+ res
27+ }
28+
1129# ------------------------------------------------------------------------------
1230
1331test_that(' xgboost execution, classification' , {
@@ -59,13 +77,21 @@ test_that('xgboost execution, classification', {
5977 )
6078 })
6179
62- expect_equal(res_f $ fit $ evaluation_log , res_xy $ fit $ evaluation_log )
63- expect_equal(res_f_wts $ fit $ evaluation_log , res_xy_wts $ fit $ evaluation_log )
80+ expect_equal(
81+ extract_xgb_evaluation_log(res_f ),
82+ extract_xgb_evaluation_log(res_xy )
83+ )
84+ expect_equal(
85+ extract_xgb_evaluation_log(res_f_wts ),
86+ extract_xgb_evaluation_log(res_xy_wts )
87+ )
6488 # Check to see if the case weights had an effect
6589 expect_true(
66- ! isTRUE(all.equal(res_f $ fit $ evaluation_log , res_f_wts $ fit $ evaluation_log ))
90+ ! isTRUE(all.equal(
91+ extract_xgb_evaluation_log(res_f ),
92+ extract_xgb_evaluation_log(res_f_wts )
93+ ))
6794 )
68-
6995 expect_true(has_multi_predict(res_xy ))
7096 expect_equal(multi_predict_args(res_xy ), " trees" )
7197
@@ -209,10 +235,7 @@ test_that('xgboost regression prediction', {
209235 )
210236 expect_equal(form_pred , predict(form_fit , new_data = mtcars [1 : 8 , - 1 ])$ .pred )
211237
212- expect_equal(
213- extract_fit_engine(form_fit )$ params $ objective ,
214- " reg:squarederror"
215- )
238+ expect_equal(extract_xgb_param(form_fit , " objective" ), " reg:squarederror" )
216239})
217240
218241
@@ -228,10 +251,7 @@ test_that('xgboost alternate objective', {
228251 set_mode(" regression" )
229252
230253 xgb_fit <- spec | > fit(mpg ~ . , data = mtcars )
231- expect_equal(
232- extract_fit_engine(xgb_fit )$ params $ objective ,
233- " reg:pseudohubererror"
234- )
254+ expect_equal(extract_xgb_param(xgb_fit , " objective" ), " reg:pseudohubererror" )
235255 expect_no_error(xgb_preds <- predict(xgb_fit , new_data = mtcars [1 , ]))
236256 expect_s3_class(xgb_preds , " data.frame" )
237257
@@ -333,7 +353,7 @@ test_that('validation sets', {
333353 )
334354
335355 expect_equal(
336- colnames(extract_fit_engine (reg_fit )$ evaluation_log )[2 ],
356+ colnames(extract_xgb_evaluation_log (reg_fit ))[2 ],
337357 " validation_rmse"
338358 )
339359
@@ -345,7 +365,7 @@ test_that('validation sets', {
345365 )
346366
347367 expect_equal(
348- colnames(extract_fit_engine (reg_fit )$ evaluation_log )[2 ],
368+ colnames(extract_xgb_evaluation_log (reg_fit ))[2 ],
349369 " validation_mae"
350370 )
351371
@@ -357,7 +377,7 @@ test_that('validation sets', {
357377 )
358378
359379 expect_equal(
360- colnames(extract_fit_engine (reg_fit )$ evaluation_log )[2 ],
380+ colnames(extract_xgb_evaluation_log (reg_fit ))[2 ],
361381 " training_mae"
362382 )
363383
@@ -387,12 +407,29 @@ test_that('early stopping', {
387407 fit(mpg ~ . , data = mtcars [- (1 : 4 ), ])
388408 )
389409
410+ extract_xgb_nitter <- function (x ) {
411+ if (utils :: packageVersion(" xgboost" ) > " 2.0.0.0" ) {
412+ res <- nrow(attr(extract_fit_engine(x ), " evaluation_log" ))
413+ } else {
414+ res <- extract_fit_engine(reg_fit )$ niter
415+ }
416+ res
417+ }
418+ extract_xgb_best_iteration <- function (x ) {
419+ if (utils :: packageVersion(" xgboost" ) > " 2.0.0.0" ) {
420+ res <- attr(extract_fit_engine(x ), " early_stop" )$ best_iteration
421+ } else {
422+ res <- extract_fit_engine(reg_fit )$ best_iteration
423+ }
424+ res
425+ }
426+
390427 expect_equal(
391- extract_fit_engine (reg_fit )$ niter -
392- extract_fit_engine (reg_fit )$ best_iteration ,
428+ extract_xgb_nitter (reg_fit ) -
429+ extract_xgb_best_iteration (reg_fit ),
393430 5
394431 )
395- expect_true(extract_fit_engine (reg_fit )$ niter < 200 )
432+ expect_true(extract_xgb_nitter (reg_fit ) < 200 )
396433
397434 expect_no_condition(
398435 reg_fit <-
@@ -535,16 +572,29 @@ test_that('xgboost data and sparse matrices', {
535572 from_mat $ fit $ handle <- NULL
536573 from_sparse $ fit $ handle <- NULL
537574
538- expect_equal(
539- extract_fit_engine(from_df ),
540- extract_fit_engine(from_mat ),
541- ignore_function_env = TRUE
542- )
543- expect_equal(
544- extract_fit_engine(from_df ),
545- extract_fit_engine(from_sparse ),
546- ignore_function_env = TRUE
547- )
575+ if (utils :: packageVersion(" xgboost" ) > " 2.0.0.0" ) {
576+ expect_equal(
577+ attributes(extract_fit_engine(from_df )),
578+ attributes(extract_fit_engine(from_mat )),
579+ ignore_function_env = TRUE
580+ )
581+ expect_equal(
582+ attributes(extract_fit_engine(from_df )),
583+ attributes(extract_fit_engine(from_sparse )),
584+ ignore_function_env = TRUE
585+ )
586+ } else {
587+ expect_equal(
588+ extract_fit_engine(from_df ),
589+ extract_fit_engine(from_mat ),
590+ ignore_function_env = TRUE
591+ )
592+ expect_equal(
593+ extract_fit_engine(from_df ),
594+ extract_fit_engine(from_sparse ),
595+ ignore_function_env = TRUE
596+ )
597+ }
548598
549599 # case weights added
550600 expect_no_condition(
@@ -591,14 +641,20 @@ test_that('argument checks for data dimensions', {
591641 xy_fit <- spec | >
592642 fit_xy(x = penguins_dummy , y = penguins $ species , control = ctrl )
593643 )
594- expect_equal(extract_fit_engine(f_fit )$ params $ colsample_bynode , 1 )
595644 expect_equal(
596- extract_fit_engine(f_fit )$ params $ min_child_weight ,
645+ extract_xgb_param(f_fit , " colsample_bynode" ),
646+ 1
647+ )
648+ expect_equal(
649+ extract_xgb_param(f_fit , " min_child_weight" ),
597650 nrow(penguins )
598651 )
599- expect_equal(extract_fit_engine(xy_fit )$ params $ colsample_bynode , 1 )
600652 expect_equal(
601- extract_fit_engine(xy_fit )$ params $ min_child_weight ,
653+ extract_xgb_param(xy_fit , " colsample_bynode" ),
654+ 1
655+ )
656+ expect_equal(
657+ extract_xgb_param(xy_fit , " min_child_weight" ),
602658 nrow(penguins )
603659 )
604660})
@@ -633,15 +689,27 @@ test_that("fit and prediction with `event_level`", {
633689 xgbmat_train_1 <- xgb.DMatrix(data = train_x , label = train_y_1 )
634690
635691 set.seed(24 )
636- fit_xgb_1 <- xgboost :: xgb.train(
637- data = xgbmat_train_1 ,
638- nrounds = 10 ,
639- watchlist = list (" training" = xgbmat_train_1 ),
640- objective = " binary:logistic" ,
641- eval_metric = " auc" ,
642- verbose = 0
643- )
644-
692+ if (utils :: packageVersion(" xgboost" ) > " 2.0.0.0" ) {
693+ fit_xgb_1 <- xgboost :: xgb.train(
694+ params = list (
695+ objective = " binary:logistic" ,
696+ eval_metric = " auc"
697+ ),
698+ data = xgbmat_train_1 ,
699+ nrounds = 10 ,
700+ evals = list (" training" = xgbmat_train_1 ),
701+ verbose = 0
702+ )
703+ } else {
704+ fit_xgb_1 <- xgboost :: xgb.train(
705+ data = xgbmat_train_1 ,
706+ nrounds = 10 ,
707+ watchlist = list (" training" = xgbmat_train_1 ),
708+ objective = " binary:logistic" ,
709+ eval_metric = " auc" ,
710+ verbose = 0
711+ )
712+ }
645713 expect_equal(
646714 extract_fit_engine(fit_p_1 )$ evaluation_log ,
647715 fit_xgb_1 $ evaluation_log
@@ -661,14 +729,27 @@ test_that("fit and prediction with `event_level`", {
661729 xgbmat_train_2 <- xgb.DMatrix(data = train_x , label = train_y_2 )
662730
663731 set.seed(24 )
664- fit_xgb_2 <- xgboost :: xgb.train(
665- data = xgbmat_train_2 ,
666- nrounds = 10 ,
667- watchlist = list (" training" = xgbmat_train_2 ),
668- objective = " binary:logistic" ,
669- eval_metric = " auc" ,
670- verbose = 0
671- )
732+ if (utils :: packageVersion(" xgboost" ) > " 2.0.0.0" ) {
733+ fit_xgb_2 <- xgboost :: xgb.train(
734+ params = list (
735+ eval_metric = " auc" ,
736+ objective = " binary:logistic"
737+ ),
738+ data = xgbmat_train_2 ,
739+ nrounds = 10 ,
740+ evals = list (" training" = xgbmat_train_2 ),
741+ verbose = 0
742+ )
743+ } else {
744+ fit_xgb_2 <- xgboost :: xgb.train(
745+ data = xgbmat_train_2 ,
746+ nrounds = 10 ,
747+ watchlist = list (" training" = xgbmat_train_2 ),
748+ objective = " binary:logistic" ,
749+ eval_metric = " auc" ,
750+ verbose = 0
751+ )
752+ }
672753
673754 expect_equal(
674755 extract_fit_engine(fit_p_2 )$ evaluation_log ,
@@ -691,9 +772,9 @@ test_that("count/proportion parameters", {
691772 set_engine(" xgboost" ) | >
692773 set_mode(" regression" ) | >
693774 fit(mpg ~ . , data = mtcars )
694- expect_equal(extract_fit_engine (fit1 ) $ params $ colsample_bytree , 1 )
775+ expect_equal(extract_xgb_param (fit1 , " colsample_bytree" ) , 1 )
695776 expect_equal(
696- extract_fit_engine (fit1 ) $ params $ colsample_bynode ,
777+ extract_xgb_param (fit1 , " colsample_bynode" ) ,
697778 7 / (ncol(mtcars ) - 1 )
698779 )
699780
@@ -703,11 +784,11 @@ test_that("count/proportion parameters", {
703784 set_mode(" regression" ) | >
704785 fit(mpg ~ . , data = mtcars )
705786 expect_equal(
706- extract_fit_engine (fit2 ) $ params $ colsample_bytree ,
787+ extract_xgb_param (fit2 , " colsample_bytree" ) ,
707788 4 / (ncol(mtcars ) - 1 )
708789 )
709790 expect_equal(
710- extract_fit_engine (fit2 ) $ params $ colsample_bynode ,
791+ extract_xgb_param (fit2 , " colsample_bynode" ) ,
711792 7 / (ncol(mtcars ) - 1 )
712793 )
713794
@@ -716,17 +797,18 @@ test_that("count/proportion parameters", {
716797 set_engine(" xgboost" ) | >
717798 set_mode(" regression" ) | >
718799 fit(mpg ~ . , data = mtcars )
719- expect_equal(extract_fit_engine (fit3 ) $ params $ colsample_bytree , 1 )
720- expect_equal(extract_fit_engine (fit3 ) $ params $ colsample_bynode , 1 )
800+ expect_equal(extract_xgb_param (fit3 , " colsample_bytree" ) , 1 )
801+ expect_equal(extract_xgb_param (fit3 , " colsample_bynode" ) , 1 )
721802
722803 fit4 <-
723804 boost_tree(mtry = .9 , trees = 4 ) | >
724805 set_engine(" xgboost" , colsample_bytree = .1 , counts = FALSE ) | >
725806 set_mode(" regression" ) | >
726807 fit(mpg ~ . , data = mtcars )
727- expect_equal(extract_fit_engine (fit4 ) $ params $ colsample_bytree , .1 )
728- expect_equal(extract_fit_engine (fit4 ) $ params $ colsample_bynode , .9 )
808+ expect_equal(extract_xgb_param (fit4 , " colsample_bytree" ) , .1 )
809+ expect_equal(extract_xgb_param (fit4 , " colsample_bynode" ) , .9 )
729810
811+ extract_xgb_param(fit4 , " colsample_bynode" )
730812 expect_snapshot(
731813 error = TRUE ,
732814 boost_tree(mtry = .9 , trees = 4 ) | >
@@ -758,7 +840,7 @@ test_that('interface to param arguments', {
758840 class = " xgboost_params_warning"
759841 )
760842
761- expect_equal(extract_fit_engine (fit_1 ) $ params $ eval_metric , " mae" )
843+ expect_equal(extract_xgb_param (fit_1 , " eval_metric" ) , " mae" )
762844
763845 # pass params as main argument (good)
764846 spec_2 <-
@@ -769,7 +851,7 @@ test_that('interface to param arguments', {
769851 fit_2 <- spec_2 | > fit(mpg ~ . , data = mtcars )
770852 )
771853
772- expect_equal(extract_fit_engine (fit_2 ) $ params $ eval_metric , " mae" )
854+ expect_equal(extract_xgb_param (fit_2 , " eval_metric" ) , " mae" )
773855
774856 # pass objective to params argument (bad)
775857 spec_3 <-
@@ -781,10 +863,7 @@ test_that('interface to param arguments', {
781863 class = " xgboost_params_warning"
782864 )
783865
784- expect_equal(
785- extract_fit_engine(fit_3 )$ params $ objective ,
786- " reg:pseudohubererror"
787- )
866+ expect_equal(extract_xgb_param(fit_3 , " objective" ), " reg:pseudohubererror" )
788867
789868 # pass objective as main argument (good)
790869 spec_4 <-
@@ -795,10 +874,7 @@ test_that('interface to param arguments', {
795874 fit_4 <- spec_4 | > fit(mpg ~ . , data = mtcars )
796875 )
797876
798- expect_equal(
799- extract_fit_engine(fit_4 )$ params $ objective ,
800- " reg:pseudohubererror"
801- )
877+ expect_equal(extract_xgb_param(fit_4 , " objective" ), " reg:pseudohubererror" )
802878
803879 # pass a guarded argument as a main argument (bad)
804880 spec_5 <-
@@ -810,7 +886,7 @@ test_that('interface to param arguments', {
810886 class = " xgboost_guarded_warning"
811887 )
812888
813- expect_null(extract_fit_engine (fit_5 ) $ params $ watchlist )
889+ expect_null(extract_xgb_param (fit_5 , " watchlist" ) )
814890
815891 # pass two guarded arguments as main arguments (bad)
816892 spec_6 <-
@@ -822,7 +898,7 @@ test_that('interface to param arguments', {
822898 class = " xgboost_guarded_warning"
823899 )
824900
825- expect_null(extract_fit_engine( fit_5 ) $ params $ watchlist )
901+ expect_null(extract_xgb_param( fit_6 , " watchlist" ) )
826902
827903 # pass a guarded argument as params argument (bad)
828904 spec_7 <-
@@ -834,5 +910,5 @@ test_that('interface to param arguments', {
834910 class = " xgboost_params_warning"
835911 )
836912
837- expect_equal(extract_fit_engine( fit_5 ) $ params $ gamma , 0 )
913+ expect_equal(extract_xgb_param( fit_7 , " gamma" ) , 0 )
838914})
0 commit comments