Skip to content

Commit f7f7d94

Browse files
committed
use correct version for xgboost switching
1 parent 1f56fff commit f7f7d94

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

R/boost_tree.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ xgb_train <- function(
340340

341341
others <- process_others(others, arg_list)
342342

343-
if (utils::packageVersion("xgboost") > "2.0.0.0") {
343+
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
344344
if (!is.null(num_class) && num_class > 2) {
345345
arg_list$num_class <- num_class
346346
}
@@ -380,7 +380,7 @@ xgb_train <- function(
380380
),
381381
others
382382
)
383-
if (utils::packageVersion("xgboost") > "2.0.0.0") {
383+
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
384384
main_args$evals <- quote(x$watchlist)
385385
} else {
386386
main_args$watchlist <- quote(x$watchlist)
@@ -543,7 +543,7 @@ as_xgb_data <- function(
543543
watch_list <- list(validation = val_data)
544544

545545
info_list <- list(label = y[trn_index])
546-
if (utils::packageVersion("xgboost") > "2.0.0.0") {
546+
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
547547
if (!is.null(weights)) {
548548
dat <- xgboost::xgb.DMatrix(
549549
data = x[trn_index, , drop = FALSE],
@@ -569,7 +569,7 @@ as_xgb_data <- function(
569569
)
570570
}
571571
} else {
572-
if (utils::packageVersion("xgboost") > "2.0.0.0") {
572+
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
573573
if (!is.null(weights)) {
574574
dat <- xgboost::xgb.DMatrix(
575575
x,
@@ -647,7 +647,7 @@ multi_predict._xgb.Booster <-
647647
}
648648

649649
xgb_by_tree <- function(tree, object, new_data, type, ...) {
650-
if (utils::packageVersion("xgboost") > "2.0.0.0") {
650+
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
651651
pred <- xgb_predict(
652652
object$fit,
653653
new_data = new_data,

tests/testthat/test-boost_tree_xgboost.R

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ hpc_xgboost <-
99
set_engine("xgboost")
1010

1111
extract_xgb_param <- function(x, param) {
12-
if (utils::packageVersion("xgboost") > "2.0.0.0") {
12+
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
1313
res <- attr(extract_fit_engine(x), "params")[[param]]
1414
} else {
1515
res <- extract_fit_engine(x)$param[[param]]
@@ -18,7 +18,7 @@ extract_xgb_param <- function(x, param) {
1818
}
1919

2020
extract_xgb_evaluation_log <- function(x) {
21-
if (utils::packageVersion("xgboost") > "2.0.0.0") {
21+
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
2222
res <- attr(extract_fit_engine(x), "evaluation_log")
2323
} else {
2424
res <- extract_fit_engine(x)[["evaluation_log"]]
@@ -408,15 +408,15 @@ test_that('early stopping', {
408408
)
409409

410410
extract_xgb_nitter <- function(x) {
411-
if (utils::packageVersion("xgboost") > "2.0.0.0") {
411+
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
412412
res <- nrow(attr(extract_fit_engine(x), "evaluation_log"))
413413
} else {
414414
res <- extract_fit_engine(reg_fit)$niter
415415
}
416416
res
417417
}
418418
extract_xgb_best_iteration <- function(x) {
419-
if (utils::packageVersion("xgboost") > "2.0.0.0") {
419+
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
420420
res <- attr(extract_fit_engine(x), "early_stop")$best_iteration
421421
} else {
422422
res <- extract_fit_engine(reg_fit)$best_iteration
@@ -572,7 +572,7 @@ test_that('xgboost data and sparse matrices', {
572572
from_mat$fit$handle <- NULL
573573
from_sparse$fit$handle <- NULL
574574

575-
if (utils::packageVersion("xgboost") > "2.0.0.0") {
575+
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
576576
expect_equal(
577577
attributes(extract_fit_engine(from_df)),
578578
attributes(extract_fit_engine(from_mat)),
@@ -689,7 +689,7 @@ test_that("fit and prediction with `event_level`", {
689689
xgbmat_train_1 <- xgb.DMatrix(data = train_x, label = train_y_1)
690690

691691
set.seed(24)
692-
if (utils::packageVersion("xgboost") > "2.0.0.0") {
692+
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
693693
fit_xgb_1 <- xgboost::xgb.train(
694694
params = list(
695695
objective = "binary:logistic",
@@ -729,7 +729,7 @@ test_that("fit and prediction with `event_level`", {
729729
xgbmat_train_2 <- xgb.DMatrix(data = train_x, label = train_y_2)
730730

731731
set.seed(24)
732-
if (utils::packageVersion("xgboost") > "2.0.0.0") {
732+
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
733733
fit_xgb_2 <- xgboost::xgb.train(
734734
params = list(
735735
eval_metric = "auc",

0 commit comments

Comments
 (0)