Skip to content

Commit

Permalink
rename args, fix warnings, test behaviour of strict limits
Browse files Browse the repository at this point in the history
  • Loading branch information
hillalex committed Nov 19, 2024
1 parent 1004715 commit ccaeb9c
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 134 deletions.
134 changes: 71 additions & 63 deletions R/biokinetics.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ biokinetics <- R6::R6Class(
scale = NULL,
priors = NULL,
preds_sd = NULL,
upper_detection_limit = NULL,
lower_detection_limit = NULL,
upper_censoring_limit = NULL,
lower_censoring_limit = NULL,
smallest_value = NULL,
data = NULL,
covariate_formula = NULL,
Expand All @@ -29,28 +29,18 @@ biokinetics <- R6::R6Class(
stop("Model has not been fitted yet. Call 'fit' before calling this function.")
}
},
get_upper_detection_limit = function(upper_detection_limit) {
if (is.null(upper_detection_limit)) {
private$upper_detection_limit <- max(private$data$value)
get_upper_censoring_limit = function(upper_censoring_limit) {
if (is.null(upper_censoring_limit)) {
private$upper_censoring_limit <- max(private$data$value)
} else {
max_value <- max(private$data$value)
if (max_value >= upper_detection_limit) {
warning(sprintf("Data contains values >= the upper detection limit %s. These will be censored.",
upper_detection_limit))
}
private$upper_detection_limit <- upper_detection_limit
private$upper_censoring_limit <- upper_censoring_limit
}
},
get_lower_detection_limit = function(lower_detection_limit) {
if (is.null(lower_detection_limit)) {
private$lower_detection_limit <- min(private$data$value)
get_lower_censoring_limit = function(lower_censoring_limit) {
if (is.null(lower_censoring_limit)) {
private$lower_censoring_limit <- min(private$data$value)
} else {
min_value <- min(private$data$value)
if (min_value <= lower_detection_limit) {
warning(sprintf("Data contains values <= the lower detection limit %s. These will be censored.",
lower_detection_limit))
}
private$lower_detection_limit <- lower_detection_limit
private$lower_censoring_limit <- lower_censoring_limit
}
},
model_matrix_with_dummy = function(data) {
Expand Down Expand Up @@ -204,11 +194,11 @@ biokinetics <- R6::R6Class(
stan_data$P <- ncol(private$design_matrix)
if (private$scale == "natural") {
# do the same transformation as used on the data
stan_data$upper_detection_limit <- log2(private$upper_detection_limit / private$smallest_value)
stan_data$lower_detection_limit <- log2(private$lower_detection_limit / private$smallest_value)
stan_data$upper_censoring_limit <- log2(private$upper_censoring_limit / private$smallest_value)
stan_data$lower_censoring_limit <- log2(private$lower_censoring_limit / private$smallest_value)
} else {
stan_data$upper_detection_limit <- private$upper_detection_limit
stan_data$lower_detection_limit <- private$lower_detection_limit
stan_data$upper_censoring_limit <- private$upper_censoring_limit
stan_data$lower_censoring_limit <- private$lower_censoring_limit
}
private$stan_input_data <- c(stan_data, private$priors)
},
Expand Down Expand Up @@ -248,24 +238,24 @@ biokinetics <- R6::R6Class(
#' @param preds_sd Standard deviation of predictor coefficients. Default 0.25.
#' @param scale One of "log" or "natural". Default "natural". Is provided data on a log or a natural scale? If on a natural scale it
#' will be converted to a log scale for model fitting.
#' @param upper_detection_limit Optional upper detection limit of the titre used. This is needed to construct a likelihood for upper censored
#' values, so only needs to be provided if you have such values in the dataset. If not provided, the model will default
#' to using the largest value in the dataset as the upper detection limit.
#' @param lower_detection_limit Optional lower detection limit of the titre used. This is needed to construct a likelihood for lower censored
#' values, so only needs to be provided if you have such values in the dataset. If not provided, the model will default
#' to using the smallest value in the dataset as the lower detection limit.
#' @param truncate_upper Logical. Whether values greater than the upper detection limit should be truncated, i.e. set to the same value as the limit. Default TRUE.
#' @param truncate_lower Logical. Whether values smaller than the lower detection limit should be truncated, i.e. set to the same value as the limit. Default TRUE.
#' @param upper_censoring_limit Optional value at which to upper censor data points. This is needed to construct a likelihood for upper censored
#' values, so only needs to be provided if you have such values in the dataset. If not provided, no censoring will be done.
#' @param lower_censoring_limit Optional value at which to lower censor data points. This is needed to construct a likelihood for lower censored
#' values, so only needs to be provided if you have such values in the dataset. If not provided, no censoring will be done.
#' @param strict_upper_limit Logical. Whether values greater than the upper censoring limit should be censored.
#' If FALSE, only values exactly equal to the upper censoring limit will be censored. Default TRUE.
#' @param strict_lower_limit Logical. Whether values smaller than the lower censoring limit should be censored.
#' If FALSE, only values exactly equal to the lower censoring limit will be censored. Default TRUE.
initialize = function(priors = biokinetics_priors(),
data = NULL,
file_path = NULL,
covariate_formula = ~0,
preds_sd = 0.25,
scale = "natural",
upper_detection_limit = NULL,
lower_detection_limit = NULL,
truncate_upper = TRUE,
truncate_lower = TRUE) {
upper_censoring_limit = NULL,
lower_censoring_limit = NULL,
strict_upper_limit = TRUE,
strict_lower_limit = TRUE) {
validate_priors(priors)
private$priors <- priors
validate_numeric(preds_sd)
Expand All @@ -291,18 +281,36 @@ biokinetics <- R6::R6Class(
validate_required_cols(private$data)
validate_formula_vars(private$all_formula_vars, private$data)
logger::log_info("Preparing data for stan")
private$get_upper_detection_limit(upper_detection_limit)
private$get_lower_detection_limit(lower_detection_limit)
if (truncate_upper) {
private$data[, value := ifelse(value > private$upper_detection_limit, private$upper_detection_limit, value)]
}
if (truncate_lower) {
private$data[, value := ifelse(value < private$lower_detection_limit, private$lower_detection_limit, value)]
private$get_upper_censoring_limit(upper_censoring_limit)
private$get_lower_censoring_limit(lower_censoring_limit)
max_value <- max(private$data$value)
min_value <- min(private$data$value)
values_above <- max_value > private$upper_censoring_limit
values_below <- min_value < private$lower_censoring_limit
if (strict_upper_limit) {
if (values_above) {
warning(sprintf("Data contains values above the upper censoring limit %s and these will be censored. To turn off this behaviour set strict_upper_limit to FALSE.",
private$upper_censoring_limit))
}
private$data[, value := ifelse(value > private$upper_censoring_limit, private$upper_censoring_limit, value)]
} else if (values_above) {
warning(sprintf("Data contains values above the upper censoring limit %s. To treat these as censored set strict_upper_limit to TRUE.",
private$upper_censoring_limit))
}
if (strict_lower_limit) {
if (values_below) {
warning(sprintf("Data contains values below the lower censoring limit %s and these will be censored. To turn off this behaviour set strict_lower_limit to FALSE.",
private$lower_censoring_limit))
}
private$data[, value := ifelse(value < private$lower_censoring_limit, private$lower_censoring_limit, value)]
} else if (values_below) {
warning(sprintf("Data contains values below the lower censoring limit %s. To treat these as censored set strict_lower_limit to TRUE.",
private$lower_censoring_limit))
}
private$data[, `:=`(obs_id = seq_len(.N),
time_since_last_exp = as.integer(day - last_exp_day, units = "days"),
censored_lo = value == private$lower_detection_limit,
censored_hi = value == private$upper_detection_limit)]
censored_lo = value == private$lower_censoring_limit,
censored_hi = value == private$upper_censoring_limit)]
private$construct_design_matrix()
private$build_covariate_lookup_table()
private$build_pid_lookup()
Expand Down Expand Up @@ -332,8 +340,8 @@ biokinetics <- R6::R6Class(
tmax = tmax,
n_draws = n_draws,
data = private$data,
upper_detection_limit = private$stan_input_data$upper_detection_limit,
lower_detection_limit = private$stan_input_data$lower_detection_limit)
upper_censoring_limit = private$stan_input_data$upper_censoring_limit,
lower_censoring_limit = private$stan_input_data$lower_censoring_limit)
},
#' @description Plot model input data with a smoothing function. Note that
#' this plot is of the data as provided to the Stan model so is on a log scale,
Expand All @@ -344,32 +352,32 @@ biokinetics <- R6::R6Class(
plot_sero_data(private$data,
tmax = tmax,
covariates = private$all_formula_vars,
upper_detection_limit = private$stan_input_data$upper_detection_limit,
lower_detection_limit = private$stan_input_data$lower_detection_limit)
upper_censoring_limit = private$stan_input_data$upper_censoring_limit,
lower_censoring_limit = private$stan_input_data$lower_censoring_limit)
},
#' @description View the data that is passed to the stan model, for debugging purposes.
#' @return A list of arguments that will be passed to the stan model.
#' @description View the data that is passed to the stan model, for debugging purposes.
#' @return A list of arguments that will be passed to the stan model.
get_stan_data = function() {
private$stan_input_data
},
#' @description View the mapping of human readable covariate names to the model variable p.
#' @return A data.table mapping the model variable p to human readable covariates.
#' @description View the mapping of human readable covariate names to the model variable p.
#' @return A data.table mapping the model variable p to human readable covariates.
get_covariate_lookup_table = function() {
private$covariate_lookup_table
},
#' @description Fit the model and return CmdStanMCMC fitted model object.
#' @return A CmdStanMCMC fitted model object: <https://mc-stan.org/cmdstanr/reference/CmdStanMCMC.html>
#' @param ... Named arguments to the `sample()` method of CmdStan model.
#' objects: <https://mc-stan.org/cmdstanr/reference/model-method-sample.html>
#' @description Fit the model and return CmdStanMCMC fitted model object.
#' @return A CmdStanMCMC fitted model object: <https://mc-stan.org/cmdstanr/reference/CmdStanMCMC.html>
#' @param ... Named arguments to the `sample()` method of CmdStan model.
#' objects: <https://mc-stan.org/cmdstanr/reference/model-method-sample.html>
fit = function(...) {
logger::log_info("Fitting model")
private$fitted <- private$model$sample(private$stan_input_data, ...)
private$fitted
},
#' @description Extract fitted population parameters
#' @return A data.table
#' @param n_draws Integer. Default 2000.
#' @param human_readable_covariates Logical. Default TRUE.
#' @description Extract fitted population parameters
#' @return A data.table
#' @param n_draws Integer. Default 2000.
#' @param human_readable_covariates Logical. Default TRUE.
extract_population_parameters = function(n_draws = 2000,
human_readable_covariates = TRUE) {
private$check_fitted()
Expand Down Expand Up @@ -489,8 +497,8 @@ biokinetics <- R6::R6Class(
attr(dt_out, "summarised") <- summarise
attr(dt_out, "scale") <- private$scale
attr(dt_out, "covariates") <- private$all_formula_vars
attr(dt_out, "upper_detection_limit") <- private$upper_detection_limit
attr(dt_out, "lower_detection_limit") <- private$lower_detection_limit
attr(dt_out, "upper_censoring_limit") <- private$upper_censoring_limit
attr(dt_out, "lower_censoring_limit") <- private$lower_censoring_limit

dt_out
},
Expand Down
40 changes: 20 additions & 20 deletions R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
#' @param tmax Integer. The number of time points in each simulated trajectory. Default 150.
#' @param n_draws Integer. The number of trajectories to simulate. Default 2000.
#' @param data Optional data.frame with columns time_since_last_exp and value. The raw data to compare to.
#' @param upper_detection_limit Optional upper detection limit.
#' @param lower_detection_limit Optional lower detection limit.
#' @param upper_censoring_limit Optional upper detection limit.
#' @param lower_censoring_limit Optional lower detection limit.
plot.biokinetics_priors <- function(x,
...,
tmax = 150,
n_draws = 2000,
data = NULL,
upper_detection_limit = NULL,
lower_detection_limit = NULL) {
upper_censoring_limit = NULL,
lower_censoring_limit = NULL) {

# Declare variables to suppress notes when compiling package
# https://github.com/Rdatatable/data.table/issues/850#issuecomment-259466153
Expand Down Expand Up @@ -57,7 +57,7 @@ plot.biokinetics_priors <- function(x,
aes(x = time_since_last_exp,
y = value))

plot <- add_limits(plot, upper_detection_limit, lower_detection_limit)
plot <- add_limits(plot, upper_censoring_limit, lower_censoring_limit)
}
plot
}
Expand All @@ -70,13 +70,13 @@ plot.biokinetics_priors <- function(x,
#' @param data A data.table with required columns time_since_last_exp, value and titre_type.
#' @param tmax Integer. The number of time points in each simulated trajectory. Default 150.
#' @param covariates Optional vector of covariate names to facet by (these must correspond to columns in the data.table)
#' @param upper_detection_limit Optional upper detection limit.
#' @param lower_detection_limit Optional lower detection limit.
#' @param upper_censoring_limit Optional upper detection limit.
#' @param lower_censoring_limit Optional lower detection limit.
plot_sero_data <- function(data,
tmax = 150,
covariates = character(0),
upper_detection_limit = NULL,
lower_detection_limit = NULL) {
upper_censoring_limit = NULL,
lower_censoring_limit = NULL) {
validate_required_cols(data, c("time_since_last_exp", "value", "titre_type"))
data <- data[time_since_last_exp <= tmax,]
# Declare variables to suppress notes when compiling package
Expand All @@ -90,7 +90,7 @@ plot_sero_data <- function(data,
facet_wrap(eval(parse(text = facet_formula(covariates)))) +
guides(colour = guide_legend(title = "Titre type"))

add_limits(plot, upper_detection_limit, lower_detection_limit)
add_limits(plot, upper_censoring_limit, lower_censoring_limit)
}

#' Plot method for "biokinetics_population_trajectories" class
Expand All @@ -104,8 +104,8 @@ plot_sero_data <- function(data,
plot.biokinetics_population_trajectories <- function(x, ...,
data = NULL) {
covariates <- attr(x, "covariates")
upper_detection_limit <- attr(x, "upper_detection_limit")
lower_detection_limit <- attr(x, "lower_detection_limit")
upper_censoring_limit <- attr(x, "upper_censoring_limit")
lower_censoring_limit <- attr(x, "lower_censoring_limit")

# Declare variables to suppress notes when compiling package
# https://github.com/Rdatatable/data.table/issues/850#issuecomment-259466153
Expand Down Expand Up @@ -139,7 +139,7 @@ plot.biokinetics_population_trajectories <- function(x, ...,
guides(fill = guide_legend(title = "Titre type"),
colour = "none")
if (!is.null(data)) {
plot <- add_limits(plot, upper_detection_limit, lower_detection_limit)
plot <- add_limits(plot, upper_censoring_limit, lower_censoring_limit)
}
plot
}
Expand All @@ -148,24 +148,24 @@ facet_formula <- function(covariates) {
paste("~", paste(c("titre_type", covariates), collapse = "+"))
}

add_limits <- function(plot, upper_detection_limit, lower_detection_limit) {
if (!is.null(lower_detection_limit)) {
add_limits <- function(plot, upper_censoring_limit, lower_censoring_limit) {
if (!is.null(lower_censoring_limit)) {
plot <- plot +
geom_hline(yintercept = lower_detection_limit,
geom_hline(yintercept = lower_censoring_limit,
linetype = 'dotted') +
annotate("text", x = 1,
y = lower_detection_limit,
y = lower_censoring_limit,
label = "Lower detection limit",
vjust = -0.5,
hjust = 0,
size = 3)
}
if (!is.null(upper_detection_limit)) {
if (!is.null(upper_censoring_limit)) {
plot <- plot +
geom_hline(yintercept = upper_detection_limit,
geom_hline(yintercept = upper_censoring_limit,
linetype = 'dotted') +
annotate("text", x = 1,
y = upper_detection_limit,
y = upper_censoring_limit,
label = "Upper detection limit",
vjust = -0.5,
hjust = 0,
Expand Down
Loading

0 comments on commit ccaeb9c

Please sign in to comment.