Skip to content

Commit

Permalink
feature issue #1657
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Sep 19, 2024
1 parent 4fff23b commit 2e6c14e
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 133 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ Package: brms
Encoding: UTF-8
Type: Package
Title: Bayesian Regression Models using 'Stan'
Version: 2.21.10
Date: 2024-09-16
Version: 2.21.11
Date: 2024-09-19
Authors@R:
c(person("Paul-Christian", "Bürkner", email = "[email protected]",
role = c("aut", "cre")),
Expand Down
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
* Support stratified `cox` models via the new addition term `bhaz`. (#1489)
* Support futures for parallelization in the `cmdstanr` backend. (#1684)
* Add method `loo_epred` thanks to Aki Vehtari. (#1641)
* Add priorsense support via `create_priorsense_data.brmsfit` thanks to Noa Kallioinen. (#1354)
* Add priorsense support via `create_priorsense_data.brmsfit`
thanks to Noa Kallioinen. (#1354)
* Vectorize censored log likelihoods in the Stan code when possible. (#1657)

### Bug Fixes

Expand Down
281 changes: 174 additions & 107 deletions R/stan-likelihood.R

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion R/stan-predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ stan_predictor.brmsframe <- function(x, prior, normalize, ...) {
px <- check_prefix(x)
resp <- usc(combine_prefix(px))
out <- list()
str_add_list(out) <- stan_response(x, normalize = normalize)
str_add_list(out) <- stan_response(x, normalize = normalize, ...)
valid_dpars <- valid_dpars(x)
family_files <- family_info(x, "include")
if (length(family_files)) {
Expand Down
58 changes: 51 additions & 7 deletions R/stan-response.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# of Stan code snippets to be pasted together later on

# Stan code for the response variables
stan_response <- function(bframe, normalize) {
stan_response <- function(bframe, threads, normalize, ...) {
stopifnot(is.brmsframe(bframe))
lpdf <- stan_lpdf_name(normalize)
family <- bframe$family
Expand Down Expand Up @@ -135,24 +135,68 @@ stan_response <- function(bframe, normalize) {
}
if (is.formula(bframe$adforms$cens)) {
str_add(out$data) <- glue(
" array[N{resp}] int<lower=-1,upper=2> cens{resp}; // indicates censoring\n"
" // censoring indicator: 0 = event, 1 = right, -1 = left, 2 = interval censored\n",
" array[N{resp}] int<lower=-1,upper=2> cens{resp};\n"
)
str_add(out$pll_args) <- glue(", data array[] int cens{resp}")
y2_expr <- get_ad_expr(bframe, "cens", "y2")
if (!is.null(y2_expr)) {
# interval censoring is required
is_interval_censored <- !is.null(y2_expr)
if (is_interval_censored) {
# some observations are interval censored
str_add(out$data) <- " // right censor points for interval censoring\n"
if (rtype == "int") {
str_add(out$data) <- glue(
" array[N{resp}] int rcens{resp};"
" array[N{resp}] int rcens{resp};\n"
)
str_add(out$pll_args) <- glue(", data array[] int rcens{resp}")
} else {
str_add(out$data) <- glue(
" vector[N{resp}] rcens{resp};"
" vector[N{resp}] rcens{resp};\n"
)
str_add(out$pll_args) <- glue(", data vector rcens{resp}")
}
str_add(out$data) <- " // right censor points for interval censoring\n"
}
n <- stan_nn(threads)
cens_indicators_def <- glue(
" // indices of censored data\n",
" int Nevent{resp} = 0;\n",
" int Nrcens{resp} = 0;\n",
" int Nlcens{resp} = 0;\n",
" int Nicens{resp} = 0;\n",
" array[N{resp}] int Jevent{resp};\n",
" array[N{resp}] int Jrcens{resp};\n",
" array[N{resp}] int Jlcens{resp};\n",
" array[N{resp}] int Jicens{resp};\n"
)
cens_indicators_comp <- glue(
" // collect indices of censored data\n",
" for (n in 1:N{resp}) {{\n",
stan_nn_def(threads),
" if (cens{resp}{n} == 0) {{\n",
" Nevent{resp} += 1;\n",
" Jevent{resp}[Nevent{resp}] = n;\n",
" }} else if (cens{resp}{n} == 1) {{\n",
" Nrcens{resp} += 1;\n",
" Jrcens{resp}[Nrcens{resp}] = n;\n",
" }} else if (cens{resp}{n} == -1) {{\n",
" Nlcens{resp} += 1;\n",
" Jlcens{resp}[Nlcens{resp}] = n;\n",
" }} else if (cens{resp}{n} == 2) {{\n",
" Nicens{resp} += 1;\n",
" Jicens{resp}[Nicens{resp}] = n;\n",
" }}\n",
" }}\n"
)
if (use_threading(threads)) {
# in threaded Stan code, gathering the indices has to be done on the fly
# inside the reduce_sum call since the indices are dependent on the slice
# of observations whose log likelihood is being evaluated
str_add(out$fun) <- " #include 'fun_add_int.stan'\n"
str_add(out$pll_def) <- cens_indicators_def
str_add(out$model_comp_basic) <- cens_indicators_comp
} else {
str_add(out$tdata_def) <- cens_indicators_def
str_add(out$tdata_comp) <- cens_indicators_comp
}
}
bounds <- bframe$frame$resp$bounds
Expand Down
14 changes: 14 additions & 0 deletions inst/chunks/fun_add_int.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/* add a single integer to an array of integers
* Args:
* x: array of integers
* y: a single integer
* Returns:
* an array of intergers of the same length as x
*/
array[] int add_int(array[] int x, int y) {
array[num_elements(x)] int out;
for (n in 1:num_elements(x)) {
out[n] = x[n] + y;
}
return out;
}
48 changes: 33 additions & 15 deletions tests/testthat/tests.stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ test_that("Stan code for multivariate models is correct", {
scode <- stancode(bform, dat, prior = bprior)
expect_match2(scode, "r_1_y2_3 = r_1[, 3]")
expect_match2(scode, "err_y1[n] = Y_y1[n] - mu_y1[n]")
expect_match2(scode, "target += normal_lccdf(Y_y1[n] | mu_y1[n], sigma_y1)")
expect_match2(scode, "target += normal_lccdf(Y_y1[Jrcens_y1[1:Nrcens_y1]] | mu_y1[Jrcens_y1[1:Nrcens_y1]], sigma_y1)")
expect_match2(scode, "target += skew_normal_lpdf(Y_y2 | mu_y2, omega_y2, alpha_y2)")
expect_match2(scode, "ps[1] = log(theta1_x) + poisson_log_lpmf(Y_x[n] | mu1_x[n])")
expect_match2(scode, "lprior += normal_lpdf(b_y1 | 0, 5)")
Expand Down Expand Up @@ -1338,13 +1338,17 @@ test_that("Stan code of response times models is correct", {

scode <- stancode(count | cens(cens) ~ Trt + (1|patient),
data = dat, family = exgaussian("inverse"))
expect_match2(scode, "exp_mod_normal_lccdf(Y[n] | mu[n] - beta, sigma, inv(beta))")
expect_match2(scode,
"target += exp_mod_normal_lccdf(Y[Jrcens[1:Nrcens]] | mu[Jrcens[1:Nrcens]] - beta, sigma, inv(beta));"
)

scode <- stancode(count ~ Trt, dat, family = shifted_lognormal())
expect_match2(scode, "target += lognormal_lpdf(Y - ndt | mu, sigma)")

scode <- stancode(count | cens(cens) ~ Trt, dat, family = shifted_lognormal())
expect_match2(scode, "target += lognormal_lcdf(Y[n] - ndt | mu[n], sigma)")
expect_match2(scode,
"target += lognormal_lcdf(Y[Jlcens[1:Nlcens]] - ndt | mu[Jlcens[1:Nlcens]], sigma);"
)

# test issue #837
scode <- stancode(mvbind(count, zBase) ~ Trt, data = dat,
Expand Down Expand Up @@ -1401,18 +1405,18 @@ test_that("weighted, censored, and truncated likelihoods are correct", {
"target += weights[n] * (binomial_logit_lpmf(Y[n] | trials[n], mu[n]));"
)

scode <- stancode(y | cens(x, y2) ~ 1, dat, poisson())
expect_match2(scode, "target += poisson_lpmf(Y[n] | mu[n]);")
scode <- stancode(y | cens(x, y2) ~ 1, dat, family = poisson())
expect_match2(scode, "target += poisson_lpmf(Y[Jevent[1:Nevent]] | mu[Jevent[1:Nevent]]);")
expect_match2(scode, "poisson_lcdf(rcens[Jicens[1:Nicens]] | mu[Jicens[1:Nicens]])")

scode <- stancode(y | cens(x) ~ 1, dat, exponential())
expect_match2(scode, "target += exponential_lccdf(Y[n] | inv(mu[n]));")
scode <- stancode(y | cens(x) ~ 1, dat, family = cox())
expect_match2(scode, "target += cox_log_lccdf(Y[n] | mu[n], bhaz[n], cbhaz[n]);")

dat$x[1] <- 2
scode <- stancode(y | cens(x, y2) ~ 1, dat, gaussian())
expect_match2(scode, paste0(
"target += log_diff_exp(\n",
" normal_lcdf(rcens[n] | mu[n], sigma),"
))
scode <- stancode(y | cens(x, y2) ~ 1, dat, family = asym_laplace())
expect_match2(scode, "target += log_diff_exp(\n")
expect_match2(scode, "asym_laplace_lcdf(rcens[n] | mu[n], sigma, quantile),")

dat$x <- 1
expect_match2(stancode(y | cens(x) + weights(x) ~ 1, dat, exponential()),
"target += weights[n] * exponential_lccdf(Y[n] | inv(mu[n]));")
Expand Down Expand Up @@ -1684,7 +1688,9 @@ test_that("Stan code of addition term 'rate' is correct", {
expect_match2(scode, "target += neg_binomial_2_log_lpmf(Y | mu + log_denom, inv(sigma) * denom);")

scode <- stancode(y | rate(time) + cens(1) ~ x, data, geometric())
expect_match2(scode, "target += neg_binomial_2_lpmf(Y[n] | mu[n] * denom[n], 1 * denom[n]);")
expect_match2(scode,
"target += neg_binomial_2_lpmf(Y[Jevent[1:Nevent]] | mu[Jevent[1:Nevent]] .* denom[Jevent[1:Nevent]], 1 * denom[Jevent[1:Nevent]]);"
)
})

test_that("Stan code of GEV models is correct", {
Expand Down Expand Up @@ -2130,7 +2136,7 @@ test_that("Stan code for missing value terms works correctly", {
scode <- stancode(bform, dat)
expect_match2(scode, "vector<lower=0,upper=1>[Nmi_x] Ymi_x;")
expect_match2(scode,
"target += beta_lpdf(Yl_x[n] | mu_x[n] * phi_x, (1 - mu_x[n]) * phi_x);"
"target += beta_lpdf(Y_x[Jevent_x[1:Nevent_x]] | mu_x[Jevent_x[1:Nevent_x]] * phi_x, (1 - mu_x[Jevent_x[1:Nevent_x]]) * phi_x);"
)

# tests #1608
Expand Down Expand Up @@ -2474,6 +2480,7 @@ test_that("threaded Stan code is correct", {

# only run if cmdstan >= 2.29 can be found on the system
# otherwise the canonicalized code will cause test failures
# TODO: switch to testing with rstan?
cmdstan_version <- try(cmdstanr::cmdstan_version(), silent = TRUE)
found_cmdstan <- !brms:::is_try_error(cmdstan_version)
skip_if_not(found_cmdstan && cmdstan_version >= "2.29.0")
Expand Down Expand Up @@ -2542,6 +2549,17 @@ test_that("threaded Stan code is correct", {
expect_match2(scode, "ps[1] = log(theta1) + poisson_log_lpmf(Y[nn] | mu1[n]);")
expect_match2(scode, "ptarget += log_sum_exp(ps);")
expect_match2(scode, "target += reduce_sum_static(partial_log_lik_lpmf,")

# test that code related to censoring is correct
scode <- stancode(
count | cens(Trt) ~ Age, dat, family = lognormal(),
threads = threading(4)
)
expect_match2(scode, "else if (cens[nn] == 1) {")
expect_match2(scode, "Jrcens[Nrcens] = n;")
expect_match2(scode,
"ptarget += lognormal_lcdf(Y[add_int(Jlcens[1:Nlcens], start - 1)] | mu[Jlcens[1:Nlcens]], sigma);"
)
})

test_that("Un-normalized Stan code is correct", {
Expand Down Expand Up @@ -2615,7 +2633,7 @@ test_that("Un-normalized Stan code is correct", {
scode <- stancode(
y | vint(size) + vreal(size) ~ x, data = dat, family = beta_binomial2,
prior = prior(gamma(0.1, 0.1), class = "tau"),
stanvars = stanvars, normalize = FALSE, backend = "cmdstanr"
stanvars = stanvars, normalize = FALSE,
)
expect_match2(scode, "target += beta_binomial2_lpmf(Y[n] | mu[n], tau, vint1[n], vreal1[n]);")
expect_match2(scode, "gamma_lupdf(tau | 0.1, 0.1);")
Expand Down

0 comments on commit 2e6c14e

Please sign in to comment.