Skip to content

Commit

Permalink
Merge pull request #121 from TeemuSailynoja/loo_pit-for-discrete-data
Browse files Browse the repository at this point in the history
Randomisation for loo_pit of discrete data.
  • Loading branch information
jgabry authored Feb 22, 2024
2 parents b84c0ba + 46b08c9 commit 79114fb
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 9 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

Items for next release go here

Extended `loo_pit()` for discrete data.

# rstantools 2.4.0

* Update to match CRAN's patched version by @jgabry in #114
Expand Down
32 changes: 28 additions & 4 deletions R/loo-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
#' @return `loo_predict()`, `loo_linpred()`, and `loo_pit()`
#' (probability integral transform) methods should return a vector with length
#' equal to the number of observations in the data.
#' For discrete observations, probability integral transform is randomised to
#' ensure theoretical uniformity. Fix random seed for reproducible results
#' with discrete data. For more details, see Czado et al. (2009).
#' `loo_predictive_interval()` methods should return a two-column matrix
#' formatted in the same way as for [predictive_interval()].
#'
#' @template seealso-rstanarm-pkg
#' @template seealso-vignettes
#'
#' @template reference-randomised-pit

#' @rdname loo-prediction
#' @export
Expand Down Expand Up @@ -60,11 +63,32 @@ loo_pit.default <- function(object, y, lw, ...) {

# internal ----------------------------------------------------------------
.loo_pit <- function(y, yrep, lw) {
vapply(seq_len(ncol(yrep)), function(j) {
sel <- yrep[, j] <= y[j]
.exp_log_sum_exp(lw[sel, j])
if (is.null(lw) || !all(is.finite(lw))) {
stop("lw needs to be not null and finite.")
}
pits <- vapply(seq_len(ncol(yrep)), function(j) {
sel_min <- yrep[, j] < y[j]
pit <- .exp_log_sum_exp(lw[sel_min, j])
sel_sup <- yrep[, j] == y[j]
if (any(sel_sup)) {
# randomized PIT for discrete y (see, e.g., Czado, C., Gneiting, T.,
# Held, L.: Predictive model assessment for count data.
# Biometrics 65(4), 1254–1261 (2009).)
pit_sup <- pit + .exp_log_sum_exp(lw[sel_sup, j])
pit <- stats::runif(1, pit, pit_sup)
}
pit
}, FUN.VALUE = 1)
if (any(pits > 1)) {
warning(cat(
"Some PIT values larger than 1! Largest: ",
max(pits),
"\nRounding PIT > 1 to 1."
))
}
pmin(1, pits)
}

.exp_log_sum_exp <- function(x) {
m <- suppressWarnings(max(x))
exp(m + log(sum(exp(x - m))))
Expand Down
4 changes: 2 additions & 2 deletions man-roxygen/details-license.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' @details In order to enable Stan functionality, \pkg{\link{rstantools}}
#' copies some files to your package. Since these files are licensed as GPL
#' >= 3, the same license applies to your package should you choose to
#' copies some files to your package. Since these files are licensed as
#' GPL >= 3, the same license applies to your package should you choose to
#' distribute it. Even if you don't use \pkg{\link{rstantools}} to create
#' your package, it is likely that you will be linking to \pkg{\link{Rcpp}} to
#' export the Stan C++ `stanmodel` objects to \R. Since
Expand Down
5 changes: 5 additions & 0 deletions man-roxygen/reference-randomised-pit.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#' @references Czado, C., Gneiting, T., and Held, L. (2009).
#' Predictive Model Assessment for Count Data.
#' *Biometrics*. 65(4), 1254-1261.
#' doi:10.1111/j.1541-0420.2009.01191.x.
#' Journal version: <https://doi.org/10.1111/j.1541-0420.2009.01191.x>
10 changes: 10 additions & 0 deletions man/loo-prediction.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion man/rstan_create_package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file modified tests/testthat/loo_pit.RDS
Binary file not shown.
Binary file added tests/testthat/loo_pit_discrete.RDS
Binary file not shown.
16 changes: 14 additions & 2 deletions tests/testthat/test-default-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@ context("default methods")
set.seed(1111)
x <- matrix(rnorm(150), 50, 3)
y <- rnorm(ncol(x))
lw <- matrix(rnorm(150), 50, 3)
lw <- sweep(
lw,
MARGIN = 2,
STATS = apply(lw, 2, \(col) log(sum(exp(col)))),
check.margin = FALSE
)

test_that("posterior_interval.default hasn't changed", {
expect_equal_to_reference(
Expand All @@ -28,12 +35,18 @@ test_that("prior_summary.default works", {
expect_null(prior_summary(list(abc = "prior_info")))
})
test_that("loo_pit.default works", {
lw <- matrix(rnorm(150), 50, 3)
expect_equal_to_reference(
loo_pit(x, y, lw),
"loo_pit.RDS"
)
})
test_that("loo_pit-default works for discrete data", {
set.seed(1111)
expect_equal_to_reference(
loo_pit(round(x), round(y), lw),
"loo_pit_discrete.RDS"
)
})
test_that("bayes_R2.default hasn't changed", {
expect_equal_to_reference(
bayes_R2(x, y),
Expand Down Expand Up @@ -82,4 +95,3 @@ test_that(".pred_errors throws errors", {
fixed = TRUE)
})


0 comments on commit 79114fb

Please sign in to comment.