Skip to content

Commit

Permalink
Merge pull request #11 from seroanalytics/i6
Browse files Browse the repository at this point in the history
simplify input data format
  • Loading branch information
hillalex authored Oct 3, 2024
2 parents 3960395 + 8632ba8 commit 9a808d2
Show file tree
Hide file tree
Showing 14 changed files with 5,679 additions and 5,610 deletions.
62 changes: 32 additions & 30 deletions R/biokinetics.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ biokinetics <- R6::R6Class(
mm
},
construct_design_matrix = function() {
var <- stan_id <- NULL
dt_design_matrix <- private$data[, .SD, .SDcols = private$all_formula_vars, by = stan_id] |>
var <- pid <- NULL
dt_design_matrix <- private$data[, .SD, .SDcols = private$all_formula_vars, by = pid] |>
unique()

# Build the full design matrix using model.matrix
Expand Down Expand Up @@ -155,23 +155,21 @@ biokinetics <- R6::R6Class(
dt_out
},
prepare_stan_data = function() {
stan_id <- titre <- censored <- titre_type_num <- titre_type <- obs_id <- t_since_last_exp <- t_since_min_date <- NULL
pid <- value <- censored <- titre_type_num <- titre_type <- obs_id <- t_since_last_exp <- t_since_min_date <- NULL
stan_data <- list(
N = private$data[, .N],
N_events = private$data[, data.table::uniqueN(stan_id)],
id = private$data[, stan_id],
titre = private$data[, titre],
N_events = private$data[, data.table::uniqueN(pid)],
id = private$data[, pid],
value = private$data[, value],
censored = private$data[, censored],
titre_type = private$data[, titre_type_num],
preds_sd = private$preds_sd,
K = private$data[, data.table::uniqueN(titre_type)],
N_uncens = private$data[censored == 0, .N],
N_lo = private$data[censored == -2, .N],
N_me = private$data[censored == -1, .N],
N_hi = private$data[censored == 1, .N],
uncens_idx = private$data[censored == 0, obs_id],
cens_lo_idx = private$data[censored == -2, obs_id],
cens_me_idx = private$data[censored == -1, obs_id],
cens_hi_idx = private$data[censored == 1, obs_id])

if (private$time_type == "relative") {
Expand Down Expand Up @@ -226,17 +224,13 @@ biokinetics <- R6::R6Class(
covariate_formula = ~0,
preds_sd = 0.25,
time_type = "relative") {
if (!inherits(priors, "biokinetics_priors")) {
stop("'priors' must be of type 'biokinetics_priors'")
}
validate_priors(priors)
private$priors <- priors
validate_numeric(preds_sd)
private$preds_sd <- preds_sd
validate_time_type(time_type)
private$time_type <- time_type
if (!(class(covariate_formula) == "formula")) {
stop("'covariate_formula' must be a formula")
}
validate_formula(covariate_formula)
private$covariate_formula <- covariate_formula
private$all_formula_vars <- all.vars(covariate_formula)
if (is.null(data) && is.null(file_path)) {
Expand All @@ -253,13 +247,17 @@ biokinetics <- R6::R6Class(
}
private$data <- data
}
unknown_vars <- private$all_formula_vars[which(!(private$all_formula_vars %in% names(private$data)))]
if (length(unknown_vars) > 0) {
stop(paste("All variables in 'covariate_formula' must correspond to data columns. Found unknown variables:",
paste(unknown_vars, collapse = ", ")))
}
validate_required_cols(private$data)
validate_formula_vars(private$all_formula_vars, private$data)
logger::log_info("Preparing data for stan")
private$data <- convert_log_scale(private$data, "titre")
private$data <- convert_log_scale(private$data, "value")
private$data[, `:=`(titre_type_num = as.numeric(as.factor(titre_type)),
obs_id = seq_len(.N))]
if (time_type == "relative") {
private$data[, t_since_last_exp := as.integer(date - last_exp_date, units = "days")]
} else {
private$data[, t_since_min_date := as.integer(date - min(date), units = "days")]
}
private$construct_design_matrix()
private$build_covariate_lookup_table()
private$prepare_stan_data()
Expand All @@ -269,6 +267,11 @@ biokinetics <- R6::R6Class(
package = "epikinetics"
)
},
#' @description View the data that is passed to the stan model, for debugging purposes.
#' @return A list of arguments that will be passed to the stan model.
get_stan_data = function() {
private$stan_input_data
},
#' @description Fit the model and return CmdStanMCMC fitted model object.
#' @return A CmdStanMCMC fitted model object: <https://mc-stan.org/cmdstanr/reference/CmdStanMCMC.html>
#' @param ... Named arguments to the `sample()` method of CmdStan model.
Expand Down Expand Up @@ -336,8 +339,7 @@ biokinetics <- R6::R6Class(
dt_out <- private$extract_parameters(params, n_draws)

data.table::setcolorder(dt_out, c("n", "k", ".draw"))

data.table::setnames(dt_out, c("n", ".draw"), c("stan_id", "draw"))
data.table::setnames(dt_out, c("n", ".draw"), c("pid", "draw"))

if (human_readable_covariates) {
logger::log_info("Recovering covariate names")
Expand Down Expand Up @@ -443,7 +445,7 @@ biokinetics <- R6::R6Class(
#' @description Simulate individual trajectories from the model. This is
#' computationally expensive and may take a while to run if n_draws is large.
#' @return A data.table. If summarise = TRUE columns are calendar_date, titre_type, me, lo, hi, time_shift.
#' If summarise = FALSE, columns are stan_id, draw, t, mu, titre_type, exposure_date, calendar_date, time_shift
#' If summarise = FALSE, columns are pid, draw, t, mu, titre_type, exposure_date, calendar_date, time_shift
#' and a column for each covariate in the hierarchical model. See the data vignette for details:
#' \code{vignette("data", package = "epikinetics")}.
#' @param summarise Boolean. If TRUE, average the individual trajectories to get lo, me and
Expand All @@ -468,18 +470,18 @@ biokinetics <- R6::R6Class(
# Calculating the maximum time each individual has data for after the
# exposure of interest
dt_max_dates <- private$data[
, .(t_max = max(t_since_last_exp)), by = .(stan_id)]
, .(t_max = max(t_since_last_exp)), by = .(pid)]

# A very small number of individuals have bleeds on the same day or a few days
# after their recorded exposure dates, resulting in very short trajectories.
# Adding a 50 day buffer to any individuals with less than or equal to 50 days
# of observations after their focal exposure
dt_max_dates <- dt_max_dates[t_max <= 50, t_max := 50, by = .(stan_id)]
dt_max_dates <- dt_max_dates[t_max <= 50, t_max := 50, by = .(pid)]

# Merging the parameter draws with the maximum time data.table
dt_params_ind <- merge(dt_params_ind, dt_max_dates, by = "stan_id")
dt_params_ind <- merge(dt_params_ind, dt_max_dates, by = "pid")

dt_params_ind_trim <- dt_params_ind[, .SD[draw %in% 1:n_draws], by = stan_id]
dt_params_ind_trim <- dt_params_ind[, .SD[draw %in% 1:n_draws], by = pid]

# Running the C++ code to simulate trajectories for each parameter sample
# for each individual
Expand All @@ -495,13 +497,13 @@ biokinetics <- R6::R6Class(
logger::log_info(paste("Calculating exposure dates. Adjusting exposures by", time_shift, "days"))
dt_lookup <- private$data[, .(
exposure_date = min(last_exp_date) - time_shift),
by = c(private$all_formula_vars, "stan_id")]
by = c(private$all_formula_vars, "pid")]

dt_out <- merge(dt_params_ind_traj, dt_lookup, by = "stan_id")
dt_out <- merge(dt_params_ind_traj, dt_lookup, by = "pid")

dt_out[
, calendar_date := exposure_date + t,
by = c(private$all_formula_vars, "stan_id", "titre_type")]
by = c(private$all_formula_vars, "pid", "titre_type")]

if (summarise) {
logger::log_info("Resampling")
Expand Down
29 changes: 29 additions & 0 deletions R/validation.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,32 @@ validate_time_type <- function(time_type) {
stop("'time_type' must be one of 'relative' or 'absolute'")
}
}

validate_formula <- function(covariate_formula) {
if (!(class(covariate_formula) == "formula")) {
stop("'covariate_formula' must be a formula")
}
}

validate_formula_vars <- function(formula_vars, data) {
unknown_vars <- formula_vars[which(!(formula_vars %in% names(data)))]
if (length(unknown_vars) > 0) {
stop(paste("All variables in 'covariate_formula' must correspond to data columns. Found unknown variables:",
paste(unknown_vars, collapse = ", ")))
}
}

validate_required_cols <- function(dat) {
required_cols <- c("pid", "date", "last_exp_date", "titre_type", "value", "censored")
missing_cols <- required_cols[!(required_cols %in% colnames(dat))]
if (length(missing_cols) > 0) {
stop(paste("Missing required columns:",
paste(missing_cols, collapse = ", ")))
}
}

validate_priors <- function(priors) {
if (!inherits(priors, "biokinetics_priors")) {
stop("'priors' must be of type 'biokinetics_priors'")
}
}
Loading

0 comments on commit 9a808d2

Please sign in to comment.