Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: workflows
Title: Modeling Workflows
Version: 1.1.3.9000
Version: 1.1.3.9001
Authors@R: c(
person("Davis", "Vaughan", , "[email protected]", role = "aut"),
person("Simon", "Couch", , "[email protected]", role = c("aut", "cre"),
Expand All @@ -24,7 +24,7 @@ Imports:
hardhat (>= 1.2.0),
lifecycle (>= 1.0.3),
modelenv (>= 0.1.0),
parsnip (>= 1.0.3),
parsnip (>= 1.1.0.9001),
rlang (>= 1.0.3),
tidyselect (>= 1.2.0),
vctrs (>= 0.4.1)
Expand All @@ -38,6 +38,8 @@ Suggests:
recipes (>= 1.0.0),
rmarkdown,
testthat (>= 3.0.0)
Remotes:
tidymodels/parsnip#955
VignetteBuilder:
knitr
Config/Needs/website:
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ S3method(print,workflow)
S3method(tidy,workflow)
S3method(tunable,workflow)
S3method(tune_args,workflow)
S3method(weight_propensity,workflow)
export(.fit_finalize)
export(.fit_model)
export(.fit_pre)
Expand Down Expand Up @@ -69,4 +70,5 @@ importFrom(hardhat,extract_recipe)
importFrom(hardhat,extract_spec_parsnip)
importFrom(lifecycle,deprecated)
importFrom(parsnip,fit_xy)
importFrom(parsnip,weight_propensity)
importFrom(stats,predict)
43 changes: 43 additions & 0 deletions R/weight_propensity.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#' Helper for bridging two-stage causal fits
#'
#' @inherit parsnip::weight_propensity.model_fit description
#'
#' @inheritParams parsnip::weight_propensity.model_fit
#'
#' @inherit parsnip::weight_propensity.model_fit return
#'
#' @inherit parsnip::weight_propensity.model_fit references
#'
#' @importFrom parsnip weight_propensity
#' @method weight_propensity workflow
#' @export
weight_propensity.workflow <- function(object,
wt_fn,
.treated = extract_fit_parsnip(object)$lvl[2],
...,
data) {
if (rlang::is_missing(wt_fn) || !is.function(wt_fn)) {
abort("`wt_fn` must be a function.")
}

if (rlang::is_missing(data) || !is.data.frame(data)) {
abort("`data` must be the data supplied as the data argument to `fit()`.")
}

if (!is_trained_workflow(object)) {
abort("`weight_propensity()` is not well-defined for an unfitted workflow.")
}

outcome_name <- names(object$pre$mold$outcomes)

preds <- predict(object, data, type = "prob")
preds <- preds[[paste0(".pred_", .treated)]]

data$.wts <-
hardhat::importance_weights(
wt_fn(preds, data[[outcome_name]], .treated = .treated, ...)
)

data
}

53 changes: 53 additions & 0 deletions man/weight_propensity.workflow.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 50 additions & 0 deletions tests/testthat/_snaps/weight_propensity.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# errors informatively with bad input

Code
weight_propensity(wf, silly_wt_fn, data = two_class_dat)
Condition
Error in `weight_propensity()`:
! `weight_propensity()` is not well-defined for an unfitted workflow.

---

Code
weight_propensity(wf_fit, data = two_class_dat)
Condition
Error in `weight_propensity()`:
! `wt_fn` must be a function.

---

Code
weight_propensity(wf_fit, "boop", data = two_class_dat)
Condition
Error in `weight_propensity()`:
! `wt_fn` must be a function.

---

Code
weight_propensity(wf_fit, function(...) {
-1L
}, data = two_class_dat)
Condition
Error in `hardhat::importance_weights()`:
! `x` can't contain negative weights.

---

Code
weight_propensity(wf_fit, silly_wt_fn)
Condition
Error in `weight_propensity()`:
! `data` must be the data supplied as the data argument to `fit()`.

---

Code
weight_propensity(wf_fit, silly_wt_fn, data = "boop")
Condition
Error in `weight_propensity()`:
! `data` must be the data supplied as the data argument to `fit()`.

63 changes: 63 additions & 0 deletions tests/testthat/test-weight_propensity.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
test_that("basic functionality", {
skip_if_not_installed("modeldata")
library(modeldata)
library(parsnip)

silly_wt_fn <- function(.propensity, .exposure = NULL, ...) {
seq(1, 2, length.out = length(.propensity))
}

lr_fit <- fit(workflow(Class ~ A + B, logistic_reg()), two_class_dat)

lr_res1 <- weight_propensity(lr_fit, silly_wt_fn, data = two_class_dat)
expect_s3_class(lr_res1, "tbl_df")
expect_true(all(names(lr_res1) %in% c(names(two_class_dat), ".wts")))
expect_equal(lr_res1$.wts, importance_weights(seq(1, 2, length.out = nrow(two_class_dat))))
})

test_that("errors informatively with bad input", {
skip_if_not_installed("modeldata")
library(modeldata)
library(parsnip)

silly_wt_fn <- function(.propensity, .exposure = NULL, ...) {
seq(1, 2, length.out = length(.propensity))
}

# untrained workflow
wf <- workflow(Class ~ A + B, logistic_reg())

expect_snapshot(
error = TRUE,
weight_propensity(wf, silly_wt_fn, data = two_class_dat)
)

# bad `wt_fn`
wf_fit <- fit(wf, two_class_dat)

expect_snapshot(
error = TRUE,
weight_propensity(wf_fit, data = two_class_dat)
)

expect_snapshot(
error = TRUE,
weight_propensity(wf_fit, "boop", data = two_class_dat)
)

expect_snapshot(
error = TRUE,
weight_propensity(wf_fit, function(...) {-1L}, data = two_class_dat)
)

# bad `data`
expect_snapshot(
error = TRUE,
weight_propensity(wf_fit, silly_wt_fn)
)

expect_snapshot(
error = TRUE,
weight_propensity(wf_fit, silly_wt_fn, data = "boop")
)
})