Skip to content

Commit

Permalink
Fixes 975 by only removing leftmost array dimension if equal to 1
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveBronder committed Jun 7, 2024
1 parent 02259ef commit 8ea0d4a
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
28 changes: 25 additions & 3 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,24 @@ process_init.default <- function(init, ...) {
return(init)
}

#' Remove the leftmost dimension if equal to 1
#' @noRd
#' @param x An array like object
.remove_leftmost_dim <- function(x) {
dims <- dim(x)
if (length(dims) == 1) {
return(drop(x))
} else if (dims[1] == 1) {
new_dims <- dims[-1]
# Create a call to subset the array, maintaining all remaining dimensions
subset_expr <- as.call(c(as.name("["), list(x), 1, rep(TRUE, length(new_dims)), drop = FALSE))
new_x <- eval(subset_expr)
return(array(new_x, dim = new_dims))
} else {
return(x)
}
}

#' Write initial values to files if provided as posterior `draws` object
#' @noRd
#' @param init A type that inherits the `posterior::draws` class.
Expand Down Expand Up @@ -1097,9 +1115,13 @@ process_init.draws <- function(init, num_procs, model_variables = NULL,
draws_rvar <- posterior::subset_draws(draws_rvar, variable = variable_names)
inits = lapply(1:num_procs, function(draw_iter) {
init_i = lapply(variable_names, function(var_name) {
x = drop(posterior::draws_of(drop(
posterior::subset_draws(draws_rvar[[var_name]], draw=draw_iter))))
return(x)
x = .remove_leftmost_dim(posterior::draws_of(
posterior::subset_draws(draws_rvar[[var_name]], draw=draw_iter)))
if (model_variables$parameters[[var_name]]$dimensions == 0) {
return(as.double(x))
} else {
return(x)
}
})
bad_names = unlist(lapply(variable_names, function(var_name) {
x = drop(posterior::draws_of(drop(
Expand Down
13 changes: 13 additions & 0 deletions tests/testthat/resources/stan/issue_975.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
data {
int N;
vector[N] y;
}

parameters {
matrix[N, 1] mu;
vector<lower=0>[N] sigma;
}

model {
target += normal_lupdf(y | mu[:, 1] , sigma);
}
10 changes: 10 additions & 0 deletions tests/testthat/test-init-issue975.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
context("fitted-inits")
set_cmdstan_path()


test_that("Sample method works as init", {
mod <- testing_model("issue_975")
data <- list(N = 100, y = rnorm(100))
pf <- mod$pathfinder(data = data)
expect_no_error(fit <- mod$sample(data = data, init = pf))
})

0 comments on commit 8ea0d4a

Please sign in to comment.