-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
i2: support multiple categorical covariates #4
Changes from 9 commits
161da12
f1fa3df
bfaaa9e
673a36b
f0e2cae
07df22f
9280ceb
44541ea
08c4c38
35fbfe0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ src/stan/**/*.exe | |
src/stan/**/*.EXE | ||
inst/doc | ||
.idea | ||
*.png |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,41 +63,13 @@ scova <- R6::R6Class( | |
# Identify columns with no variance and remove them | ||
variance_per_column <- apply(mm, 2, var) | ||
relevant_columns <- which(variance_per_column != 0) | ||
mm_reduced <- mm[, relevant_columns] | ||
mm_reduced <- mm[, relevant_columns, drop = FALSE] | ||
private$design_matrix <- mm_reduced | ||
}, | ||
build_covariate_lookup_table = function() { | ||
# Extract column names | ||
col_names <- colnames(private$design_matrix) | ||
|
||
# Split column names based on the ':' delimiter | ||
split_data <- stringr::str_split(col_names, ":", simplify = TRUE) | ||
|
||
# Convert the matrix to a data.table | ||
dt <- data.table::as.data.table(split_data) | ||
|
||
# Set the new column names | ||
data.table::setnames(dt, private$all_formula_vars) | ||
|
||
for (col_name in names(dt)) { | ||
# Find the matching formula variable for current column | ||
matching_formula_var <- private$all_formula_vars[which(startsWith(col_name, private$all_formula_vars))] | ||
if (length(matching_formula_var) > 0) { | ||
pattern_to_remove <- paste0("^", matching_formula_var) | ||
dt[, (col_name) := stringr::str_remove_all(get(col_name), pattern_to_remove)] | ||
} | ||
} | ||
|
||
# Declare variables to suppress notes when compiling package | ||
# https://github.com/Rdatatable/data.table/issues/850#issuecomment-259466153 | ||
p <- NULL | ||
|
||
# .I is a special symbol in data.table for row number | ||
dt[, p := .I] | ||
|
||
# Reorder columns to have 'i' first | ||
data.table::setcolorder(dt, "p") | ||
private$covariate_lookup_table <- dt | ||
private$covariate_lookup_table <- build_covariate_lookup_table(private$data, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. moved logic out to own function, mostly for ease of unit testing (but also could be re-used if/when we add CT model) |
||
private$design_matrix, | ||
private$all_formula_vars) | ||
}, | ||
recover_covariate_names = function(dt) { | ||
# Declare variables to suppress notes when compiling package | ||
|
@@ -110,35 +82,46 @@ scova <- R6::R6Class( | |
|
||
dt_out <- dt[dt_titre_lookup, on = "k"][, `:=`(k = NULL)] | ||
if ("p" %in% colnames(dt)) { | ||
dt_out <- dt_out[private$covariate_lookup_table, on = "p"][, `:=`(p = NULL)] | ||
dt_out <- dt_out[private$covariate_lookup_table, on = "p", nomatch = NULL][, `:=`(p = NULL)] | ||
} | ||
dt_out | ||
}, | ||
summarise_pop_fit = function(time_range, | ||
summarise, | ||
n_draws) { | ||
|
||
has_covariates <- length(private$all_formula_vars) > 0 | ||
|
||
# Declare variables to suppress notes when compiling package | ||
# https://github.com/Rdatatable/data.table/issues/850#issuecomment-259466153 | ||
t0_pop <- tp_pop <- ts_pop <- m1_pop <- m2_pop <- m3_pop <- NULL | ||
beta_t0 <- beta_tp <- beta_ts <- beta_m1 <- beta_m2 <- beta_m3 <- NULL | ||
k <- p <- .draw <- t_id <- mu <- NULL | ||
|
||
params <- c("t0_pop[k]", "tp_pop[k]", "ts_pop[k]", | ||
"m1_pop[k]", "m2_pop[k]", "m3_pop[k]") | ||
if (has_covariates) { | ||
params <- c(params, "beta_t0[p]", "beta_tp[p]", "beta_ts[p]", | ||
"beta_m1[p]", "beta_m2[p]", "beta_m3[p]") | ||
} | ||
|
||
params_proc <- rlang::parse_exprs(params) | ||
|
||
dt_samples_wide <- tidybayes::spread_draws( | ||
private$fitted, | ||
t0_pop[k], tp_pop[k], ts_pop[k], | ||
m1_pop[k], m2_pop[k], m3_pop[k], | ||
beta_t0[p], beta_tp[p], beta_ts[p], | ||
beta_m1[p], beta_m2[p], beta_m3[p]) |> | ||
private$fitted, !!!params_proc) |> | ||
data.table() | ||
|
||
dt_samples_wide <- dt_samples_wide[.draw %in% 1:n_draws] | ||
|
||
dt_samples_wide[, `:=`(.chain = NULL, .iteration = NULL)] | ||
|
||
if (!has_covariates) { | ||
# there are no covariates, so add dummy column | ||
# that will be removed after processing | ||
dt_samples_wide$p <- 1 | ||
} | ||
|
||
data.table::setcolorder(dt_samples_wide, c("k", "p", ".draw")) | ||
|
||
if (length(private$all_formula_vars) > 0) { | ||
if (has_covariates) { | ||
logger::log_info("Adjusting by regression coefficients") | ||
dt_samples_wide <- private$adjust_parameters(dt_samples_wide) | ||
} | ||
|
@@ -165,6 +148,11 @@ scova <- R6::R6Class( | |
} | ||
|
||
data.table::setcolorder(dt_out, c("t", "p", "k")) | ||
|
||
if (!has_covariates) { | ||
dt_out[, p:= NULL] | ||
} | ||
dt_out | ||
}, | ||
prepare_stan_data = function() { | ||
stan_id <- titre <- censored <- titre_type_num <- titre_type <- obs_id <- t_since_last_exp <- t_since_min_date <- NULL | ||
|
@@ -228,7 +216,8 @@ scova <- R6::R6Class( | |
#' for required columns: \code{vignette("data", package = "epikinetics")}. | ||
#' @param file_path Optional file path to model inputs in CSV format. One of data or file must be provided. | ||
#' @param priors Object of type \link[epikinetics]{scova_priors}. Default scova_priors(). | ||
#' @param covariate_formula Formula specifying hierarchical structure of model. Default ~0. | ||
#' @param covariate_formula Formula specifying linear regression model. Note all variables in the formula | ||
#' will be treated as categorical variables. Default ~0. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, I might be missing something, but in principle there is no need to treat covariates as categorical. There is certainly a need for continuous variables for the Ct version of the model and I can imagine others wanting continuous variables for this model possibly. It shouldn't change the overall regression parameters. It might make the covariate name recovery a bit more strange, and perhaps the "bespoke" code I wrote for this model might need some tweaking. But we should aim to include continuous variables at some point too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, totally agree that we should accept these in principle, will just be a bit fiddly to get that working so figured this was a reasonable intermediate step. Will add an issue 👍 |
||
#' @param preds_sd Standard deviation of predictor coefficients. Default 0.25. | ||
#' @param time_type One of 'relative' or 'absolute'. Default 'relative'. | ||
initialize = function(priors = scova_priors(), | ||
|
@@ -296,14 +285,23 @@ scova <- R6::R6Class( | |
extract_population_parameters = function(n_draws = 2500, | ||
human_readable_covariates = TRUE) { | ||
private$check_fitted() | ||
params <- c("t0_pop[k]", "tp_pop[k]", "ts_pop[k]", "m1_pop[k]", "m2_pop[k]", | ||
"m3_pop[k]", "beta_t0[p]", "beta_tp[p]", "beta_ts[p]", "beta_m1[p]", | ||
"beta_m2[p]", "beta_m3[p]") | ||
has_covariates <- length(private$all_formula_vars) > 0 | ||
|
||
params <- c("t0_pop[k]", "tp_pop[k]", "ts_pop[k]", "m1_pop[k]", "m2_pop[k]", "m3_pop[k]") | ||
|
||
if (has_covariates) { | ||
params <- c(params, "beta_t0[p]", "beta_tp[p]", "beta_ts[p]", "beta_m1[p]", "beta_m2[p]", "beta_m3[p]") | ||
} | ||
|
||
logger::log_info("Extracting parameters") | ||
dt_out <- private$extract_parameters(params, n_draws) | ||
|
||
data.table::setcolorder(dt_out, c("k", "p", ".draw")) | ||
if (has_covariates){ | ||
data.table::setcolorder(dt_out, c("p", "k", ".draw")) | ||
} else { | ||
data.table::setcolorder(dt_out, c("k", ".draw")) | ||
} | ||
|
||
data.table::setnames(dt_out, ".draw", "draw") | ||
|
||
if (length(private$all_formula_vars) > 0) { | ||
|
@@ -338,6 +336,7 @@ scova <- R6::R6Class( | |
dt_out <- private$extract_parameters(params, n_draws) | ||
|
||
data.table::setcolorder(dt_out, c("n", "k", ".draw")) | ||
|
||
data.table::setnames(dt_out, c("n", ".draw"), c("stan_id", "draw")) | ||
|
||
if (human_readable_covariates) { | ||
|
@@ -408,14 +407,20 @@ scova <- R6::R6Class( | |
human_readable_covariates = FALSE) | ||
|
||
logger::log_info("Calculating peak and switch titre values") | ||
|
||
by <- c("k", "draw") | ||
if ("p" %in% colnames(dt_peak_switch)) { | ||
by <- c("p", by) | ||
} | ||
|
||
dt_peak_switch[, `:=`( | ||
mu_0 = scova_simulate_trajectory( | ||
0, t0_pop, tp_pop, ts_pop, m1_pop, m2_pop, m3_pop), | ||
mu_p = scova_simulate_trajectory( | ||
tp_pop, t0_pop, tp_pop, ts_pop, m1_pop, m2_pop, m3_pop), | ||
mu_s = scova_simulate_trajectory( | ||
ts_pop, t0_pop, tp_pop, ts_pop, m1_pop, m2_pop, m3_pop)), | ||
by = c("p", "k", "draw")] | ||
by = by] | ||
|
||
logger::log_info("Recovering covariate names") | ||
dt_peak_switch <- private$recover_covariate_names(dt_peak_switch) | ||
|
@@ -482,7 +487,7 @@ scova <- R6::R6Class( | |
dt_params_ind_traj <- scova_simulate_trajectories(dt_params_ind_trim) | ||
|
||
dt_params_ind_traj <- data.table::setDT(convert_log_scale_inverse_cpp( | ||
dt_params_ind_traj, vars_to_transform = "mu")) | ||
dt_params_ind_traj, vars_to_transform = "mu")) | ||
|
||
logger::log_info("Recovering covariate names") | ||
dt_params_ind_traj <- private$recover_covariate_names(dt_params_ind_traj) | ||
|
@@ -499,18 +504,19 @@ scova <- R6::R6Class( | |
by = c(private$all_formula_vars, "stan_id", "titre_type")] | ||
|
||
if (summarise) { | ||
logger::log_info("Summarising into population quantiles") | ||
logger::log_info("Resampling") | ||
dt_out <- dt_out[ | ||
!is.nan(mu), .(pop_mu_sum = mean(mosaic::resample(mu))), | ||
by = c("calendar_date", "draw", "titre_type")] | ||
|
||
logger::log_info("Summarising into population quantiles") | ||
dt_out <- summarise_draws( | ||
dt_out, | ||
column_name = "pop_mu_sum", | ||
by = c("calendar_date", "titre_type")) | ||
} | ||
|
||
dt_out[, time_shift:= time_shift] | ||
dt_out[, time_shift := time_shift] | ||
} | ||
) | ||
) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just in case this becomes a 1 column matrix, stop it from being converted to a vector with colnames removed