diff --git a/NEWS.md b/NEWS.md index e001d5ab..4aa50774 100644 --- a/NEWS.md +++ b/NEWS.md @@ -14,14 +14,14 @@ unique url. ### Some of the new plots and features in ShinyStan app * Rebranding (new look to reflect changes to Stan logo and website) * HMC/NUTS diagnostic plots ('Diagnose' page, 'HMC/NUTS (plots)' tab) -* Specify arbitrary transformations (e.g. log, logit, sqrt, etc.) for density, +* Specify transformations (e.g. log, logit, sqrt, etc.) for density, histogram, bivariate, trivariate plots (on 'Explore' page) and HMC/NUTS diagnostics plots. -* Bivariate scatterplot plot also shows divergent iterations and max treedepth +* Many plots can now also be saved as pdf +* Bivariate scatterplot plot also shows divergent transitions and max treedepth saturation (on 'Explore' page) * (Experimental) Introduce basic graphical posterior predictive checking ('Diagnose' page, 'PPcheck' tab) for limited class of models -* Option to show partial autocorrelations -('Diagnose' page, 'Autocorrelation' tab) +* Option to show partial autocorrelations ('Diagnose' page, 'Autocorrelation' tab) * Better customization of of posterior summary statistics table * Many improvements to GUI design diff --git a/R/convenience.R b/R/convenience.R index 670b2735..36fb9a3c 100644 --- a/R/convenience.R +++ b/R/convenience.R @@ -65,11 +65,12 @@ retrieve_sd <- function(sso, pars) { sp_check <- function(sso) { - if (is.na(sso@sampler_params[[1]])) stop("Only available for Stan models.") + if (identical(sso@sampler_params, list(NA))) + stop("No sampler parameters found", call. = FALSE) } retrieve_max_treedepth <- function(sso, inc_warmup = FALSE) { - sp_check() + sp_check(sso) rows <- if (inc_warmup) 1:sso@nIter else (sso@nWarmup+1):sso@nIter max_td <- sapply(sso@sampler_params, function(x) max(x[rows,"treedepth__"])) @@ -78,7 +79,7 @@ retrieve_max_treedepth <- function(sso, inc_warmup = FALSE) { } retrieve_prop_divergent <- function(sso, inc_warmup = FALSE) { - sp_check() + sp_check(sso) rows <- if (inc_warmup) 1:sso@nIter else (sso@nWarmup+1):sso@nIter prop_div <- sapply(sso@sampler_params, function(x) mean(x[rows,"n_divergent__"])) @@ -87,7 +88,7 @@ retrieve_prop_divergent <- function(sso, inc_warmup = FALSE) { } retrieve_avg_stepsize <- function(sso, inc_warmup = FALSE) { - sp_check() + sp_check(sso) rows <- if (inc_warmup) 1:sso@nIter else (sso@nWarmup+1):sso@nIter avg_ss <- sapply(sso@sampler_params, function(x) mean(x[rows,"stepsize__"])) @@ -96,7 +97,7 @@ retrieve_avg_stepsize <- function(sso, inc_warmup = FALSE) { } retrieve_avg_accept <- function(sso, inc_warmup = FALSE) { - sp_check() + sp_check(sso) rows <- if (inc_warmup) 1:sso@nIter else (sso@nWarmup+1):sso@nIter avg_accept <- sapply(sso@sampler_params, function(x) mean(x[rows,"accept_stat__"])) diff --git a/R/misc.R b/R/misc.R index 276f9bde..5cf87455 100644 --- a/R/misc.R +++ b/R/misc.R @@ -126,3 +126,38 @@ set_ppcheck_defaults <- function(appDir, yrep_name, y_name = "y") { ) } +.retrieve <- function(sso, what, ...) { + if (what %in% c("rhat", "rhats", "Rhat", "Rhats", "r_hat", "R_hat")) { + return(retrieve_rhat(sso, ...)) + } + if (what %in% c("N_eff","n_eff", "neff", "Neff", "ess","ESS")) { + return(retrieve_neff(sso, ...)) + } + if (grepl_ic("mean", what)) { + return(retrieve_mean(sso, ...)) + } + if (grepl_ic("sd", what)) { + return(retrieve_sd(sso, ...)) + } + if (what %in% c("se_mean", "mcse")) { + return(retrieve_mcse(sso, ...)) + } + if (grepl_ic("quant", what)) { + return(retrieve_quant(sso, ...)) + } + if (grepl_ic("median", what)) { + return(retrieve_median(sso, ...)) + } + if (grepl_ic("tree", what) | grepl_ic("depth", what)) { + return(retrieve_max_treedepth(sso, ...)) + } + if (grepl_ic("step", what)) { + return(retrieve_avg_stepsize(sso, ...)) + } + if (grepl_ic("diverg", what)) { + return(retrieve_prop_divergent(sso, ...)) + } + if (grepl_ic("accept", what)) { + return(retrieve_avg_accept(sso, ...)) + } +} \ No newline at end of file diff --git a/R/retrieve.R b/R/retrieve.R index dac4ebd4..b5e94a09 100644 --- a/R/retrieve.R +++ b/R/retrieve.R @@ -62,38 +62,5 @@ retrieve <- function(sso, what, ...) { sso_check(sso) - - if (what %in% c("rhat", "rhats", "Rhat", "Rhats", "r_hat", "R_hat")) { - return(retrieve_rhat(sso, ...)) - } - if (what %in% c("N_eff","n_eff", "neff", "Neff", "ess","ESS")) { - return(retrieve_neff(sso, ...)) - } - if (grepl_ic("mean", what)) { - return(retrieve_mean(sso, ...)) - } - if (grepl_ic("sd", what)) { - return(retrieve_sd(sso, ...)) - } - if (what %in% c("se_mean", "mcse")) { - return(retrieve_mcse(sso, ...)) - } - if (grepl_ic("quant", what)) { - return(retrieve_quant(sso, ...)) - } - if (grepl_ic("median", what)) { - return(retrieve_median(sso, ...)) - } - if (grepl_ic("tree", what) | grepl_ic("depth", what)) { - return(retrieve_max_treedepth(sso, ...)) - } - if (grepl_ic("step", what)) { - return(retrieve_avg_stepsize(sso, ...)) - } - if (grepl_ic("diverg", what)) { - return(retrieve_prop_divergent(sso, ...)) - } - if (grepl_ic("accept", what)) { - return(retrieve_avg_accept(sso, ...)) - } + .retrieve(sso, what, ...) }