Skip to content

Format project using Air #291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@ vignettes/loo2-non-factorizable_cache/*

^CRAN-SUBMISSION$
^release-prep\.R$
^[\.]?air\.toml$
^\.vscode$
29 changes: 29 additions & 0 deletions .github/workflows/format-suggest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Workflow derived from https://github.com/posit-dev/setup-air/tree/main/examples
on:
pull_request:

name: format-suggest.yaml

permissions: read-all

jobs:
format-suggest:
name: format-suggest
runs-on: ubuntu-latest
permissions:
pull-requests: write
steps:
- uses: actions/checkout@v4

- name: Install
uses: posit-dev/setup-air@v1

- name: Format
run: air format .

- name: Suggest
uses: reviewdog/action-suggester@v1
with:
level: error
fail_level: error
tool_name: air
68 changes: 45 additions & 23 deletions R/E_loo.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,14 @@ E_loo <- function(x, psis_object, ...) {
#' @rdname E_loo
#' @export
E_loo.default <-
function(x,
psis_object,
...,
type = c("mean", "variance", "sd", "quantile"),
probs = NULL,
log_ratios = NULL) {
function(
x,
psis_object,
...,
type = c("mean", "variance", "sd", "quantile"),
probs = NULL,
log_ratios = NULL
) {
stopifnot(
is.numeric(x),
is.psis(psis_object),
Expand Down Expand Up @@ -137,12 +139,14 @@ E_loo.default <-
#' @rdname E_loo
#' @export
E_loo.matrix <-
function(x,
psis_object,
...,
type = c("mean", "variance", "sd", "quantile"),
probs = NULL,
log_ratios = NULL) {
function(
x,
psis_object,
...,
type = c("mean", "variance", "sd", "quantile"),
probs = NULL,
log_ratios = NULL
) {
stopifnot(
is.numeric(x),
is.psis(psis_object),
Expand All @@ -162,9 +166,13 @@ E_loo.matrix <-
}
w <- weights(psis_object, log = FALSE)

out <- vapply(seq_len(ncol(x)), function(i) {
E_fun(x[, i], w[, i], probs = probs)
}, FUN.VALUE = fun_val)
out <- vapply(
seq_len(ncol(x)),
function(i) {
E_fun(x[, i], w[, i], probs = probs)
},
FUN.VALUE = fun_val
)

if (is.null(log_ratios)) {
# Use of smoothed ratios gives slightly optimistic
Expand All @@ -183,7 +191,6 @@ E_loo.matrix <-
}



#' Select the function to use based on user's 'type' argument
#'
#' @noRd
Expand Down Expand Up @@ -290,22 +297,37 @@ E_loo_khat.matrix <- function(x, psis_object, log_ratios, ...) {
.E_loo_khat_i <- function(x_i, log_ratios_i, tail_len_i) {
h_theta <- x_i
r_theta <- exp(log_ratios_i - max(log_ratios_i))
khat_r <- posterior::pareto_khat(r_theta, tail = "right", ndraws_tail = tail_len_i)
if (is.list(khat_r)) { # retain compatiblity with older posterior that returned a list
khat_r <- posterior::pareto_khat(
r_theta,
tail = "right",
ndraws_tail = tail_len_i
)
if (is.list(khat_r)) {
# retain compatiblity with older posterior that returned a list
khat_r <- khat_r$khat
}
if (is.null(x_i) || is_constant(x_i) || length(unique(x_i))==2 ||
anyNA(x_i) || any(is.infinite(x_i))) {
if (
is.null(x_i) ||
is_constant(x_i) ||
length(unique(x_i)) == 2 ||
anyNA(x_i) ||
any(is.infinite(x_i))
) {
khat_r
} else {
khat_hr <- posterior::pareto_khat(h_theta * r_theta, tail = "both", ndraws_tail = tail_len_i)
if (is.list(khat_hr)) { # retain compatiblity with older posterior that returned a list
khat_hr <- posterior::pareto_khat(
h_theta * r_theta,
tail = "both",
ndraws_tail = tail_len_i
)
if (is.list(khat_hr)) {
# retain compatiblity with older posterior that returned a list
khat_hr <- khat_hr$khat
}
if (is.na(khat_hr) && is.na(khat_r)) {
k <- NA
} else {
k <- max(khat_hr, khat_r, na.rm=TRUE)
k <- max(khat_hr, khat_r, na.rm = TRUE)
}
k
}
Expand Down
34 changes: 25 additions & 9 deletions R/compare.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ compare <- function(..., x = list()) {
dots <- list(...)
if (length(dots)) {
if (length(x)) {
stop("If 'x' is specified then '...' should not be specified.",
call. = FALSE)
stop(
"If 'x' is specified then '...' should not be specified.",
call. = FALSE
)
}
nms <- as.character(match.call(expand.dots = TRUE))[-1L]
} else {
Expand Down Expand Up @@ -97,16 +99,18 @@ compare <- function(..., x = list()) {

x <- sapply(dots, function(x) {
est <- x$estimates
setNames(c(est), nm = c(rownames(est), paste0("se_", rownames(est))) )
setNames(c(est), nm = c(rownames(est), paste0("se_", rownames(est))))
})
colnames(x) <- nms
rnms <- rownames(x)
comp <- x
ord <- order(x[grep("^elpd", rnms), ], decreasing = TRUE)
comp <- t(comp)[ord, ]
patts <- c("elpd", "p_", "^waic$|^looic$", "^se_waic$|^se_looic$")
col_ord <- unlist(sapply(patts, function(p) grep(p, colnames(comp))),
use.names = FALSE)
col_ord <- unlist(
sapply(patts, function(p) grep(p, colnames(comp))),
use.names = FALSE
)
comp <- comp[, col_ord]

# compute elpd_diff and se_elpd_diff relative to best model
Expand All @@ -122,13 +126,25 @@ compare <- function(..., x = list()) {
}



# internal ----------------------------------------------------------------
compare_two_models <- function(loo_a, loo_b, return = c("elpd_diff", "se"), check_dims = TRUE) {
compare_two_models <- function(
loo_a,
loo_b,
return = c("elpd_diff", "se"),
check_dims = TRUE
) {
if (check_dims) {
if (dim(loo_a$pointwise)[1] != dim(loo_b$pointwise)[1]) {
stop(paste("Models don't have the same number of data points.",
"\nFound N_1 =", dim(loo_a$pointwise)[1], "and N_2 =", dim(loo_b$pointwise)[1]), call. = FALSE)
stop(
paste(
"Models don't have the same number of data points.",
"\nFound N_1 =",
dim(loo_a$pointwise)[1],
"and N_2 =",
dim(loo_b$pointwise)[1]
),
call. = FALSE
)
}
}

Expand Down
114 changes: 68 additions & 46 deletions R/crps.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ crps.matrix <- function(x, x2, y, ..., permutations = 1) {
#' @rdname crps
#' @export
crps.numeric <- function(x, x2, y, ..., permutations = 1) {
stopifnot(length(x) == length(x2),
length(y) == 1)
stopifnot(length(x) == length(x2), length(y) == 1)
crps.matrix(as.matrix(x), as.matrix(x2), y, permutations)
}

Expand All @@ -106,23 +105,32 @@ crps.numeric <- function(x, x2, y, ..., permutations = 1) {
#' @param cores The number of cores to use for parallelization of `[psis()]`.
#' See [psis()] for details.
loo_crps.matrix <-
function(x,
x2,
y,
log_lik,
...,
permutations = 1,
r_eff = 1,
cores = getOption("mc.cores", 1)) {
validate_crps_input(x, x2, y, log_lik)
repeats <- replicate(permutations,
EXX_loo_compute(x, x2, log_lik, r_eff = r_eff, ...),
simplify = F)
EXX <- Reduce(`+`, repeats) / permutations
psis_obj <- psis(-log_lik, r_eff = r_eff, cores = cores)
EXy <- E_loo(abs(sweep(x, 2, y)), psis_obj, log_ratios = -log_lik, ...)$value
crps_output(.crps_fun(EXX, EXy))
}
function(
x,
x2,
y,
log_lik,
...,
permutations = 1,
r_eff = 1,
cores = getOption("mc.cores", 1)
) {
validate_crps_input(x, x2, y, log_lik)
repeats <- replicate(
permutations,
EXX_loo_compute(x, x2, log_lik, r_eff = r_eff, ...),
simplify = F
)
EXX <- Reduce(`+`, repeats) / permutations
psis_obj <- psis(-log_lik, r_eff = r_eff, cores = cores)
EXy <- E_loo(
abs(sweep(x, 2, y)),
psis_obj,
log_ratios = -log_lik,
...
)$value
crps_output(.crps_fun(EXX, EXy))
}


#' @rdname crps
Expand All @@ -138,8 +146,7 @@ scrps.matrix <- function(x, x2, y, ..., permutations = 1) {
#' @rdname crps
#' @export
scrps.numeric <- function(x, x2, y, ..., permutations = 1) {
stopifnot(length(x) == length(x2),
length(y) == 1)
stopifnot(length(x) == length(x2), length(y) == 1)
scrps.matrix(as.matrix(x), as.matrix(x2), y, permutations)
}

Expand All @@ -155,40 +162,54 @@ loo_scrps.matrix <-
...,
permutations = 1,
r_eff = 1,
cores = getOption("mc.cores", 1)) {
validate_crps_input(x, x2, y, log_lik)
repeats <- replicate(permutations,
EXX_loo_compute(x, x2, log_lik, r_eff = r_eff, ...),
simplify = F)
EXX <- Reduce(`+`, repeats) / permutations
psis_obj <- psis(-log_lik, r_eff = r_eff, cores = cores)
EXy <- E_loo(abs(sweep(x, 2, y)), psis_obj, log_ratios = -log_lik, ...)$value
crps_output(.crps_fun(EXX, EXy, scale = TRUE))
}
cores = getOption("mc.cores", 1)
) {
validate_crps_input(x, x2, y, log_lik)
repeats <- replicate(
permutations,
EXX_loo_compute(x, x2, log_lik, r_eff = r_eff, ...),
simplify = F
)
EXX <- Reduce(`+`, repeats) / permutations
psis_obj <- psis(-log_lik, r_eff = r_eff, cores = cores)
EXy <- E_loo(
abs(sweep(x, 2, y)),
psis_obj,
log_ratios = -log_lik,
...
)$value
crps_output(.crps_fun(EXX, EXy, scale = TRUE))
}

# ------------ Internals ----------------


EXX_compute <- function(x, x2) {
S <- nrow(x)
colMeans(abs(x - x2[sample(1:S),]))
colMeans(abs(x - x2[sample(1:S), ]))
}


EXX_loo_compute <- function(x, x2, log_lik, r_eff = 1, ...) {
S <- nrow(x)
shuffle <- sample (1:S)
x2 <- x2[shuffle,]
log_lik2 <- log_lik[shuffle,]
psis_obj_joint <- psis(-log_lik - log_lik2 , r_eff = r_eff)
E_loo(abs(x - x2), psis_obj_joint, log_ratios = -log_lik - log_lik2, ...)$value
shuffle <- sample(1:S)
x2 <- x2[shuffle, ]
log_lik2 <- log_lik[shuffle, ]
psis_obj_joint <- psis(-log_lik - log_lik2, r_eff = r_eff)
E_loo(
abs(x - x2),
psis_obj_joint,
log_ratios = -log_lik - log_lik2,
...
)$value
}


#' Function to compute crps and scrps
#' @noRd
.crps_fun <- function(EXX, EXy, scale = FALSE) {
if (scale) return(-EXy/EXX - 0.5 * log(EXX))
if (scale) {
return(-EXy / EXX - 0.5 * log(EXX))
}
0.5 * EXX - EXy
}

Expand All @@ -208,11 +229,12 @@ crps_output <- function(crps_pw) {
#' Check that predictive draws and observed data are of compatible shape
#' @noRd
validate_crps_input <- function(x, x2, y, log_lik = NULL) {
stopifnot(is.numeric(x),
is.numeric(x2),
is.numeric(y),
identical(dim(x), dim(x2)),
ncol(x) == length(y),
ifelse(is.null(log_lik), TRUE, identical(dim(log_lik), dim(x)))
)
stopifnot(
is.numeric(x),
is.numeric(x2),
is.numeric(y),
identical(dim(x), dim(x2)),
ncol(x) == length(y),
ifelse(is.null(log_lik), TRUE, identical(dim(log_lik), dim(x)))
)
}
Loading