Skip to content

Commit

Permalink
updates for ensuring all parametric priors are included
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasjclark committed Nov 2, 2023
1 parent 3c57763 commit ca18e2e
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 18 deletions.
39 changes: 24 additions & 15 deletions R/add_stan_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -510,21 +510,30 @@ add_stan_data = function(jags_file, stan_file,
if(any(grep('## parametric effect priors', jags_file))){

# Get indices of parametric effects
min_paras <- as.numeric(sub('.*(?=.$)', '',
sub("\\:.*", "",
jags_file[grep('## parametric effect', jags_file) + 1]), perl=T))
max_paras <- as.numeric(substr(sub(".*\\:", "",
jags_file[grep('## parametric effect', jags_file) + 1]),
1, 1))
para_indices <- seq(min_paras, max_paras)

# Get names of parametric terms
int_included <- attr(ss_gam$pterms, 'intercept') == 1L
other_pterms <- attr(ss_gam$pterms, 'term.labels')
all_paras <- other_pterms
if(int_included){
all_paras <- c('(Intercept)', all_paras)
}
smooth_labs <- do.call(rbind, lapply(seq_along(ss_gam$smooth), function(x){
data.frame(label = ss_gam$smooth[[x]]$label,
term = paste(ss_gam$smooth[[x]]$term, collapse = ','),
class = class(ss_gam$smooth[[x]])[1])
}))
lpmat <- predict(ss_gam, type = 'lpmatrix',
exclude = smooth_labs$label)
para_indices <- which(apply(lpmat, 2, function(x) !all(x == 0)) == TRUE)
all_paras <- names(para_indices)
# min_paras <- as.numeric(sub('.*(?=.$)', '',
# sub("\\:.*", "",
# jags_file[grep('## parametric effect', jags_file) + 1]), perl=T))
# max_paras <- as.numeric(substr(sub(".*\\:", "",
# jags_file[grep('## parametric effect', jags_file) + 1]),
# 1, 1))
# para_indices <- seq(min_paras, max_paras)
#
# # Get names of parametric terms
# int_included <- attr(ss_gam$pterms, 'intercept') == 1L
# other_pterms <- attr(ss_gam$pterms, 'term.labels')
# all_paras <- other_pterms
# if(int_included){
# all_paras <- c('(Intercept)', all_paras)
# }

# Create prior lines for parametric terms
para_lines <- vector()
Expand Down
16 changes: 13 additions & 3 deletions R/stan_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -2921,9 +2921,19 @@ add_trend_predictors = function(trend_formula,
paste0('// dynamic process models\n',
paste0(paste(plines, collapse = '\n')))
} else {
model_file[grep("// dynamic factor estimates", model_file, fixed = TRUE)] <-
paste0('// dynamic process models\n',
paste0(paste(plines, collapse = '\n')))
if(any(grepl("// dynamic factor estimates", model_file, fixed = TRUE))){
model_file[grep("// dynamic factor estimates", model_file, fixed = TRUE)] <-
paste0('// dynamic process models\n',
paste0(paste(plines, collapse = '\n')))
}

if(any(grepl("// trend means", model_file, fixed = TRUE))){
model_file[grep("// trend means", model_file, fixed = TRUE)] <-
paste0('// dynamic process models\n',
paste0(paste(plines, collapse = '\n'),
'// trend means'))
}

}

}
Expand Down
Binary file modified src/RcppExports.o
Binary file not shown.
Binary file modified src/mvgam.dll
Binary file not shown.
Binary file modified src/trend_funs.o
Binary file not shown.
Binary file modified tests/testthat/Rplots.pdf
Binary file not shown.
128 changes: 128 additions & 0 deletions tests/testthat/test-mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -225,5 +225,133 @@ test_that("trend_formula setup is working properly", {

})

# Check that parametric effect priors are properly incorporated in the
# model for a wide variety of model forms
test_that("parametric effect priors correctly incorporated in models", {
mod_data <- mvgam:::mvgam_examp_dat
mod_data$data_train$x1 <-
rnorm(NROW(mod_data$data_train))
mod_data$data_train$x2 <-
rnorm(NROW(mod_data$data_train))
mod_data$data_train$x3 <-
rnorm(NROW(mod_data$data_train))

# Observation formula; no trend
mod <- mvgam(y ~ s(season) + series:x1 +
series:x2 + series:x3,
trend_model = 'None',
data = mod_data$data_train,
family = gaussian(),
run_model = FALSE)

expect_true(any(grepl('// prior for seriesseries_3:x1...',
mod$model_file, fixed = TRUE)))
expect_true(any(grepl('// prior for (Intercept)...',
mod$model_file, fixed = TRUE)))

para_names <- paste0(paste0('// prior for seriesseries_', 1:3,
paste0(':x', 1:3, '...')))
for(i in seq_along(para_names)){
expect_true(any(grepl(para_names[i],
mod$model_file, fixed = TRUE)))
}

priors <- get_mvgam_priors(y ~ s(season) + series:x1 +
series:x2 + series:x3,
trend_model = 'None',
data = mod_data$data_train,
family = gaussian())
expect_true(any(grepl('seriesseries_1:x2',
priors$param_name)))
expect_true(any(grepl('seriesseries_2:x3',
priors$param_name)))


# Observation formula; complex trend
mod <- mvgam(y ~ s(season) + series:x1 + series:x2 + series:x3,
trend_model = 'VARMA',
data = mod_data$data_train,
family = gaussian(),
run_model = FALSE)

expect_true(any(grepl('// prior for seriesseries_3:x1...',
mod$model_file, fixed = TRUE)))
expect_true(any(grepl('// prior for (Intercept)...',
mod$model_file, fixed = TRUE)))

para_names <- paste0(paste0('// prior for seriesseries_', 1:3,
paste0(':x', 1:3, '...')))
for(i in seq_along(para_names)){
expect_true(any(grepl(para_names[i],
mod$model_file, fixed = TRUE)))
}

priors <- get_mvgam_priors(y ~ s(season) + series:x1 +
series:x2 + series:x3,
trend_model = 'VARMA',
data = mod_data$data_train,
family = gaussian())
expect_true(any(grepl('seriesseries_1:x2',
priors$param_name)))
expect_true(any(grepl('seriesseries_2:x3',
priors$param_name)))

# Trend formula; RW
mod <- mvgam(y ~ 1,
trend_formula = ~ s(season) + trend:x1 +
trend:x2 + trend:x3,
trend_model = 'RW',
data = mod_data$data_train,
family = gaussian(),
run_model = FALSE)

expect_true(any(grepl('// prior for (Intercept)...',
mod$model_file, fixed = TRUE)))

para_names <- paste0(paste0('// prior for trendtrend', 1:3,
paste0(':x', 1:3, '_trend...')))
for(i in seq_along(para_names)){
expect_true(any(grepl(para_names[i],
mod$model_file, fixed = TRUE)))
}

priors <- get_mvgam_priors(y ~ 1,
trend_formula = ~ s(season) + trend:x1 +
trend:x2 + trend:x3,
trend_model = 'RW',
data = mod_data$data_train,
family = gaussian())
expect_true(any(grepl('trendtrend1:x1_trend',
priors$param_name)))
expect_true(any(grepl('trendtrend2:x3_trend',
priors$param_name)))

# Trend formula; VARMA
mod <- mvgam(y ~ 1,
trend_formula = ~ s(season) + trend:x1 + trend:x2 + trend:x3,
trend_model = 'VARMA',
data = mod_data$data_train,
family = gaussian(),
run_model = FALSE)

expect_true(any(grepl('// prior for (Intercept)...',
mod$model_file, fixed = TRUE)))

para_names <- paste0(paste0('// prior for trendtrend', 1:3,
paste0(':x', 1:3, '_trend...')))
for(i in seq_along(para_names)){
expect_true(any(grepl(para_names[i],
mod$model_file, fixed = TRUE)))
}

priors <- get_mvgam_priors(y ~ 1,
trend_formula = ~ s(season) + trend:x1 + trend:x2 + trend:x3,
trend_model = 'RW',
data = mod_data$data_train,
family = gaussian())
expect_true(any(grepl('trendtrend1:x1_trend',
priors$param_name)))
expect_true(any(grepl('trendtrend2:x3_trend',
priors$param_name)))
})

0 comments on commit ca18e2e

Please sign in to comment.