Skip to content

Commit 450237f

Browse files
committed
make xgb functions backwards compatible with 2.0.0.0 version
1 parent eb2651c commit 450237f

File tree

1 file changed

+105
-29
lines changed

1 file changed

+105
-29
lines changed

R/boost_tree.R

Lines changed: 105 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -340,31 +340,68 @@ xgb_train <- function(
340340

341341
others <- process_others(others, arg_list)
342342

343+
if (utils::packageVersion("xgboost") > "2.0.0.0") {
344+
if (!is.null(num_class) && num_class > 2) {
345+
arg_list$num_class <- num_class
346+
}
347+
348+
if (!is.null(others$objective)) {
349+
arg_list$objective <- others$objective
350+
others$objective <- NULL
351+
}
352+
if (!is.null(others$eval_metric)) {
353+
arg_list$eval_metric <- others$eval_metric
354+
others$eval_metric <- NULL
355+
}
356+
if (!is.null(others$nthread)) {
357+
arg_list$nthread <- others$nthread
358+
others$nthread <- NULL
359+
}
360+
361+
if (is.null(arg_list$objective)) {
362+
if (is.numeric(y)) {
363+
arg_list$objective <- "reg:squarederror"
364+
} else {
365+
if (num_class == 2) {
366+
arg_list$objective <- "binary:logistic"
367+
} else {
368+
arg_list$objective <- "multi:softprob"
369+
}
370+
}
371+
}
372+
}
373+
343374
main_args <- c(
344375
list(
345376
data = quote(x$data),
346-
watchlist = quote(x$watchlist),
347377
params = arg_list,
348378
nrounds = nrounds,
349379
early_stopping_rounds = early_stop
350380
),
351381
others
352382
)
383+
if (utils::packageVersion("xgboost") > "2.0.0.0") {
384+
main_args$evals <- quote(x$watchlist)
385+
} else {
386+
main_args$watchlist <- quote(x$watchlist)
387+
}
353388

354-
if (is.null(main_args$objective)) {
355-
if (is.numeric(y)) {
356-
main_args$objective <- "reg:squarederror"
357-
} else {
358-
if (num_class == 2) {
359-
main_args$objective <- "binary:logistic"
389+
if (utils::packageVersion("xgboost") < "2.0.0.0") {
390+
if (is.null(main_args$objective)) {
391+
if (is.numeric(y)) {
392+
main_args$objective <- "reg:squarederror"
360393
} else {
361-
main_args$objective <- "multi:softprob"
394+
if (num_class == 2) {
395+
main_args$objective <- "binary:logistic"
396+
} else {
397+
main_args$objective <- "multi:softprob"
398+
}
362399
}
363400
}
364-
}
365401

366-
if (!is.null(num_class) && num_class > 2) {
367-
main_args$num_class <- num_class
402+
if (!is.null(num_class) && num_class > 2) {
403+
main_args$num_class <- num_class
404+
}
368405
}
369406

370407
call <- make_call(fun = "xgb.train", ns = "xgboost", main_args)
@@ -506,21 +543,52 @@ as_xgb_data <- function(
506543
watch_list <- list(validation = val_data)
507544

508545
info_list <- list(label = y[trn_index])
509-
if (!is.null(weights)) {
510-
info_list$weight <- weights[trn_index]
546+
if (utils::packageVersion("xgboost") > "2.0.0.0") {
547+
if (!is.null(weights)) {
548+
dat <- xgboost::xgb.DMatrix(
549+
data = x[trn_index, , drop = FALSE],
550+
missing = NA,
551+
label = y[trn_index],
552+
weight = weights[trn_index]
553+
)
554+
} else {
555+
dat <- xgboost::xgb.DMatrix(
556+
data = x[trn_index, , drop = FALSE],
557+
missing = NA,
558+
label = y[trn_index]
559+
)
560+
}
561+
} else {
562+
if (!is.null(weights)) {
563+
info_list$weight <- weights[trn_index]
564+
}
565+
dat <- xgboost::xgb.DMatrix(
566+
data = x[trn_index, , drop = FALSE],
567+
missing = NA,
568+
info = info_list
569+
)
511570
}
512-
dat <- xgboost::xgb.DMatrix(
513-
data = x[trn_index, , drop = FALSE],
514-
missing = NA,
515-
info = info_list
516-
)
517571
} else {
518-
info_list <- list(label = y)
519-
if (!is.null(weights)) {
520-
info_list$weight <- weights
572+
if (utils::packageVersion("xgboost") > "2.0.0.0") {
573+
if (!is.null(weights)) {
574+
dat <- xgboost::xgb.DMatrix(
575+
x,
576+
missing = NA,
577+
label = y,
578+
weight = weights
579+
)
580+
} else {
581+
dat <- xgboost::xgb.DMatrix(x, missing = NA, label = y)
582+
}
583+
watch_list <- list(training = dat)
584+
} else {
585+
info_list <- list(label = y)
586+
if (!is.null(weights)) {
587+
info_list$weight <- weights
588+
}
589+
dat <- xgboost::xgb.DMatrix(x, missing = NA, info = info_list)
590+
watch_list <- list(training = dat)
521591
}
522-
dat <- xgboost::xgb.DMatrix(x, missing = NA, info = info_list)
523-
watch_list <- list(training = dat)
524592
}
525593
} else {
526594
dat <- xgboost::setinfo(x, "label", y)
@@ -579,12 +647,20 @@ multi_predict._xgb.Booster <-
579647
}
580648

581649
xgb_by_tree <- function(tree, object, new_data, type, ...) {
582-
pred <- xgb_predict(
583-
object$fit,
584-
new_data = new_data,
585-
iterationrange = c(1, tree + 1),
586-
ntreelimit = NULL
587-
)
650+
if (utils::packageVersion("xgboost") > "2.0.0.0") {
651+
pred <- xgb_predict(
652+
object$fit,
653+
new_data = new_data,
654+
iterationrange = c(1, tree + 1)
655+
)
656+
} else {
657+
pred <- xgb_predict(
658+
object$fit,
659+
new_data = new_data,
660+
iterationrange = c(1, tree + 1),
661+
ntreelimit = NULL
662+
)
663+
}
588664

589665
# switch based on prediction type
590666
if (object$spec$mode == "regression") {

0 commit comments

Comments
 (0)