@@ -340,31 +340,68 @@ xgb_train <- function(
340340
341341 others <- process_others(others , arg_list )
342342
343+ if (utils :: packageVersion(" xgboost" ) > " 2.0.0.0" ) {
344+ if (! is.null(num_class ) && num_class > 2 ) {
345+ arg_list $ num_class <- num_class
346+ }
347+
348+ if (! is.null(others $ objective )) {
349+ arg_list $ objective <- others $ objective
350+ others $ objective <- NULL
351+ }
352+ if (! is.null(others $ eval_metric )) {
353+ arg_list $ eval_metric <- others $ eval_metric
354+ others $ eval_metric <- NULL
355+ }
356+ if (! is.null(others $ nthread )) {
357+ arg_list $ nthread <- others $ nthread
358+ others $ nthread <- NULL
359+ }
360+
361+ if (is.null(arg_list $ objective )) {
362+ if (is.numeric(y )) {
363+ arg_list $ objective <- " reg:squarederror"
364+ } else {
365+ if (num_class == 2 ) {
366+ arg_list $ objective <- " binary:logistic"
367+ } else {
368+ arg_list $ objective <- " multi:softprob"
369+ }
370+ }
371+ }
372+ }
373+
343374 main_args <- c(
344375 list (
345376 data = quote(x $ data ),
346- watchlist = quote(x $ watchlist ),
347377 params = arg_list ,
348378 nrounds = nrounds ,
349379 early_stopping_rounds = early_stop
350380 ),
351381 others
352382 )
383+ if (utils :: packageVersion(" xgboost" ) > " 2.0.0.0" ) {
384+ main_args $ evals <- quote(x $ watchlist )
385+ } else {
386+ main_args $ watchlist <- quote(x $ watchlist )
387+ }
353388
354- if (is.null(main_args $ objective )) {
355- if (is.numeric(y )) {
356- main_args $ objective <- " reg:squarederror"
357- } else {
358- if (num_class == 2 ) {
359- main_args $ objective <- " binary:logistic"
389+ if (utils :: packageVersion(" xgboost" ) < " 2.0.0.0" ) {
390+ if (is.null(main_args $ objective )) {
391+ if (is.numeric(y )) {
392+ main_args $ objective <- " reg:squarederror"
360393 } else {
361- main_args $ objective <- " multi:softprob"
394+ if (num_class == 2 ) {
395+ main_args $ objective <- " binary:logistic"
396+ } else {
397+ main_args $ objective <- " multi:softprob"
398+ }
362399 }
363400 }
364- }
365401
366- if (! is.null(num_class ) && num_class > 2 ) {
367- main_args $ num_class <- num_class
402+ if (! is.null(num_class ) && num_class > 2 ) {
403+ main_args $ num_class <- num_class
404+ }
368405 }
369406
370407 call <- make_call(fun = " xgb.train" , ns = " xgboost" , main_args )
@@ -506,21 +543,52 @@ as_xgb_data <- function(
506543 watch_list <- list (validation = val_data )
507544
508545 info_list <- list (label = y [trn_index ])
509- if (! is.null(weights )) {
510- info_list $ weight <- weights [trn_index ]
546+ if (utils :: packageVersion(" xgboost" ) > " 2.0.0.0" ) {
547+ if (! is.null(weights )) {
548+ dat <- xgboost :: xgb.DMatrix(
549+ data = x [trn_index , , drop = FALSE ],
550+ missing = NA ,
551+ label = y [trn_index ],
552+ weight = weights [trn_index ]
553+ )
554+ } else {
555+ dat <- xgboost :: xgb.DMatrix(
556+ data = x [trn_index , , drop = FALSE ],
557+ missing = NA ,
558+ label = y [trn_index ]
559+ )
560+ }
561+ } else {
562+ if (! is.null(weights )) {
563+ info_list $ weight <- weights [trn_index ]
564+ }
565+ dat <- xgboost :: xgb.DMatrix(
566+ data = x [trn_index , , drop = FALSE ],
567+ missing = NA ,
568+ info = info_list
569+ )
511570 }
512- dat <- xgboost :: xgb.DMatrix(
513- data = x [trn_index , , drop = FALSE ],
514- missing = NA ,
515- info = info_list
516- )
517571 } else {
518- info_list <- list (label = y )
519- if (! is.null(weights )) {
520- info_list $ weight <- weights
572+ if (utils :: packageVersion(" xgboost" ) > " 2.0.0.0" ) {
573+ if (! is.null(weights )) {
574+ dat <- xgboost :: xgb.DMatrix(
575+ x ,
576+ missing = NA ,
577+ label = y ,
578+ weight = weights
579+ )
580+ } else {
581+ dat <- xgboost :: xgb.DMatrix(x , missing = NA , label = y )
582+ }
583+ watch_list <- list (training = dat )
584+ } else {
585+ info_list <- list (label = y )
586+ if (! is.null(weights )) {
587+ info_list $ weight <- weights
588+ }
589+ dat <- xgboost :: xgb.DMatrix(x , missing = NA , info = info_list )
590+ watch_list <- list (training = dat )
521591 }
522- dat <- xgboost :: xgb.DMatrix(x , missing = NA , info = info_list )
523- watch_list <- list (training = dat )
524592 }
525593 } else {
526594 dat <- xgboost :: setinfo(x , " label" , y )
@@ -579,12 +647,20 @@ multi_predict._xgb.Booster <-
579647 }
580648
581649xgb_by_tree <- function (tree , object , new_data , type , ... ) {
582- pred <- xgb_predict(
583- object $ fit ,
584- new_data = new_data ,
585- iterationrange = c(1 , tree + 1 ),
586- ntreelimit = NULL
587- )
650+ if (utils :: packageVersion(" xgboost" ) > " 2.0.0.0" ) {
651+ pred <- xgb_predict(
652+ object $ fit ,
653+ new_data = new_data ,
654+ iterationrange = c(1 , tree + 1 )
655+ )
656+ } else {
657+ pred <- xgb_predict(
658+ object $ fit ,
659+ new_data = new_data ,
660+ iterationrange = c(1 , tree + 1 ),
661+ ntreelimit = NULL
662+ )
663+ }
588664
589665 # switch based on prediction type
590666 if (object $ spec $ mode == " regression" ) {
0 commit comments