Skip to content

Commit

Permalink
Allow simulation to be restored with new interventions. (#286)
Browse files Browse the repository at this point in the history
This uses a few of individual's new features that make restoring more
flexible. It also fixes a bug when restoring the mosquitto solvers, by
correctly restoring the timestep, which is needed to model seasonality.

The intervention events use the new `restore = FALSE` flag to make sure
their schedule can be modified when resuming. Instead of having the
events re-schedule themselves everytime they fire, we setup the entire
schedule upfront when initialising the simulation.

It adds end-to-end testing of this feature, across a range of scenarios.
For each scenario, the outcomes of the simulation with and without
restoring are compared and we make sure they are equivalent.
  • Loading branch information
plietar committed May 1, 2024
1 parent 9debd7f commit e534db7
Show file tree
Hide file tree
Showing 18 changed files with 217 additions and 146 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ Remotes:
Additional_repositories:
https://mrc-ide.r-universe.dev
Imports:
individual (>= 0.1.15),
individual (>= 0.1.16),
malariaEquilibrium (>= 1.0.1),
Rcpp,
statmod,
Expand Down
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ solver_get_states <- function(solver) {
.Call(`_malariasimulation_solver_get_states`, solver)
}

solver_set_states <- function(solver, state) {
invisible(.Call(`_malariasimulation_solver_set_states`, solver, state))
solver_set_states <- function(solver, t, state) {
invisible(.Call(`_malariasimulation_solver_set_states`, solver, t, state))
}

solver_step <- function(solver) {
Expand Down
8 changes: 4 additions & 4 deletions R/compartmental.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ Solver <- R6::R6Class(
save_state = function() {
solver_get_states(private$.solver)
},
restore_state = function(state) {
solver_set_states(private$.solver, state)
restore_state = function(t, state) {
solver_set_states(private$.solver, t, state)
}
)
)
Expand All @@ -173,7 +173,7 @@ AquaticMosquitoModel <- R6::R6Class(
# state of the ODE is stored separately). We still provide these methods to
# conform to the expected interface.
save_state = function() { NULL },
restore_state = function(state) { }
restore_state = function(t, state) { }
)
)

Expand All @@ -187,7 +187,7 @@ AdultMosquitoModel <- R6::R6Class(
save_state = function() {
adult_mosquito_model_save_state(self$.model)
},
restore_state = function(state) {
restore_state = function(t, state) {
adult_mosquito_model_restore_state(self$.model, state)
}
)
Expand Down
9 changes: 7 additions & 2 deletions R/correlation.R
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,16 @@ CorrelationParameters <- R6::R6Class(
},

#' @description Restore the correlation state.
#'
#' Only the randomly drawn weights are restored. The object needs to be
#' initialized with the same rhos.
#'
#' @param timestep the timestep at which simulation is resumed. This
#' parameter's value is ignored, it only exists to conform to a uniform
#' interface.
#' @param state a previously saved correlation state, as returned by the
#' save_state method.
restore_state = function(state) {
#' save_state method.
restore_state = function(timestep, state) {
private$.mvnorm <- state$mvnorm
}
)
Expand Down
18 changes: 8 additions & 10 deletions R/events.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
create_events <- function(parameters) {
events <- list(
# MDA events
mda_administer = individual::Event$new(),
smc_administer = individual::Event$new(),
mda_administer = individual::Event$new(restore=FALSE),
smc_administer = individual::Event$new(restore=FALSE),

# TBV event
tbv_vaccination = individual::Event$new(),
tbv_vaccination = individual::Event$new(restore=FALSE),

# Bednet events
throw_away_net = individual::TargetedEvent$new(parameters$human_population)
Expand All @@ -21,7 +21,7 @@ create_events <- function(parameters) {
seq_along(parameters$mass_pev_booster_spacing),
function(.) individual::TargetedEvent$new(parameters$human_population)
)
events$mass_pev <- individual::Event$new()
events$mass_pev <- individual::Event$new(restore=FALSE)
events$mass_pev_doses <- mass_pev_doses
events$mass_pev_boosters <- mass_pev_boosters
}
Expand Down Expand Up @@ -63,16 +63,16 @@ initialise_events <- function(events, variables, parameters) {

# Initialise scheduled interventions
if (!is.null(parameters$mass_pev_timesteps)) {
events$mass_pev$schedule(parameters$mass_pev_timesteps[[1]] - 1)
events$mass_pev$schedule(parameters$mass_pev_timesteps - 1)
}
if (parameters$mda) {
events$mda_administer$schedule(parameters$mda_timesteps[[1]] - 1)
events$mda_administer$schedule(parameters$mda_timesteps - 1)
}
if (parameters$smc) {
events$smc_administer$schedule(parameters$smc_timesteps[[1]] - 1)
events$smc_administer$schedule(parameters$smc_timesteps - 1)
}
if (parameters$tbv) {
events$tbv_vaccination$schedule(parameters$tbv_timesteps[[1]] - 1)
events$tbv_vaccination$schedule(parameters$tbv_timesteps - 1)
}
}

Expand Down Expand Up @@ -158,7 +158,6 @@ attach_event_listeners <- function(
if (parameters$mda == 1) {
events$mda_administer$add_listener(create_mda_listeners(
variables,
events$mda_administer,
parameters$mda_drug,
parameters$mda_timesteps,
parameters$mda_coverages,
Expand All @@ -174,7 +173,6 @@ attach_event_listeners <- function(
if (parameters$smc == 1) {
events$smc_administer$add_listener(create_mda_listeners(
variables,
events$smc_administer,
parameters$smc_drug,
parameters$smc_timesteps,
parameters$smc_coverages,
Expand Down
2 changes: 1 addition & 1 deletion R/lag.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ LaggedValue <- R6::R6Class(
timeseries_save_state(private$history)
},

restore_state = function(state) {
restore_state = function(t, state) {
timeseries_restore_state(private$history, state)
}
)
Expand Down
7 changes: 0 additions & 7 deletions R/mda_processes.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#' @title Create listeners for MDA events
#' @param variables the variables available in the model
#' @param administer_event the event schedule for drug administration
#' @param drug the drug to administer
#' @param timesteps timesteps for each round
#' @param coverages the coverage for each round
Expand All @@ -14,7 +13,6 @@
#' @noRd
create_mda_listeners <- function(
variables,
administer_event,
drug,
timesteps,
coverages,
Expand Down Expand Up @@ -78,11 +76,6 @@ create_mda_listeners <- function(
variables$drug$queue_update(drug, to_move)
variables$drug_time$queue_update(timestep, to_move)
}

# Schedule next round
if (time_index < length(timesteps)) {
administer_event$schedule(timesteps[[time_index + 1]] - timestep)
}
}
}

Expand Down
17 changes: 10 additions & 7 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,19 @@ run_resumable_simulation <- function(
lagged_eir <- create_lagged_eir(variables, solvers, parameters)
lagged_infectivity <- create_lagged_infectivity(variables, parameters)

stateful_objects <- unlist(list(
stateful_objects <- list(
RandomState$new(restore_random_state),
correlations,
vector_models,
solvers,
lagged_eir,
lagged_infectivity))
lagged_infectivity)

if (!is.null(initial_state)) {
restore_state(initial_state$malariasimulation, stateful_objects)
individual::restore_object_state(
initial_state$timesteps,
stateful_objects,
initial_state$malariasimulation)
}

individual_state <- individual::simulation_loop(
Expand All @@ -161,16 +164,16 @@ run_resumable_simulation <- function(
timesteps
),
variables = variables,
events = unlist(events),
events = events,
timesteps = timesteps,
state = initial_state$individual,
restore_random_state = restore_random_state
)

final_state <- list(
timesteps=timesteps,
individual=individual_state,
malariasimulation=save_state(stateful_objects)
timesteps = timesteps,
individual = individual_state,
malariasimulation = individual::save_object_state(stateful_objects)
)

data <- renderer$to_dataframe()
Expand Down
5 changes: 0 additions & 5 deletions R/pev.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,6 @@ create_mass_pev_listener <- function(
parameters,
events$mass_pev_doses
)
if (time_index < length(parameters$mass_pev_timesteps)) {
events$mass_pev$schedule(
parameters$mass_pev_timesteps[[time_index + 1]] - timestep
)
}
}
}

Expand Down
47 changes: 0 additions & 47 deletions R/stateful.R

This file was deleted.

5 changes: 0 additions & 5 deletions R/tbv.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,6 @@ create_tbv_listener <- function(variables, events, parameters, correlations, ren
to_vaccinate
)
}
if (time_index < length(parameters$tbv_timesteps)) {
events$tbv_vaccination$schedule(
parameters$tbv_timesteps[[time_index + 1]] - timestep
)
}
}
}

Expand Down
24 changes: 24 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,27 @@ rtexp <- function(n, m, t) { itexp(runif(n), m, t) }
match_timestep <- function(ts, t) {
min(sum(ts <= t), length(ts))
}

#' @title a placeholder class to save the random number generator class.
#' @description the class integrates with the simulation loop to save and
#' restore the random number generator class when appropriate.
#' @noRd
RandomState <- R6::R6Class(
'RandomState',
private = list(
restore_random_state = NULL
),
public = list(
initialize = function(restore_random_state) {
private$restore_random_state <- restore_random_state
},
save_state = function() {
random_save_state()
},
restore_state = function(t, state) {
if (private$restore_random_state) {
random_restore_state(state)
}
}
)
)
7 changes: 6 additions & 1 deletion man/CorrelationParameters.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 5 additions & 4 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,14 @@ BEGIN_RCPP
END_RCPP
}
// solver_set_states
void solver_set_states(Rcpp::XPtr<Solver> solver, std::vector<double> state);
RcppExport SEXP _malariasimulation_solver_set_states(SEXP solverSEXP, SEXP stateSEXP) {
void solver_set_states(Rcpp::XPtr<Solver> solver, double t, std::vector<double> state);
RcppExport SEXP _malariasimulation_solver_set_states(SEXP solverSEXP, SEXP tSEXP, SEXP stateSEXP) {
BEGIN_RCPP
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< Rcpp::XPtr<Solver> >::type solver(solverSEXP);
Rcpp::traits::input_parameter< double >::type t(tSEXP);
Rcpp::traits::input_parameter< std::vector<double> >::type state(stateSEXP);
solver_set_states(solver, state);
solver_set_states(solver, t, state);
return R_NilValue;
END_RCPP
}
Expand Down Expand Up @@ -363,7 +364,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_malariasimulation_rainfall", (DL_FUNC) &_malariasimulation_rainfall, 5},
{"_malariasimulation_exponential_process_cpp", (DL_FUNC) &_malariasimulation_exponential_process_cpp, 2},
{"_malariasimulation_solver_get_states", (DL_FUNC) &_malariasimulation_solver_get_states, 1},
{"_malariasimulation_solver_set_states", (DL_FUNC) &_malariasimulation_solver_set_states, 2},
{"_malariasimulation_solver_set_states", (DL_FUNC) &_malariasimulation_solver_set_states, 3},
{"_malariasimulation_solver_step", (DL_FUNC) &_malariasimulation_solver_step, 1},
{"_malariasimulation_create_timeseries", (DL_FUNC) &_malariasimulation_create_timeseries, 2},
{"_malariasimulation_timeseries_at", (DL_FUNC) &_malariasimulation_timeseries_at, 3},
Expand Down
3 changes: 2 additions & 1 deletion src/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ std::vector<double> solver_get_states(Rcpp::XPtr<Solver> solver) {
}

//[[Rcpp::export]]
void solver_set_states(Rcpp::XPtr<Solver> solver, std::vector<double> state) {
void solver_set_states(Rcpp::XPtr<Solver> solver, double t, std::vector<double> state) {
solver->t = t;
solver->state = state;
}

Expand Down
Loading

0 comments on commit e534db7

Please sign in to comment.