Skip to content

Commit

Permalink
Fixes 975 by only removing leftmost array dimension if equal to 1 (#993)
Browse files Browse the repository at this point in the history
* Fixes 975 by only removing leftmost array dimension if equal to 1

* Update tests, fix windows error

---------

Co-authored-by: Andrew Johnson <[email protected]>
  • Loading branch information
SteveBronder and andrjohns authored Jun 8, 2024
1 parent 356fa04 commit 722d196
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 5 deletions.
32 changes: 27 additions & 5 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,24 @@ process_init.default <- function(init, ...) {
return(init)
}

#' Remove the leftmost dimension if equal to 1
#' @noRd
#' @param x An array like object
.remove_leftmost_dim <- function(x) {
dims <- dim(x)
if (length(dims) == 1) {
return(drop(x))
} else if (dims[1] == 1) {
new_dims <- dims[-1]
# Create a call to subset the array, maintaining all remaining dimensions
subset_expr <- as.call(c(as.name("["), list(x), 1, rep(TRUE, length(new_dims)), drop = FALSE))
new_x <- eval(subset_expr)
return(array(new_x, dim = new_dims))
} else {
return(x)
}
}

#' Write initial values to files if provided as posterior `draws` object
#' @noRd
#' @param init A type that inherits the `posterior::draws` class.
Expand Down Expand Up @@ -1097,9 +1115,13 @@ process_init.draws <- function(init, num_procs, model_variables = NULL,
draws_rvar <- posterior::subset_draws(draws_rvar, variable = variable_names)
inits = lapply(1:num_procs, function(draw_iter) {
init_i = lapply(variable_names, function(var_name) {
x = drop(posterior::draws_of(drop(
posterior::subset_draws(draws_rvar[[var_name]], draw=draw_iter))))
return(x)
x = .remove_leftmost_dim(posterior::draws_of(
posterior::subset_draws(draws_rvar[[var_name]], draw=draw_iter)))
if (model_variables$parameters[[var_name]]$dimensions == 0) {
return(as.double(x))
} else {
return(x)
}
})
bad_names = unlist(lapply(variable_names, function(var_name) {
x = drop(posterior::draws_of(drop(
Expand Down Expand Up @@ -1295,13 +1317,13 @@ process_init_approx <- function(init, num_procs, model_variables = NULL,
# Calculate unique draws based on 'lw' using base R functions
unique_draws = length(unique(draws_df$lw))
if (num_procs > unique_draws) {
if (inherits(init, " CmdStanPathfinder ")) {
if (inherits(init, "CmdStanPathfinder")) {
algo_name = " Pathfinder "
extra_msg = " Try running Pathfinder with psis_resample=FALSE."
} else if (inherits(init, "CmdStanVB")) {
algo_name = " CmdStanVB "
extra_msg = ""
} else if (inherits(init, " CmdStanLaplace ")) {
} else if (inherits(init, "CmdStanLaplace")) {
algo_name = " CmdStanLaplace "
extra_msg = ""
} else {
Expand Down
25 changes: 25 additions & 0 deletions tests/testthat/test-model-init.R
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,28 @@ test_that("Initial values for single-element containers treated correctly", {
)
)
})

test_that("Pathfinder inits do not drop dimensions", {
modcode <- "
data {
int N;
vector[N] y;
}
parameters {
matrix[N, 1] mu;
matrix[1, N] mu_2;
vector<lower=0>[N] sigma;
}
model {
target += normal_lupdf(y | mu[:, 1], sigma);
target += normal_lupdf(y | mu_2[1], sigma);
}
"
mod <- cmdstan_model(write_stan_file(modcode), force_recompile = TRUE)
data <- list(N = 100, y = rnorm(100))
pf <- mod$pathfinder(data = data, psis_resample = FALSE)
expect_no_error(fit <- mod$sample(data = data, init = pf, chains = 1,
iter_warmup = 100, iter_sampling = 100))
})

0 comments on commit 722d196

Please sign in to comment.