-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implemented two models, and workflow to how to add data to template t…
…o running model
- Loading branch information
Showing
18 changed files
with
1,134 additions
and
122 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#' Assign data to template | ||
#' | ||
#' @param model_template output from rmot_model | ||
#' @param ... data-masking name-value pairs | ||
#' | ||
#' @return updated named list with your data assigned to Stan model parameters | ||
#' @export | ||
#' | ||
#' @examples | ||
#' rmot_assign_data(X = Loblolly$age, Y = Loblolly$height) | ||
rmot_assign_data <- function(model_template, ...){ | ||
# Grab user expressions | ||
user_code <- rlang::enexprs(..., .check_assign = TRUE) | ||
|
||
# Grab the names | ||
fields <- names(user_code) | ||
|
||
#TODO: Check names are in model_template | ||
|
||
# Evaluate the RHS of expressions (the values) | ||
data <- purrr::map(user_code, | ||
eval) | ||
|
||
for(i in fields){ | ||
model_template <- purrr::list_modify(model_template, !!!data[i]) | ||
} | ||
|
||
#TODO: Check if N is supplied, if not, assign by default to length(X) and give warning | ||
|
||
return(model_template) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,61 +1,51 @@ | ||
# Set list structures for different models | ||
# An example for lm | ||
|
||
rmot_lm <- function(){ | ||
list(X = NULL, | ||
Y = NULL, | ||
N = NULL, | ||
model = "linear") | ||
} | ||
#' Select data configuration template for rmot supported model | ||
#' | ||
#' @param model model name character string | ||
#' | ||
#' @return named list that matches Stan model parameters | ||
#' @export | ||
#' | ||
#' @examples | ||
#' rmot_model("linear") | ||
|
||
rmot_model <- function(model=NULL){ | ||
|
||
# Need a mechanism to select models | ||
# rmot_config(model = "linear") | ||
#TODO: Need a mechanism to check model requested in one that is supported by rmot | ||
|
||
rmot_config <- function(model=NULL){ | ||
output <- switch(model, | ||
linear = rmot_lm()) | ||
linear = rmot_lm(), | ||
constant_single = rmot_cgs()) | ||
|
||
class(output) <- "rmot_object" | ||
|
||
return(output) | ||
} | ||
|
||
# Need a mechanism to take user data and assign to slots in list | ||
rmot_assign_data <- function(model_template, field, data){ | ||
purrr::assign_in(model_template, field, data) | ||
} | ||
|
||
|
||
rmot_assign_data <- function(model_template, ...){ | ||
# Grab user expressions | ||
user_code <- rlang::enexprs(..., .check_assign = TRUE) | ||
#' Data configuration template for linear model | ||
#' @keywords internal | ||
#' @noRd | ||
|
||
# Evaluate the RHS of expressions (the values) | ||
data <- purrr::map(user_code, | ||
eval) | ||
|
||
# Grab the names | ||
fields <- names(user_code) | ||
|
||
for(i in fields){ | ||
model_template <- purrr::list_modify(model_template, !!!data[i]) | ||
} | ||
|
||
return(model_template) | ||
rmot_lm <- function(){ | ||
list(X = NULL, | ||
Y = NULL, | ||
N = NULL, | ||
model = "linear") | ||
} | ||
|
||
#' Data configuration template for constant growth single species model | ||
#' @keywords internal | ||
#' @noRd | ||
|
||
rmot_cgs <- function(){ | ||
list(N_obs = NULL, | ||
N_ind = NULL, | ||
S_obs = NULL, | ||
census = NULL, | ||
census_interval = NULL, | ||
id_factor = NULL, | ||
S_0_obs = NULL, | ||
model = "constant_single") | ||
} | ||
|
||
|
||
|
||
list_rename = function(data, ...) { | ||
mapping = sapply( | ||
rlang::enquos(...), | ||
rlang::as_name | ||
) | ||
new_names = stats::setNames(nm=names(data)) | ||
# `new_name = old_name` for consistency with `dplyr::rename` | ||
new_names[mapping] = names(mapping) | ||
# for `old_name = new_name` use: `new_names[names(mapping)] = mapping` | ||
stats::setNames(data, new_names) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#' Run a linear model in Stan | ||
#' | ||
#' @param model_template model template generated by rmot_model and updated by rmot_assign_data | ||
#' @param ... additional arguments passed to rstan::sampling | ||
#' | ||
#' @return Stanfit model output | ||
#' @export | ||
#' | ||
#' @examples | ||
#' mtcars | ||
#' rmot_lm(mtcars$mpg, mtcars$disp) | ||
rmot_run <- function(model_template, ...) { | ||
|
||
# Detect model | ||
out <- switch(model_template$model, | ||
linear = rstan::sampling(stanmodels$linear, data = model_template, ...), | ||
constant_single = rstan::sampling(stanmodels$constant_single, data = model_template, ...)) | ||
|
||
return(out) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
list_rename = function(data, ...) { | ||
mapping = sapply( | ||
rlang::enquos(...), | ||
rlang::as_name | ||
) | ||
new_names = stats::setNames(nm=names(data)) | ||
# `new_name = old_name` for consistency with `dplyr::rename` | ||
new_names[mapping] = names(mapping) | ||
# for `old_name = new_name` use: `new_names[names(mapping)] = mapping` | ||
stats::setNames(data, new_names) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
//Constant Growth - Single species | ||
|
||
// Data structure | ||
data { | ||
int N_obs; | ||
int N_ind; | ||
real S_obs[N_obs]; | ||
int census[N_obs]; | ||
real census_interval[N_obs]; | ||
int id_factor[N_obs]; | ||
real S_0_obs[N_ind]; | ||
} | ||
|
||
// The parameters accepted by the model. | ||
parameters { | ||
//Individual level | ||
real<lower=0> ind_S_0[N_ind]; | ||
real<lower=0> ind_beta[N_ind]; | ||
|
||
real species_beta_mu; | ||
real<lower=0> species_beta_sigma; | ||
|
||
//Global level | ||
real<lower=0> global_error_sigma; | ||
} | ||
|
||
// The model to be estimated. | ||
model { | ||
real S_hat[N_obs]; | ||
real G_hat[N_obs]; | ||
|
||
for(i in 1:N_obs){ | ||
if(id_factor[i+1]==id_factor[i]){ | ||
if(census[i]==1){//Fits the first size | ||
S_hat[i] = ind_S_0[id_factor[i]]; | ||
} | ||
|
||
if(i < N_obs){ //Analytic solution | ||
G_hat[i] = ind_beta[id_factor[i]]; | ||
S_hat[i+1] = S_hat[i] + G_hat[i]*census_interval[i+1]; | ||
} | ||
} else { | ||
G_hat[i] = 0; //Gives 0 as the growth estimate for the last data point. | ||
} | ||
} | ||
|
||
//Likelihood | ||
S_obs ~ normal(S_hat, global_error_sigma); | ||
|
||
//Priors | ||
//Individual level | ||
ind_S_0 ~ normal(S_0_obs, global_error_sigma); | ||
ind_beta ~ lognormal(species_beta_mu, | ||
species_beta_sigma); | ||
|
||
//Species level | ||
species_beta_mu ~ normal(0.1, 1); | ||
species_beta_sigma ~cauchy(0.1, 1); | ||
|
||
//Global level | ||
global_error_sigma ~cauchy(0.1, 1); | ||
} | ||
|
||
// The output | ||
generated quantities { | ||
real S_hat[N_obs]; | ||
real G_hat[N_obs]; | ||
|
||
for(i in 1:N_obs){ | ||
if(id_factor[i+1]==id_factor[i]){ | ||
if(census[i]==1){//Fits the first size | ||
S_hat[i] = ind_S_0[id_factor[i]]; | ||
} | ||
|
||
if(i < N_obs){ //Analytic solution | ||
G_hat[i] = ind_beta[id_factor[i]]; | ||
S_hat[i+1] = S_hat[i] + G_hat[i]*census_interval[i+1]; | ||
} | ||
} else { | ||
G_hat[i] = 0; //Gives 0 as the growth estimate for the last data point. | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Oops, something went wrong.