diff --git a/R/glmnet-engines.R b/R/glmnet-engines.R index 2d31bee3c..b0321a15c 100644 --- a/R/glmnet-engines.R +++ b/R/glmnet-engines.R @@ -138,6 +138,16 @@ predict_raw._glmnetfit <- predict_raw_glmnet unname(x[, 1]) } +organize_glmnet_pre_pred <- function(x, object) { + x <- x[, rownames(object$fit$beta), drop = FALSE] + if (is_sparse_matrix(x)) { + return(x) + } + + as.matrix(x) +} + + organize_glmnet_class <- function(x, object) { prob_to_class_2(x[, 1], object) } diff --git a/R/linear_reg_data.R b/R/linear_reg_data.R index c24f07c9e..cbbaf750d 100644 --- a/R/linear_reg_data.R +++ b/R/linear_reg_data.R @@ -250,7 +250,7 @@ set_pred( args = list( object = expr(object$fit), - newx = expr(as.matrix(new_data[, rownames(object$fit$beta), drop = FALSE])), + newx = expr(organize_glmnet_pre_pred(new_data, object)), type = "response", s = expr(object$spec$args$penalty) ) diff --git a/tests/testthat/_snaps/sparsevctrs.md b/tests/testthat/_snaps/sparsevctrs.md index 30c84b1a9..b7d04e160 100644 --- a/tests/testthat/_snaps/sparsevctrs.md +++ b/tests/testthat/_snaps/sparsevctrs.md @@ -127,6 +127,14 @@ Error in `maybe_sparse_matrix()`: ! no sparse vectors detected +# we don't run as.matrix() on sparse matrix for glmnet pred #1210 + + Code + predict(lm_fit, hotel_data) + Condition + Error in `predict.elnet()`: + ! data is sparse + # fit() errors if sparse matrix has no colnames Code diff --git a/tests/testthat/test-sparsevctrs.R b/tests/testthat/test-sparsevctrs.R index 73715b2d0..1f29eccd5 100644 --- a/tests/testthat/test-sparsevctrs.R +++ b/tests/testthat/test-sparsevctrs.R @@ -314,6 +314,34 @@ test_that("maybe_sparse_matrix() is used correctly", { ) }) +test_that("we don't run as.matrix() on sparse matrix for glmnet pred #1210", { + skip_if_not_installed("glmnet") + + local_mocked_bindings( + predict.elnet = function(object, newx, ...) { + if (is_sparse_matrix(newx)) { + stop("data is sparse") + } else { + stop("data isn't sparse (should not happen)") + } + }, + .package = "glmnet" + ) + + hotel_data <- sparse_hotel_rates() + + spec <- linear_reg(penalty = 0) %>% + set_mode("regression") %>% + set_engine("glmnet") + + lm_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1]) + + expect_snapshot( + error = TRUE, + predict(lm_fit, hotel_data) + ) +}) + test_that("fit() errors if sparse matrix has no colnames", { hotel_data <- sparse_hotel_rates() colnames(hotel_data) <- NULL