diff --git a/R/biokinetics.R b/R/biokinetics.R index f3234c2..a9e4a7c 100644 --- a/R/biokinetics.R +++ b/R/biokinetics.R @@ -63,41 +63,13 @@ biokinetics <- R6::R6Class( # Identify columns with no variance and remove them variance_per_column <- apply(mm, 2, var) relevant_columns <- which(variance_per_column != 0) - mm_reduced <- mm[, relevant_columns] + mm_reduced <- mm[, relevant_columns, drop = FALSE] private$design_matrix <- mm_reduced }, build_covariate_lookup_table = function() { - # Extract column names - col_names <- colnames(private$design_matrix) - - # Split column names based on the ':' delimiter - split_data <- stringr::str_split(col_names, ":", simplify = TRUE) - - # Convert the matrix to a data.table - dt <- data.table::as.data.table(split_data) - - # Set the new column names - data.table::setnames(dt, private$all_formula_vars) - - for (col_name in names(dt)) { - # Find the matching formula variable for current column - matching_formula_var <- private$all_formula_vars[which(startsWith(col_name, private$all_formula_vars))] - if (length(matching_formula_var) > 0) { - pattern_to_remove <- paste0("^", matching_formula_var) - dt[, (col_name) := stringr::str_remove_all(get(col_name), pattern_to_remove)] - } - } - - # Declare variables to suppress notes when compiling package - # https://github.com/Rdatatable/data.table/issues/850#issuecomment-259466153 - p <- NULL - - # .I is a special symbol in data.table for row number - dt[, p := .I] - - # Reorder columns to have 'i' first - data.table::setcolorder(dt, "p") - private$covariate_lookup_table <- dt + private$covariate_lookup_table <- build_covariate_lookup_table(private$data, + private$design_matrix, + private$all_formula_vars) }, recover_covariate_names = function(dt) { # Declare variables to suppress notes when compiling package @@ -110,7 +82,7 @@ biokinetics <- R6::R6Class( dt_out <- dt[dt_titre_lookup, on = "k"][, `:=`(k = NULL)] if ("p" %in% colnames(dt)) { - dt_out <- dt_out[private$covariate_lookup_table, on = "p"][, `:=`(p = NULL)] + dt_out <- dt_out[private$covariate_lookup_table, on = "p", nomatch = NULL][, `:=`(p = NULL)] } dt_out }, @@ -118,27 +90,38 @@ biokinetics <- R6::R6Class( summarise, n_draws) { + has_covariates <- length(private$all_formula_vars) > 0 + # Declare variables to suppress notes when compiling package # https://github.com/Rdatatable/data.table/issues/850#issuecomment-259466153 t0_pop <- tp_pop <- ts_pop <- m1_pop <- m2_pop <- m3_pop <- NULL - beta_t0 <- beta_tp <- beta_ts <- beta_m1 <- beta_m2 <- beta_m3 <- NULL k <- p <- .draw <- t_id <- mu <- NULL + params <- c("t0_pop[k]", "tp_pop[k]", "ts_pop[k]", + "m1_pop[k]", "m2_pop[k]", "m3_pop[k]") + if (has_covariates) { + params <- c(params, "beta_t0[p]", "beta_tp[p]", "beta_ts[p]", + "beta_m1[p]", "beta_m2[p]", "beta_m3[p]") + } + + params_proc <- rlang::parse_exprs(params) + dt_samples_wide <- tidybayes::spread_draws( - private$fitted, - t0_pop[k], tp_pop[k], ts_pop[k], - m1_pop[k], m2_pop[k], m3_pop[k], - beta_t0[p], beta_tp[p], beta_ts[p], - beta_m1[p], beta_m2[p], beta_m3[p]) |> + private$fitted, !!!params_proc) |> data.table() dt_samples_wide <- dt_samples_wide[.draw %in% 1:n_draws] - dt_samples_wide[, `:=`(.chain = NULL, .iteration = NULL)] + if (!has_covariates) { + # there are no covariates, so add dummy column + # that will be removed after processing + dt_samples_wide$p <- 1 + } + data.table::setcolorder(dt_samples_wide, c("k", "p", ".draw")) - if (length(private$all_formula_vars) > 0) { + if (has_covariates) { logger::log_info("Adjusting by regression coefficients") dt_samples_wide <- private$adjust_parameters(dt_samples_wide) } @@ -165,6 +148,11 @@ biokinetics <- R6::R6Class( } data.table::setcolorder(dt_out, c("t", "p", "k")) + + if (!has_covariates) { + dt_out[, p:= NULL] + } + dt_out }, prepare_stan_data = function() { stan_id <- titre <- censored <- titre_type_num <- titre_type <- obs_id <- t_since_last_exp <- t_since_min_date <- NULL @@ -228,7 +216,8 @@ biokinetics <- R6::R6Class( #' for required columns: \code{vignette("data", package = "epikinetics")}. #' @param file_path Optional file path to model inputs in CSV format. One of data or file must be provided. #' @param priors Object of type \link[epikinetics]{biokinetics_priors}. Default biokinetics_priors(). - #' @param covariate_formula Formula specifying hierarchical structure of model. Default ~0. + #' @param covariate_formula Formula specifying linear regression model. Note all variables in the formula + #' will be treated as categorical variables. Default ~0. #' @param preds_sd Standard deviation of predictor coefficients. Default 0.25. #' @param time_type One of 'relative' or 'absolute'. Default 'relative'. initialize = function(priors = biokinetics_priors(), @@ -296,14 +285,23 @@ biokinetics <- R6::R6Class( extract_population_parameters = function(n_draws = 2500, human_readable_covariates = TRUE) { private$check_fitted() - params <- c("t0_pop[k]", "tp_pop[k]", "ts_pop[k]", "m1_pop[k]", "m2_pop[k]", - "m3_pop[k]", "beta_t0[p]", "beta_tp[p]", "beta_ts[p]", "beta_m1[p]", - "beta_m2[p]", "beta_m3[p]") + has_covariates <- length(private$all_formula_vars) > 0 + + params <- c("t0_pop[k]", "tp_pop[k]", "ts_pop[k]", "m1_pop[k]", "m2_pop[k]", "m3_pop[k]") + + if (has_covariates) { + params <- c(params, "beta_t0[p]", "beta_tp[p]", "beta_ts[p]", "beta_m1[p]", "beta_m2[p]", "beta_m3[p]") + } logger::log_info("Extracting parameters") dt_out <- private$extract_parameters(params, n_draws) - data.table::setcolorder(dt_out, c("k", "p", ".draw")) + if (has_covariates){ + data.table::setcolorder(dt_out, c("p", "k", ".draw")) + } else { + data.table::setcolorder(dt_out, c("k", ".draw")) + } + data.table::setnames(dt_out, ".draw", "draw") if (length(private$all_formula_vars) > 0) { @@ -338,6 +336,7 @@ biokinetics <- R6::R6Class( dt_out <- private$extract_parameters(params, n_draws) data.table::setcolorder(dt_out, c("n", "k", ".draw")) + data.table::setnames(dt_out, c("n", ".draw"), c("stan_id", "draw")) if (human_readable_covariates) { @@ -408,6 +407,12 @@ biokinetics <- R6::R6Class( human_readable_covariates = FALSE) logger::log_info("Calculating peak and switch titre values") + + by <- c("k", "draw") + if ("p" %in% colnames(dt_peak_switch)) { + by <- c("p", by) + } + dt_peak_switch[, `:=`( mu_0 = biokinetics_simulate_trajectory( 0, t0_pop, tp_pop, ts_pop, m1_pop, m2_pop, m3_pop), @@ -415,7 +420,7 @@ biokinetics <- R6::R6Class( tp_pop, t0_pop, tp_pop, ts_pop, m1_pop, m2_pop, m3_pop), mu_s = biokinetics_simulate_trajectory( ts_pop, t0_pop, tp_pop, ts_pop, m1_pop, m2_pop, m3_pop)), - by = c("p", "k", "draw")] + by = by] logger::log_info("Recovering covariate names") dt_peak_switch <- private$recover_covariate_names(dt_peak_switch) @@ -482,7 +487,7 @@ biokinetics <- R6::R6Class( dt_params_ind_traj <- biokinetics_simulate_trajectories(dt_params_ind_trim) dt_params_ind_traj <- data.table::setDT(convert_log_scale_inverse_cpp( - dt_params_ind_traj, vars_to_transform = "mu")) + dt_params_ind_traj, vars_to_transform = "mu")) logger::log_info("Recovering covariate names") dt_params_ind_traj <- private$recover_covariate_names(dt_params_ind_traj) @@ -499,18 +504,19 @@ biokinetics <- R6::R6Class( by = c(private$all_formula_vars, "stan_id", "titre_type")] if (summarise) { - logger::log_info("Summarising into population quantiles") + logger::log_info("Resampling") dt_out <- dt_out[ !is.nan(mu), .(pop_mu_sum = mean(mosaic::resample(mu))), by = c("calendar_date", "draw", "titre_type")] + logger::log_info("Summarising into population quantiles") dt_out <- summarise_draws( dt_out, column_name = "pop_mu_sum", by = c("calendar_date", "titre_type")) } - dt_out[, time_shift:= time_shift] + dt_out[, time_shift := time_shift] } ) ) diff --git a/R/utils.R b/R/utils.R index 09981d5..9ccd3fa 100644 --- a/R/utils.R +++ b/R/utils.R @@ -3,11 +3,11 @@ convert_log_scale <- function( simplify_limits = TRUE) { dt_out <- data.table::copy(dt_in) - for(var in vars_to_transform) { - if(simplify_limits == TRUE) { + for (var in vars_to_transform) { + if (simplify_limits == TRUE) { dt_out[get(var) > 2560, (var) := 2560] } - dt_out[, (var) := log2(get(var)/5)] + dt_out[, (var) := log2(get(var) / 5)] } return(dt_out) } @@ -23,9 +23,9 @@ convert_log_scale <- function( #' @export convert_log_scale_inverse <- function(dt_in, vars_to_transform) { dt_out <- data.table::copy(dt_in) - for(var in vars_to_transform) { + for (var in vars_to_transform) { # Reverse the log2 transformation and multiplication by 5. - dt_out[, (var) := 5*2^(get(var))] + dt_out[, (var) := 5 * 2^(get(var))] } dt_out } @@ -40,6 +40,33 @@ summarise_draws <- function(dt_in, column_name, by = by) { lo = stats::quantile(get(column_name), 0.025), hi = stats::quantile(get(column_name), 0.975) ), - by = by + by = by ] } + +build_covariate_lookup_table <- function(data, design_matrix, all_formula_vars) { + + if (length(all_formula_vars) == 0) { + return(NULL) + } + p_names <- colnames(design_matrix) + p_names <- data.table::data.table(p_name = p_names, p = seq_along(p_names)) + + factors <- lapply(all_formula_vars, function(v) { levels(as.factor(as.data.frame(data)[, v])) }) + names(factors) <- all_formula_vars + combinations <- expand.grid(factors) + + for (i in 1:nrow(combinations)) { + combinations[i, "p_name"] <- paste0(all_formula_vars, as.matrix(combinations[i, all_formula_vars]), collapse = ":") + } + + dt_out <- data.table::setDT(combinations)[p_names, on = "p_name"] + + for (f in names(factors)) { + re <- paste0("(?<=", f, ").+?(?=$|:)") + dt_out[, f] <- dt_out[, stringr::str_extract(p_name, re)] + } + + dt_out[, p_name := NULL] + dt_out +} diff --git a/man/biokinetics.Rd b/man/biokinetics.Rd index f89bd87..12cc552 100644 --- a/man/biokinetics.Rd +++ b/man/biokinetics.Rd @@ -45,7 +45,8 @@ for required columns: \code{vignette("data", package = "epikinetics")}.} \item{\code{file_path}}{Optional file path to model inputs in CSV format. One of data or file must be provided.} -\item{\code{covariate_formula}}{Formula specifying hierarchical structure of model. Default ~0.} +\item{\code{covariate_formula}}{Formula specifying linear regression model. Note all variables in the formula +will be treated as categorical variables. Default ~0.} \item{\code{preds_sd}}{Standard deviation of predictor coefficients. Default 0.25.} diff --git a/tests/testthat/manual-test-multiplecovariates.R b/tests/testthat/manual-test-multiplecovariates.R new file mode 100644 index 0000000..2246e40 --- /dev/null +++ b/tests/testthat/manual-test-multiplecovariates.R @@ -0,0 +1,118 @@ +library(ggplot2) +library(epikinetics) + +mod <- scova$new(file_path = system.file("delta_full.rds", package = "epikinetics"), + priors = scova_priors(), + covariate_formula = ~0 + infection_history:last_vax_type) + +mod$fit(chains = 4, + parallel_chains = 4, + iter_warmup = 50, + iter_sampling = 200, + threads_per_chain = 4) + +dat <- mod$simulate_population_trajectories() +dat[, titre_type := forcats::fct_relevel( + titre_type, + c("Ancestral", "Alpha", "Delta"))] + +ggplot(data = dat[!is.na(last_vax_type)]) + + geom_line(aes(x = t, + y = me, + colour = titre_type)) + + geom_ribbon(aes(x = t, + ymin = lo, + ymax = hi, + fill = titre_type), alpha = 0.65) + + coord_cartesian(clip = "off") + + labs(x = "Time since last exposure (days)", + y = expression(paste("Titre (IC"[50], ")"))) + + scale_y_continuous( + trans = "log2") + + facet_wrap(titre_type ~ infection_history + last_vax_type) + +dat <- mod$population_stationary_points() +dat[, titre_type := forcats::fct_relevel( + titre_type, + c("Ancestral", "Alpha", "Delta"))] + +ggplot(data = dat[!is.na(last_vax_type)], aes( + x = mu_p, y = mu_s, + colour = titre_type)) + + geom_density_2d( + aes( + group = interaction( + infection_history, + last_vax_type, + titre_type))) + + geom_point(alpha = 0.05, size = 0.2) + + geom_point(aes(x = mu_p_me, y = mu_s_me, + shape = interaction(infection_history, last_vax_type)), + colour = "black") + + geom_path(aes(x = mu_p_me, y = mu_s_me, + group = titre_type), + colour = "black") + + geom_vline(xintercept = 2560, linetype = "twodash", colour = "gray30") + + scale_x_continuous( + trans = "log2", + breaks = c(40, 80, 160, 320, 640, 1280, 2560, 5120, 10240), + labels = c(expression(" " <= 40), + "80", "160", "320", "640", "1280", "2560", "5120", "10240"), + limits = c(NA, 10240)) + + geom_hline(yintercept = 2560, linetype = "twodash", colour = "gray30") + + scale_y_continuous( + trans = "log2", + breaks = c(40, 80, 160, 320, 640, 1280, 2560, 5120, 10240), + labels = c(expression(" " <= 40), + "80", "160", "320", "640", "1280", "2560", "5120", "10240"), + limits = c(NA, 5120)) + + labs(x = expression(paste("Population-level titre value at peak (IC"[50], ")")), + y = expression(paste("Population-level titre value at set-point (IC"[50], ")"))) + +dat <- mod$simulate_individual_trajectories() +rawdat <- data.table::fread(system.file("delta_full.rds", package = "epikinetics")) + +date_delta <- lubridate::ymd("2021-05-07") +date_ba2 <- lubridate::ymd("2022-01-24") + +dat$wave <- "Delta" +rawdat$wave <- "Delta" +plot_data <- merge( + dat, rawdat[, .( + min_date = min(date), max_date = max(date)), by = wave])[ + , .SD[calendar_date >= min_date & calendar_date <= date_ba2], by = wave] + +plot_data[, titre_type := forcats::fct_relevel( + titre_type, + c("Ancestral", "Alpha", "Delta"))] + +ggplot() + geom_line( + data = plot_data, + aes(x = calendar_date, + y = me, + group = interaction(titre_type, wave), + colour = titre_type), + alpha = 0.2) + + geom_ribbon( + data = plot_data, + aes(x = calendar_date, + ymin = lo, + ymax = hi, + group = interaction(titre_type, wave) + ), + alpha = 0.2) + + labs(title = "Population-level titre values", + tag = "A", + x = "Date", + y = expression(paste("Titre (IC"[50], ")"))) + + scale_x_date( + date_labels = "%b %Y", + limits = c(min(rawdat$date), date_ba2)) + + geom_smooth( + data = plot_data, + aes(x = calendar_date, + y = me, + fill = titre_type, + colour = titre_type, + group = interaction(titre_type, wave)), + alpha = 0.5, span = 0.2) \ No newline at end of file diff --git a/tests/testthat/manual-test-nocovariates.R b/tests/testthat/manual-test-nocovariates.R new file mode 100644 index 0000000..81d0a20 --- /dev/null +++ b/tests/testthat/manual-test-nocovariates.R @@ -0,0 +1,117 @@ +library(ggplot2) +library(epikinetics) + +mod <- scova$new(file_path = system.file("delta_full.rds", package = "epikinetics"), + priors = scova_priors()) + +mod$fit(chains = 4, + parallel_chains = 4, + iter_warmup = 50, + iter_sampling = 200, + threads_per_chain = 4) + +dat <- mod$simulate_population_trajectories() +dat[, titre_type := forcats::fct_relevel( + titre_type, + c("Ancestral", "Alpha", "Delta"))] + +ggplot(data = dat) + + geom_line(aes(x = t, + y = me, + colour = titre_type)) + + geom_ribbon(aes(x = t, + ymin = lo, + ymax = hi, + fill = titre_type), alpha = 0.65) + + coord_cartesian(clip = "off") + + labs(x = "Time since last exposure (days)", + y = expression(paste("Titre (IC"[50], ")"))) + + scale_y_continuous( + trans = "log2") + + facet_wrap(~titre_type) + +dat <- mod$population_stationary_points() +dat[, titre_type := forcats::fct_relevel( + titre_type, + c("Ancestral", "Alpha", "Delta"))] + +ggplot(data = dat, aes( + x = mu_p, y = mu_s, + colour = titre_type)) + + geom_density_2d( + aes( + group = titre_type)) + + geom_point(alpha = 0.05, size = 0.2) + + geom_point(aes(x = mu_p_me, y = mu_s_me), + colour = "black") + + geom_path(aes(x = mu_p_me, y = mu_s_me, + group = titre_type), + colour = "black") + + geom_vline(xintercept = 2560, linetype = "twodash", colour = "gray30") + + scale_x_continuous( + trans = "log2", + breaks = c(40, 80, 160, 320, 640, 1280, 2560, 5120, 10240), + labels = c(expression(" " <= 40), + "80", "160", "320", "640", "1280", "2560", "5120", "10240"), + limits = c(NA, 10240)) + + geom_hline(yintercept = 2560, linetype = "twodash", colour = "gray30") + + scale_y_continuous( + trans = "log2", + breaks = c(40, 80, 160, 320, 640, 1280, 2560, 5120, 10240), + labels = c(expression(" " <= 40), + "80", "160", "320", "640", "1280", "2560", "5120", "10240"), + limits = c(NA, 5120)) + + scale_shape_manual(values = c(1, 2, 3)) + + labs(x = expression(paste("Population-level titre value at peak (IC"[50], ")")), + y = expression(paste("Population-level titre value at set-point (IC"[50], ")"))) + + guides(colour = guide_legend(title = "Titre type", override.aes = list(alpha = 1, size = 1)), + shape = guide_legend(title = "Infection history")) + +dat <- mod$simulate_individual_trajectories() + +rawdat <- data.table::fread(system.file("delta_full.rds", package = "epikinetics")) + +date_delta <- lubridate::ymd("2021-05-07") +date_ba2 <- lubridate::ymd("2022-01-24") + +dat$wave <- "Delta" +rawdat$wave <- "Delta" +plot_data <- merge( + dat, rawdat[, .( + min_date = min(date), max_date = max(date)), by = wave])[ + , .SD[calendar_date >= min_date & calendar_date <= date_ba2], by = wave] + +plot_data[, titre_type := forcats::fct_relevel( + titre_type, + c("Ancestral", "Alpha", "Delta"))] + +ggplot() + geom_line( + data = plot_data, + aes(x = calendar_date, + y = me, + group = interaction(titre_type, wave), + colour = titre_type), + alpha = 0.2) + + geom_ribbon( + data = plot_data, + aes(x = calendar_date, + ymin = lo, + ymax = hi, + group = interaction(titre_type, wave) + ), + alpha = 0.2) + + labs(title = "Population-level titre values", + tag = "A", + x = "Date", + y = expression(paste("Titre (IC"[50], ")"))) + + scale_x_date( + date_labels = "%b %Y", + limits = c(min(rawdat$date), date_ba2)) + + geom_smooth( + data = plot_data, + aes(x = calendar_date, + y = me, + fill = titre_type, + colour = titre_type, + group = interaction(titre_type, wave)), + alpha = 0.5, span = 0.2) diff --git a/tests/testthat/test-extract-parameters.R b/tests/testthat/test-extract-parameters.R index de0e955..efbc8a1 100644 --- a/tests/testthat/test-extract-parameters.R +++ b/tests/testthat/test-extract-parameters.R @@ -2,32 +2,42 @@ mock_model <- function(name, package) { list(sample = function(x, ...) readRDS(test_path("testdata", "testdraws.rds"))) } -local_mocked_bindings( - stan_package_model = mock_model, .package = "instantiate" -) - test_that("Cannot retrieve population params until model is fitted", { + local_mocked_bindings( + stan_package_model = mock_model, .package = "instantiate" + ) mod <- biokinetics$new(file_path = system.file("delta_full.rds", package = "epikinetics")) expect_error(mod$extract_population_parameters(), "Model has not been fitted yet. Call 'fit' before calling this function.") }) test_that("Cannot retrieve individual params until model is fitted", { + local_mocked_bindings( + stan_package_model = mock_model, .package = "instantiate" + ) mod <- biokinetics$new(file_path = system.file("delta_full.rds", package = "epikinetics")) expect_error(mod$extract_individual_parameters(), "Model has not been fitted yet. Call 'fit' before calling this function.") }) test_that("Can extract population parameters without human readable covariates", { + + local_mocked_bindings( + stan_package_model = mock_model, .package = "instantiate" + ) + mod <- biokinetics$new(file_path = system.file("delta_full.rds", package = "epikinetics"), - covariate_formula = ~0 + infection_history) + covariate_formula = ~0 + infection_history) mod$fit() params <- mod$extract_population_parameters(n_draws = 10, human_readable_covariates = FALSE) - expect_equal(names(params), c("k", "p", "draw", "t0_pop", "tp_pop", "ts_pop", "m1_pop", "m2_pop", "m3_pop", + expect_equal(names(params), c("p", "k", "draw", "t0_pop", "tp_pop", "ts_pop", "m1_pop", "m2_pop", "m3_pop", "beta_t0", "beta_tp", "beta_ts", "beta_m1", "beta_m2", "beta_m3")) }) test_that("Can extract population parameters with human readable covariates", { + local_mocked_bindings( + stan_package_model = mock_model, .package = "instantiate" + ) mod <- biokinetics$new(file_path = system.file("delta_full.rds", package = "epikinetics"), - covariate_formula = ~0 + infection_history) + covariate_formula = ~0 + infection_history) mod$fit() params <- mod$extract_population_parameters(n_draws = 10, human_readable_covariates = TRUE) expect_equal(names(params), c("draw", "t0_pop", "tp_pop", "ts_pop", "m1_pop", "m2_pop", "m3_pop", @@ -36,8 +46,11 @@ test_that("Can extract population parameters with human readable covariates", { }) test_that("Can extract individual parameters without human readable covariates", { + local_mocked_bindings( + stan_package_model = mock_model, .package = "instantiate" + ) mod <- biokinetics$new(file_path = system.file("delta_full.rds", package = "epikinetics"), - covariate_formula = ~0 + infection_history) + covariate_formula = ~0 + infection_history) mod$fit() params <- mod$extract_individual_parameters(n_draws = 10, human_readable_covariates = FALSE, @@ -47,8 +60,11 @@ test_that("Can extract individual parameters without human readable covariates", }) test_that("Can extract individual parameters with human readable covariates", { + local_mocked_bindings( + stan_package_model = mock_model, .package = "instantiate" + ) mod <- biokinetics$new(file_path = system.file("delta_full.rds", package = "epikinetics"), - covariate_formula = ~0 + infection_history) + covariate_formula = ~0 + infection_history) mod$fit() params <- mod$extract_individual_parameters(n_draws = 10, human_readable_covariates = TRUE, @@ -58,8 +74,12 @@ test_that("Can extract individual parameters with human readable covariates", { }) test_that("Can extract individual parameters with variation params", { + local_mocked_bindings( + stan_package_model = mock_model, .package = "instantiate" + ) + mod <- biokinetics$new(file_path = system.file("delta_full.rds", package = "epikinetics"), - covariate_formula = ~0 + infection_history) + covariate_formula = ~0 + infection_history) mod$fit() params <- mod$extract_individual_parameters(n_draws = 10, human_readable_covariates = TRUE, diff --git a/tests/testthat/test-recover-covariates.R b/tests/testthat/test-recover-covariates.R new file mode 100644 index 0000000..e470be8 --- /dev/null +++ b/tests/testthat/test-recover-covariates.R @@ -0,0 +1,60 @@ +test_that("Can recover single covariate", { + dat <- data.table(infection_history = c("Infection naive", "Previously infected"), + value = 1, + last_vax_type = c("a", "b", "c", "d")) + mm <- stats::model.matrix(~0 + infection_history, dat) + lookup <- build_covariate_lookup_table(dat, mm, c("infection_history")) + + values <- data.table(p = c(1, 2, 2, 1), value = c(1, 2, 3, 4)) + res <- na.omit(values[lookup, on = "p"], colnames(values))[, `:=`(p = NULL)] + expect_equal(names(res), c("value", "infection_history")) + expect_equal(res$infection_history, c("Infection naive", "Infection naive", + "Previously infected", "Previously infected")) + expect_equal(res$value, c(1, 4, 2, 3)) +}) + +test_that("Can recover multiple covariate from interaction term", { + dat <- data.frame(infection_history = c("Infection naive", "Previously infected"), + value = 1, + last_vax_type = c("a", "b", "c", "d")) + mm <- stats::model.matrix(~0 + infection_history:last_vax_type, dat) + lookup <- build_covariate_lookup_table(dat, mm, c("infection_history", "last_vax_type")) + + values <- data.table(p = c(1, 2, 3, 4, 5, 6, 7, 8), value = c(1, 2, 3, 4, 5, 6, 7, 8)) + res <- na.omit(values[lookup, on = "p"], colnames(values))[, `:=`(p = NULL)] + expect_equal(names(res), c("value", "infection_history", "last_vax_type")) + expect_equal(res$infection_history, rep(c("Infection naive", "Previously infected"), 4)) + expect_equal(res$last_vax_type, c("a", "a", "b", "b", "c", "c", "d", "d")) + expect_equal(res$value, c(1, 2, 3, 4, 5, 6, 7, 8)) +}) + +test_that("Can recover multiple covariates without interaction term", { + dat <- data.frame(infection_history = c("Infection naive", "Previously infected"), + value = 1, + last_vax_type = c("a", "b", "c", "d")) + mm <- stats::model.matrix(~0 + infection_history + last_vax_type, dat) + lookup <- build_covariate_lookup_table(dat, mm, c("infection_history", "last_vax_type")) + + values <- data.table(p = c(1, 2, 3, 4, 5), value = c(1, 2, 3, 4, 5)) + res <- na.omit(values[lookup, on = "p"], colnames(values))[, `:=`(p = NULL)] + expect_equal(names(res), c("value", "infection_history", "last_vax_type")) + expect_equal(res$infection_history, c("Infection naive", "Previously infected", NA, NA, NA)) + expect_equal(res$last_vax_type, c(NA, NA, "b", "c", "d")) + expect_equal(res$value, c(1, 2, 3, 4, 5)) +}) + +test_that("Can recover multiple covariates and interaction term", { + dat <- data.frame(infection_history = c("Infection naive", "Previously infected"), + value = 1, + last_vax_type = c("a", "b", "c", "d")) + mm <- stats::model.matrix(~0 + infection_history + last_vax_type + infection_history:last_vax_type, dat) + lookup <- build_covariate_lookup_table(dat, mm, c("infection_history", "last_vax_type")) + + values <- data.table(p = c(1, 2, 3, 4, 5, 6, 7, 8), value = c(1, 2, 3, 4, 5, 6, 7, 8)) + res <- na.omit(values[lookup, on = "p"], colnames(values))[, `:=`(p = NULL)] + expect_equal(names(res), c("value", "infection_history", "last_vax_type")) + expect_equal(res$infection_history, c("Infection naive", "Previously infected", NA, NA, NA, + "Previously infected", "Previously infected", "Previously infected")) + expect_equal(res$last_vax_type, c(NA, NA, "b", "c", "d", "b", "c", "d")) + expect_equal(res$value, c(1, 2, 3, 4, 5, 6, 7, 8)) +}) diff --git a/tests/testthat/test-run-model.R b/tests/testthat/test-run-model.R index 545a746..ba337ea 100644 --- a/tests/testthat/test-run-model.R +++ b/tests/testthat/test-run-model.R @@ -1,12 +1,19 @@ -mock_model <- function(name, package) { +mock_model_return_args <- function(name, package) { list(sample = function(x, ...) list(...)) } -local_mocked_bindings( - stan_package_model = mock_model, .package = "instantiate" -) +mock_model_multiple_covariates <- function(name, package) { + list(sample = function(x, ...) readRDS(test_path("testdata", "testdraws_multiplecovariates.rds"))) +} + +mock_model_no_covariates <- function(name, package) { + list(sample = function(x, ...) readRDS(test_path("testdata", "testdraws_nocovariates.rds"))) +} test_that("Can fit model with arguments", { + local_mocked_bindings( + stan_package_model = mock_model_return_args, .package = "instantiate" + ) res <- biokinetics$new(file_path = system.file("delta_full.rds", package = "epikinetics"), priors = biokinetics_priors())$fit(chains = 4, parallel_chains = 4, @@ -15,3 +22,57 @@ test_that("Can fit model with arguments", { threads_per_chain = 4) expect_equal(names(res), c("chains", "parallel_chains", "iter_warmup", "iter_sampling", "threads_per_chain")) }) + +test_that("Can process model fits with no covariates", { + local_mocked_bindings( + stan_package_model = mock_model_no_covariates, .package = "instantiate" + ) + mod <- biokinetics$new(file_path = system.file("delta_full.rds", package = "epikinetics")) + + res <- mod$fit(chains = 4, + parallel_chains = 4, + iter_warmup = 10, + iter_sampling = 40, + threads_per_chain = 4) + + pt <- mod$simulate_population_trajectories(n_draws = 100, summarise = FALSE) + expect_equal(names(pt), c("t", ".draw", "t0_pop", "tp_pop", "ts_pop", "m1_pop", "m2_pop", + "m3_pop", "mu", "titre_type")) + + pt <- mod$simulate_population_trajectories(n_draws = 100, summarise = TRUE) + expect_equal(names(pt), c("t", "me", "lo", "hi", "titre_type")) + + it <- mod$simulate_individual_trajectories(n_draws = 100, summarise = FALSE) + expect_equal(names(it), c("stan_id", "draw", "t", "mu", "titre_type", + "exposure_date", "calendar_date", "time_shift")) + + # it <- mod$simulate_individual_trajectories(n_draws = 100, summarise = TRUE) + # expect_equal(names(it), c("calendar_date", "titre_type", "me", "lo", "hi", "time_shift")) + + sp <- mod$population_stationary_points() + expect_equal(names(sp), c("titre_type", "mu_p", "mu_s", "rel_drop_me", "mu_p_me", "mu_s_me")) +}) + +test_that("Can process model fits with multiple covariates", { + local_mocked_bindings( + stan_package_model = mock_model_multiple_covariates, .package = "instantiate" + ) + mod <- biokinetics$new(file_path = system.file("delta_full.rds", package = "epikinetics"), + covariate_formula = ~0 + infection_history + last_vax_type) + + res <- mod$fit(chains = 4, + parallel_chains = 4, + iter_warmup = 10, + iter_sampling = 40, + threads_per_chain = 4) + + pt <- mod$simulate_population_trajectories(n_draws = 100) + expect_equal(names(pt), c("t", "me", "lo", "hi", "titre_type", "infection_history", "last_vax_type")) + + it <- mod$simulate_individual_trajectories(n_draws = 100) + expect_equal(names(it), c("calendar_date", "titre_type", "me", "lo", "hi", "time_shift")) + + sp <- mod$population_stationary_points() + expect_equal(names(sp), c("infection_history", "last_vax_type", "titre_type", + "mu_p", "mu_s", "rel_drop_me", "mu_p_me", "mu_s_me")) +}) diff --git a/tests/testthat/testdata/testdraws_multiplecovariates.rds b/tests/testthat/testdata/testdraws_multiplecovariates.rds new file mode 100644 index 0000000..e45e5f7 Binary files /dev/null and b/tests/testthat/testdata/testdraws_multiplecovariates.rds differ diff --git a/tests/testthat/testdata/testdraws_nocovariates.rds b/tests/testthat/testdata/testdraws_nocovariates.rds new file mode 100644 index 0000000..0df59aa Binary files /dev/null and b/tests/testthat/testdata/testdraws_nocovariates.rds differ diff --git a/vignettes/biokinetics.Rmd b/vignettes/biokinetics.Rmd index 88e7d42..290e115 100644 --- a/vignettes/biokinetics.Rmd +++ b/vignettes/biokinetics.Rmd @@ -21,7 +21,7 @@ This vignette demonstrates how to use the `epikinetics` package to replicate som # Fitting the model -To initialise a model object, the only required argument is a path to the data in CSV format, or a [data.table]("https://cran.r-project.org/web/packages/data.table/index.html"). See [biokinetics](../reference/biokinetics.html) for all available arguments. In this vignette we use a dataset representing the Delta wave which is installed with this package, specifying a hierarchical model that just looks at the effect of infection history. +To initialise a model object, the only required argument is a path to the data in CSV format, or a [data.table]("https://cran.r-project.org/web/packages/data.table/index.html"). See [biokinetics](../reference/biokinetics.html) for all available arguments. In this vignette we use a dataset representing the Delta wave which is installed with this package, specifying a regression model that just looks at the effect of infection history. The `fit` method then has the same function signature as the underlying [cmdstanr::sample](https://mc-stan.org/cmdstanr/reference/model-method-sample.html) method. Here we specify a relatively small number of iterations of the algorithm to limit the time it takes to compile this vignette. diff --git a/vignettes/data.Rmd b/vignettes/data.Rmd index af88f7f..92c2924 100644 --- a/vignettes/data.Rmd +++ b/vignettes/data.Rmd @@ -31,7 +31,7 @@ censored: It can also contain further columns for any covariates to be included in the model. The data files installed with this package have additional columns infection_history, last_vax_type, and exp_num. -The model also accepts a covariate formula for a hierarchical model. The variables in the formula must correspond to column names in the dataset. +The model also accepts a covariate formula to define the regression model. The variables in the formula must correspond to column names in the dataset. Note that all variables will be treated as **categorical variables**; that is, converted to factors regardless of their input type. # Ouput data