From c42e2c52cee9f70e6067f8e6474e62f8428bcae5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul=20Li=C3=A9tar?= Date: Tue, 27 Feb 2024 20:02:49 +0000 Subject: [PATCH] Add initial support for saving and restoring simulation state. (#280) The model can now be run for a given number of steps, have its state saved, and then restore it and run for some more time steps. Parameters of the model may be modified when resuming, allowing a simulation to be forked with a single historical run and many futures, modelling different intervention scenarios. There are some limitations as to which parameters may be modified. In general, the structure of the simulation must remain the same, with the same variables and same events. This means interventions cannot be enabled or disabled when resuming, only their parameterization can change. Additionally, the timing of intervention events should not be modified. These limitations may be relaxed in the future, if there is a need for it. To avoid changing the existing API, this feature is available as a new `run_resumable_simulation` function. The function returns a pair of values, the existing dataframe with the simulation data and an object representing the internal simulation state. The function can be called a second time, passing the saved simulation state as an extra argument. See the `test-resume.R` file for usage of the new function. The implementation builds on top of individual's new support for this. Individual already saves the state of its native objects, ie. variables and events. The malaria model keeps quite a bit of state outside of individual, which we need to save and restore explicitly. This is done by creating a set of "stateful objects", ie. R6 objects with a pair `save_state` and `restore_state` methods. This interface is implemented by every bit of the model that has state to capture: - `LaggedValue` objects store the short term EIR and FOIM history. - `Solver` objects represent the current state of ODE solvers. - Adult mosquito ODEs keep a history of incubating values which need to be restored. - `CorrelationParameters` stores a randomly sampled value. This needs to be saved to ensure the simulation is resumed with that same value. - In addition to R's native RNG (whose state is already saved by individual), the model uses the dqrng library, whose state needs saving. --- DESCRIPTION | 6 +- R/RcppExports.R | 28 +++++++ R/biting_process.R | 8 +- R/compartmental.R | 81 +++++++++++++++++--- R/correlation.R | 19 +++++ R/lag.R | 8 ++ R/model.R | 64 ++++++++++++++-- R/mosquito_biology.R | 2 +- R/render.R | 2 +- R/stateful.R | 47 ++++++++++++ man/CorrelationParameters.Rd | 32 ++++++++ man/run_resumable_simulation.Rd | 34 +++++++++ src/Random.cpp | 15 +++- src/Random.h | 3 + src/RcppExports.cpp | 84 ++++++++++++++++++++- src/adult_mosquito_eqs.cpp | 24 +++++- src/adult_mosquito_eqs.h | 2 +- src/solver.cpp | 5 ++ src/timeseries.cpp | 51 +++++++++++-- src/timeseries.h | 5 +- src/utils.cpp | 10 +++ tests/testthat/helper-integration.R | 7 ++ tests/testthat/test-biting-integration.R | 2 +- tests/testthat/test-compartmental.R | 39 +++++++--- tests/testthat/test-emergence-integration.R | 33 ++++---- tests/testthat/test-resume.R | 41 ++++++++++ tests/testthat/test-seasonality.R | 6 +- 27 files changed, 583 insertions(+), 75 deletions(-) create mode 100644 R/stateful.R create mode 100644 man/run_resumable_simulation.Rd create mode 100644 tests/testthat/test-resume.R diff --git a/DESCRIPTION b/DESCRIPTION index 1d971723..ce25ce05 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -69,13 +69,15 @@ LazyData: true Remotes: mrc-ide/malariaEquilibrium, mrc-ide/individual +Additional_repositories: + https://mrc-ide.r-universe.dev Imports: - individual (>= 0.1.7), + individual (>= 0.1.13), malariaEquilibrium (>= 1.0.1), Rcpp, statmod, MASS, - dqrng, + dqrng (>= 0.3.2.2), sitmo, BH, R6, diff --git a/R/RcppExports.R b/R/RcppExports.R index 73a51dd0..01f3dc11 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -9,6 +9,14 @@ adult_mosquito_model_update <- function(model, mu, foim, susceptible, f) { invisible(.Call(`_malariasimulation_adult_mosquito_model_update`, model, mu, foim, susceptible, f)) } +adult_mosquito_model_save_state <- function(model) { + .Call(`_malariasimulation_adult_mosquito_model_save_state`, model) +} + +adult_mosquito_model_restore_state <- function(model, state) { + invisible(.Call(`_malariasimulation_adult_mosquito_model_restore_state`, model, state)) +} + create_adult_solver <- function(model, init, r_tol, a_tol, max_steps) { .Call(`_malariasimulation_create_adult_solver`, model, init, r_tol, a_tol, max_steps) } @@ -41,6 +49,10 @@ solver_get_states <- function(solver) { .Call(`_malariasimulation_solver_get_states`, solver) } +solver_set_states <- function(solver, state) { + invisible(.Call(`_malariasimulation_solver_set_states`, solver, state)) +} + solver_step <- function(solver) { invisible(.Call(`_malariasimulation_solver_step`, solver)) } @@ -57,10 +69,26 @@ timeseries_push <- function(timeseries, value, timestep) { invisible(.Call(`_malariasimulation_timeseries_push`, timeseries, value, timestep)) } +timeseries_save_state <- function(timeseries) { + .Call(`_malariasimulation_timeseries_save_state`, timeseries) +} + +timeseries_restore_state <- function(timeseries, state) { + invisible(.Call(`_malariasimulation_timeseries_restore_state`, timeseries, state)) +} + random_seed <- function(seed) { invisible(.Call(`_malariasimulation_random_seed`, seed)) } +random_save_state <- function() { + .Call(`_malariasimulation_random_save_state`) +} + +random_restore_state <- function(state) { + invisible(.Call(`_malariasimulation_random_restore_state`, state)) +} + bernoulli_multi_p_cpp <- function(p) { .Call(`_malariasimulation_bernoulli_multi_p_cpp`, p) } diff --git a/R/biting_process.R b/R/biting_process.R index ec09620d..801aea20 100644 --- a/R/biting_process.R +++ b/R/biting_process.R @@ -103,7 +103,7 @@ simulate_bites <- function( for (s_i in seq_along(parameters$species)) { species_name <- parameters$species[[s_i]] - solver_states <- solver_get_states(solvers[[s_i]]) + solver_states <- solvers[[s_i]]$get_states() p_bitten <- prob_bitten(timestep, variables, s_i, parameters) Q0 <- parameters$Q0[[s_i]] W <- average_p_successful(p_bitten$prob_bitten_survives, .pi, Q0) @@ -167,7 +167,7 @@ simulate_bites <- function( if (parameters$individual_mosquitoes) { # update the ODE with stats for ovoposition calculations aquatic_mosquito_model_update( - models[[s_i]], + models[[s_i]]$.model, species_index$size(), f, mu @@ -189,7 +189,7 @@ simulate_bites <- function( ) } else { adult_mosquito_model_update( - models[[s_i]], + models[[s_i]]$.model, mu, foim, solver_states[[ADULT_ODE_INDICES['Sm']]], @@ -235,7 +235,7 @@ calculate_infectious <- function(species, solvers, variables, parameters) { ) ) } - calculate_infectious_compartmental(solver_get_states(solvers[[species]])) + calculate_infectious_compartmental(solvers[[species]]$get_states()) } calculate_infectious_individual <- function( diff --git a/R/compartmental.R b/R/compartmental.R index f3405728..7df83a6e 100644 --- a/R/compartmental.R +++ b/R/compartmental.R @@ -50,16 +50,16 @@ parameterise_mosquito_models <- function(parameters, timesteps) { m )[ADULT_ODE_INDICES['Sm']] return( - create_adult_mosquito_model( + AdultMosquitoModel$new(create_adult_mosquito_model( growth_model, parameters$mum[[i]], parameters$dem, susceptible * parameters$init_foim, parameters$init_foim - ) + )) ) } - growth_model + AquaticMosquitoModel$new(growth_model) } ) } @@ -72,22 +72,22 @@ parameterise_solvers <- function(models, parameters) { init <- initial_mosquito_counts(parameters, i, parameters$init_foim, m) if (!parameters$individual_mosquitoes) { return( - create_adult_solver( - models[[i]], + Solver$new(create_adult_solver( + models[[i]]$.model, init, parameters$r_tol, parameters$a_tol, parameters$ode_max_steps - ) + )) ) } - create_aquatic_solver( - models[[i]], + Solver$new(create_aquatic_solver( + models[[i]]$.model, init[ODE_INDICES], parameters$r_tol, parameters$a_tol, parameters$ode_max_steps - ) + )) } ) } @@ -103,7 +103,7 @@ create_compartmental_rendering_process <- function(renderer, solvers, parameters counts <- rep(0, length(indices)) for (s_i in seq_along(solvers)) { if (parameters$species_proportions[[s_i]] > 0) { - row <- solver_get_states(solvers[[s_i]]) + row <- solvers[[s_i]]$get_states() } else { row <- rep(0, length(indices)) } @@ -128,8 +128,67 @@ create_solver_stepping_process <- function(solvers, parameters) { function(timestep) { for (i in seq_along(solvers)) { if (parameters$species_proportions[[i]] > 0) { - solver_step(solvers[[i]]) + solvers[[i]]$step() } } } } + +Solver <- R6::R6Class( + 'Solver', + private = list( + .solver = NULL + ), + public = list( + initialize = function(solver) { + private$.solver <- solver + }, + step = function() { + solver_step(private$.solver) + }, + get_states = function() { + solver_get_states(private$.solver) + }, + + # This is the same as `get_states`, just exposed under the interface that + # is expected of stateful objects. + save_state = function() { + solver_get_states(private$.solver) + }, + restore_state = function(state) { + solver_set_states(private$.solver, state) + } + ) +) + +AquaticMosquitoModel <- R6::R6Class( + 'AquaticMosquitoModel', + public = list( + .model = NULL, + initialize = function(model) { + self$.model <- model + }, + + # The aquatic mosquito model doesn't have any state to save or restore (the + # state of the ODE is stored separately). We still provide these methods to + # conform to the expected interface. + save_state = function() { NULL }, + restore_state = function(state) { } + ) +) + +AdultMosquitoModel <- R6::R6Class( + 'AdultMosquitoModel', + public = list( + .model = NULL, + initialize = function(model) { + self$.model <- model + }, + save_state = function() { + adult_mosquito_model_save_state(self$.model) + }, + restore_state = function(state) { + adult_mosquito_model_restore_state(self$.model, state) + } + ) +) diff --git a/R/correlation.R b/R/correlation.R index 68119369..df5f88f5 100644 --- a/R/correlation.R +++ b/R/correlation.R @@ -115,6 +115,25 @@ CorrelationParameters <- R6::R6Class( dimnames(private$.mvnorm)[[2]] <- private$interventions } private$.mvnorm + }, + + #' @description Save the correlation state. + save_state = function() { + # mvnorm is sampled at random lazily on its first use. We need to save it + # in order to restore the same value when resuming the simulation, + # otherwise we would be drawing a new, probably different, value. + # The rest of the object is derived deterministically from the parameters + # and does not need saving. + list(mvnorm=private$.mvnorm) + }, + + #' @description Restore the correlation state. + #' Only the randomly drawn weights are restored. The object needs to be + #' initialized with the same rhos. + #' @param state a previously saved correlation state, as returned by the + #' save_state method. + restore_state = function(state) { + private$.mvnorm <- state$mvnorm } ) ) diff --git a/R/lag.R b/R/lag.R index 35e47026..ca858ac7 100644 --- a/R/lag.R +++ b/R/lag.R @@ -14,6 +14,14 @@ LaggedValue <- R6::R6Class( get = function(timestep) { timeseries_at(private$history, timestep, TRUE) + }, + + save_state = function() { + timeseries_save_state(private$history) + }, + + restore_state = function(state) { + timeseries_restore_state(private$history, state) } ) ) diff --git a/R/model.R b/R/model.R index 4c4f7c2b..34bbfb46 100644 --- a/R/model.R +++ b/R/model.R @@ -87,6 +87,28 @@ run_simulation <- function( timesteps, parameters = NULL, correlations = NULL +) { + run_resumable_simulation(timesteps, parameters, correlations)$data +} + +#' @title Run the simulation in a resumable way +#' +#' @description this function accepts an initial simulation state as an argument, and returns the +#' final state after running all of its timesteps. This allows one run to be resumed, possibly +#' having changed some of the parameters. +#' @param timesteps the timestep at which to stop the simulation +#' @param parameters a named list of parameters to use +#' @param correlations correlation parameters +#' @param initial_state the state from which the simulation is resumed +#' @param restore_random_state if TRUE, restore the random number generator's state from the checkpoint. +#' @return a list with two entries, one for the dataframe of results and one for the final +#' simulation state. +run_resumable_simulation <- function( + timesteps, + parameters = NULL, + correlations = NULL, + initial_state = NULL, + restore_random_state = FALSE ) { random_seed(ceiling(runif(1) * .Machine$integer.max)) if (is.null(parameters)) { @@ -108,7 +130,23 @@ run_simulation <- function( ) vector_models <- parameterise_mosquito_models(parameters, timesteps) solvers <- parameterise_solvers(vector_models, parameters) - individual::simulation_loop( + + lagged_eir <- create_lagged_eir(variables, solvers, parameters) + lagged_infectivity <- create_lagged_infectivity(variables, parameters) + + stateful_objects <- unlist(list( + RandomState$new(restore_random_state), + correlations, + vector_models, + solvers, + lagged_eir, + lagged_infectivity)) + + if (!is.null(initial_state)) { + restore_state(initial_state$malariasimulation, stateful_objects) + } + + individual_state <- individual::simulation_loop( processes = create_processes( renderer, variables, @@ -117,15 +155,31 @@ run_simulation <- function( vector_models, solvers, correlations, - list(create_lagged_eir(variables, solvers, parameters)), - list(create_lagged_infectivity(variables, parameters)), + list(lagged_eir), + list(lagged_infectivity), timesteps ), variables = variables, events = unlist(events), - timesteps = timesteps + timesteps = timesteps, + state = initial_state$individual, + restore_random_state = restore_random_state ) - renderer$to_dataframe() + + final_state <- list( + timesteps=timesteps, + individual=individual_state, + malariasimulation=save_state(stateful_objects) + ) + + data <- renderer$to_dataframe() + if (!is.null(initial_state)) { + # Drop the timesteps we didn't simulate from the data. + # It would just be full of NA. + data <- data[-(1:initial_state$timesteps),] + } + + list(data=data, state=final_state) } #' @title Run a metapopulation model diff --git a/R/mosquito_biology.R b/R/mosquito_biology.R index 381450d9..4225a46f 100644 --- a/R/mosquito_biology.R +++ b/R/mosquito_biology.R @@ -220,7 +220,7 @@ create_mosquito_emergence_process <- function( p_counts <- vnapply( solvers, function(solver) { - solver_get_states(solver)[[ODE_INDICES[['P']]]] + solver$get_states()[[ODE_INDICES[['P']]]] } ) n <- sum(p_counts) * rate diff --git a/R/render.R b/R/render.R index 3fcf4882..60cb5b73 100644 --- a/R/render.R +++ b/R/render.R @@ -146,7 +146,7 @@ create_total_M_renderer_compartmental <- function(renderer, solvers, parameters) function(timestep) { total_M <- 0 for (i in seq_along(solvers)) { - row <- solver_get_states(solvers[[i]]) + row <- solvers[[i]]$get_states() species_M <- sum(row[ADULT_ODE_INDICES]) total_M <- total_M + species_M renderer$render(paste0('total_M_', parameters$species[[i]]), species_M, timestep) diff --git a/R/stateful.R b/R/stateful.R new file mode 100644 index 00000000..05c614a0 --- /dev/null +++ b/R/stateful.R @@ -0,0 +1,47 @@ +#' @title Save the state of a list of \emph{stateful objects} +#' @description The state of each element is saved and stored into a single +#' object, representing them in a way that can be exported and re-used later. +#' @param objects A list of stateful objects to be saved. Stateful objects are +#' instances of R6 classes with a pair of \code{save_state} and +#' \code{restore_state} methods. +#' @noRd +save_state <- function(objects) { + lapply(objects, function(o) o$save_state()) +} + +#' @title Restore the state of a collection of stateful objects +#' @description This is the counterpart of \code{save_state}. Calling it +#' restores the collection of objects back into their original state. +#' @param state A state object, as returned by the \code{save_state} function. +#' @param objects A collection of stateful objects to be restored. +#' @noRd +restore_state <- function(state, objects) { + stopifnot(length(state) == length(objects)) + for (i in seq_along(state)) { + objects[[i]]$restore_state(state[[i]]) + } +} + +#' @title a placeholder class to save the random number generator class. +#' @description the class integrates with the simulation loop to save and +#' restore the random number generator class when appropriate. +#' @noRd +RandomState <- R6::R6Class( + 'RandomState', + private = list( + restore_random_state = NULL + ), + public = list( + initialize = function(restore_random_state) { + private$restore_random_state <- restore_random_state + }, + save_state = function() { + random_save_state() + }, + restore_state = function(state) { + if (private$restore_random_state) { + random_restore_state(state) + } + } + ) +) diff --git a/man/CorrelationParameters.Rd b/man/CorrelationParameters.Rd index c2e6ada7..480d663d 100644 --- a/man/CorrelationParameters.Rd +++ b/man/CorrelationParameters.Rd @@ -19,6 +19,8 @@ Describes an event in the simulation \item \href{#method-CorrelationParameters-inter_intervention_rho}{\code{CorrelationParameters$inter_intervention_rho()}} \item \href{#method-CorrelationParameters-sigma}{\code{CorrelationParameters$sigma()}} \item \href{#method-CorrelationParameters-mvnorm}{\code{CorrelationParameters$mvnorm()}} +\item \href{#method-CorrelationParameters-save_state}{\code{CorrelationParameters$save_state()}} +\item \href{#method-CorrelationParameters-restore_state}{\code{CorrelationParameters$restore_state()}} \item \href{#method-CorrelationParameters-clone}{\code{CorrelationParameters$clone()}} } } @@ -101,6 +103,36 @@ multivariate norm draws for these parameters \if{html}{\out{
}}\preformatted{CorrelationParameters$mvnorm()}\if{html}{\out{
}} } +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-CorrelationParameters-save_state}{}}} +\subsection{Method \code{save_state()}}{ +Save the correlation state. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{CorrelationParameters$save_state()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-CorrelationParameters-restore_state}{}}} +\subsection{Method \code{restore_state()}}{ +Restore the correlation state. +Only the randomly drawn weights are restored. The object needs to be +initialized with the same rhos. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{CorrelationParameters$restore_state(state)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{state}}{a previously saved correlation state, as returned by the +save_state method.} +} +\if{html}{\out{
}} +} } \if{html}{\out{
}} \if{html}{\out{}} diff --git a/man/run_resumable_simulation.Rd b/man/run_resumable_simulation.Rd new file mode 100644 index 00000000..8991fd1c --- /dev/null +++ b/man/run_resumable_simulation.Rd @@ -0,0 +1,34 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/model.R +\name{run_resumable_simulation} +\alias{run_resumable_simulation} +\title{Run the simulation in a resumable way} +\usage{ +run_resumable_simulation( + timesteps, + parameters = NULL, + correlations = NULL, + initial_state = NULL, + restore_random_state = FALSE +) +} +\arguments{ +\item{timesteps}{the timestep at which to stop the simulation} + +\item{parameters}{a named list of parameters to use} + +\item{correlations}{correlation parameters} + +\item{initial_state}{the state from which the simulation is resumed} + +\item{restore_random_state}{if TRUE, restore the random number generator's state from the checkpoint.} +} +\value{ +a list with two entries, one for the dataframe of results and one for the final +simulation state. +} +\description{ +this function accepts an initial simulation state as an argument, and returns the +final state after running all of its timesteps. This allows one run to be resumed, possibly +having changed some of the parameters. +} diff --git a/src/Random.cpp b/src/Random.cpp index fad22963..2906ccc7 100644 --- a/src/Random.cpp +++ b/src/Random.cpp @@ -61,7 +61,7 @@ void Random::prop_sample_bucket( // all probabilities are the same if (heavy == n) { for (auto i = 0; i < size; ++i) { - *result = (*rng)(n); + *result = (*rng)((uint64_t)n); ++result; } return; @@ -122,10 +122,21 @@ void Random::prop_sample_bucket( // sample for (auto i = 0; i < size; ++i) { - size_t bucket = (*rng)(n); + size_t bucket = (*rng)((uint64_t)n); double acceptance = dqrng::uniform01((*rng)()); *result = (acceptance < dividing_probs[bucket]) ? bucket : alternative_index[bucket]; ++result; } } + +std::string Random::save_state() { + std::ostringstream stream; + stream << *rng; + return stream.str(); +} + +void Random::restore_state(std::string state) { + std::istringstream stream(state); + stream >> *rng; +} diff --git a/src/Random.h b/src/Random.h index b6e15039..e796fb9a 100644 --- a/src/Random.h +++ b/src/Random.h @@ -58,6 +58,9 @@ class Random : public RandomInterface { Random(Random &&other) = delete; Random& operator=(const Random &other) = delete; Random& operator=(Random &&other) = delete; + + std::string save_state(); + void restore_state(std::string state); private: Random() : rng(dqrng::generator(42)) {}; }; diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index f5c226fd..7675e8e6 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -40,6 +40,28 @@ BEGIN_RCPP return R_NilValue; END_RCPP } +// adult_mosquito_model_save_state +std::vector adult_mosquito_model_save_state(Rcpp::XPtr model); +RcppExport SEXP _malariasimulation_adult_mosquito_model_save_state(SEXP modelSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< Rcpp::XPtr >::type model(modelSEXP); + rcpp_result_gen = Rcpp::wrap(adult_mosquito_model_save_state(model)); + return rcpp_result_gen; +END_RCPP +} +// adult_mosquito_model_restore_state +void adult_mosquito_model_restore_state(Rcpp::XPtr model, std::vector state); +RcppExport SEXP _malariasimulation_adult_mosquito_model_restore_state(SEXP modelSEXP, SEXP stateSEXP) { +BEGIN_RCPP + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< Rcpp::XPtr >::type model(modelSEXP); + Rcpp::traits::input_parameter< std::vector >::type state(stateSEXP); + adult_mosquito_model_restore_state(model, state); + return R_NilValue; +END_RCPP +} // create_adult_solver Rcpp::XPtr create_adult_solver(Rcpp::XPtr model, std::vector init, double r_tol, double a_tol, size_t max_steps); RcppExport SEXP _malariasimulation_create_adult_solver(SEXP modelSEXP, SEXP initSEXP, SEXP r_tolSEXP, SEXP a_tolSEXP, SEXP max_stepsSEXP) { @@ -168,6 +190,17 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// solver_set_states +void solver_set_states(Rcpp::XPtr solver, std::vector state); +RcppExport SEXP _malariasimulation_solver_set_states(SEXP solverSEXP, SEXP stateSEXP) { +BEGIN_RCPP + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< Rcpp::XPtr >::type solver(solverSEXP); + Rcpp::traits::input_parameter< std::vector >::type state(stateSEXP); + solver_set_states(solver, state); + return R_NilValue; +END_RCPP +} // solver_step void solver_step(Rcpp::XPtr solver); RcppExport SEXP _malariasimulation_solver_step(SEXP solverSEXP) { @@ -215,6 +248,28 @@ BEGIN_RCPP return R_NilValue; END_RCPP } +// timeseries_save_state +Rcpp::List timeseries_save_state(Rcpp::XPtr timeseries); +RcppExport SEXP _malariasimulation_timeseries_save_state(SEXP timeseriesSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< Rcpp::XPtr >::type timeseries(timeseriesSEXP); + rcpp_result_gen = Rcpp::wrap(timeseries_save_state(timeseries)); + return rcpp_result_gen; +END_RCPP +} +// timeseries_restore_state +void timeseries_restore_state(Rcpp::XPtr timeseries, Rcpp::List state); +RcppExport SEXP _malariasimulation_timeseries_restore_state(SEXP timeseriesSEXP, SEXP stateSEXP) { +BEGIN_RCPP + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< Rcpp::XPtr >::type timeseries(timeseriesSEXP); + Rcpp::traits::input_parameter< Rcpp::List >::type state(stateSEXP); + timeseries_restore_state(timeseries, state); + return R_NilValue; +END_RCPP +} // random_seed void random_seed(size_t seed); RcppExport SEXP _malariasimulation_random_seed(SEXP seedSEXP) { @@ -225,6 +280,26 @@ BEGIN_RCPP return R_NilValue; END_RCPP } +// random_save_state +std::string random_save_state(); +RcppExport SEXP _malariasimulation_random_save_state() { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + rcpp_result_gen = Rcpp::wrap(random_save_state()); + return rcpp_result_gen; +END_RCPP +} +// random_restore_state +void random_restore_state(std::string state); +RcppExport SEXP _malariasimulation_random_restore_state(SEXP stateSEXP) { +BEGIN_RCPP + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< std::string >::type state(stateSEXP); + random_restore_state(state); + return R_NilValue; +END_RCPP +} // bernoulli_multi_p_cpp std::vector bernoulli_multi_p_cpp(const std::vector p); RcppExport SEXP _malariasimulation_bernoulli_multi_p_cpp(SEXP pSEXP) { @@ -260,11 +335,13 @@ BEGIN_RCPP END_RCPP } -RcppExport SEXP run_testthat_tests(); +RcppExport SEXP run_testthat_tests(void); static const R_CallMethodDef CallEntries[] = { {"_malariasimulation_create_adult_mosquito_model", (DL_FUNC) &_malariasimulation_create_adult_mosquito_model, 5}, {"_malariasimulation_adult_mosquito_model_update", (DL_FUNC) &_malariasimulation_adult_mosquito_model_update, 5}, + {"_malariasimulation_adult_mosquito_model_save_state", (DL_FUNC) &_malariasimulation_adult_mosquito_model_save_state, 1}, + {"_malariasimulation_adult_mosquito_model_restore_state", (DL_FUNC) &_malariasimulation_adult_mosquito_model_restore_state, 2}, {"_malariasimulation_create_adult_solver", (DL_FUNC) &_malariasimulation_create_adult_solver, 5}, {"_malariasimulation_create_aquatic_mosquito_model", (DL_FUNC) &_malariasimulation_create_aquatic_mosquito_model, 18}, {"_malariasimulation_aquatic_mosquito_model_update", (DL_FUNC) &_malariasimulation_aquatic_mosquito_model_update, 4}, @@ -273,11 +350,16 @@ static const R_CallMethodDef CallEntries[] = { {"_malariasimulation_eggs_laid", (DL_FUNC) &_malariasimulation_eggs_laid, 3}, {"_malariasimulation_rainfall", (DL_FUNC) &_malariasimulation_rainfall, 5}, {"_malariasimulation_solver_get_states", (DL_FUNC) &_malariasimulation_solver_get_states, 1}, + {"_malariasimulation_solver_set_states", (DL_FUNC) &_malariasimulation_solver_set_states, 2}, {"_malariasimulation_solver_step", (DL_FUNC) &_malariasimulation_solver_step, 1}, {"_malariasimulation_create_timeseries", (DL_FUNC) &_malariasimulation_create_timeseries, 2}, {"_malariasimulation_timeseries_at", (DL_FUNC) &_malariasimulation_timeseries_at, 3}, {"_malariasimulation_timeseries_push", (DL_FUNC) &_malariasimulation_timeseries_push, 3}, + {"_malariasimulation_timeseries_save_state", (DL_FUNC) &_malariasimulation_timeseries_save_state, 1}, + {"_malariasimulation_timeseries_restore_state", (DL_FUNC) &_malariasimulation_timeseries_restore_state, 2}, {"_malariasimulation_random_seed", (DL_FUNC) &_malariasimulation_random_seed, 1}, + {"_malariasimulation_random_save_state", (DL_FUNC) &_malariasimulation_random_save_state, 0}, + {"_malariasimulation_random_restore_state", (DL_FUNC) &_malariasimulation_random_restore_state, 1}, {"_malariasimulation_bernoulli_multi_p_cpp", (DL_FUNC) &_malariasimulation_bernoulli_multi_p_cpp, 1}, {"_malariasimulation_bitset_index_cpp", (DL_FUNC) &_malariasimulation_bitset_index_cpp, 2}, {"_malariasimulation_fast_weighted_sample", (DL_FUNC) &_malariasimulation_fast_weighted_sample, 2}, diff --git a/src/adult_mosquito_eqs.cpp b/src/adult_mosquito_eqs.cpp index ad5553eb..8126c7d7 100644 --- a/src/adult_mosquito_eqs.cpp +++ b/src/adult_mosquito_eqs.cpp @@ -17,7 +17,7 @@ AdultMosquitoModel::AdultMosquitoModel( ) : growth_model(growth_model), mu(mu), tau(tau), foim(foim) { for (auto i = 0u; i < tau; ++i) { - lagged_incubating.push(incubating); + lagged_incubating.push_back(incubating); } } @@ -82,12 +82,30 @@ void adult_mosquito_model_update( model->foim = foim; model->growth_model.f = f; model->growth_model.mum = mu; - model->lagged_incubating.push(susceptible * foim); + model->lagged_incubating.push_back(susceptible * foim); if (model->lagged_incubating.size() > 0) { - model->lagged_incubating.pop(); + model->lagged_incubating.pop_front(); } } +//[[Rcpp::export]] +std::vector adult_mosquito_model_save_state( + Rcpp::XPtr model + ) { + // Only the lagged_incubating needs to be saved. The rest of the model + // state is reset at each time step by a call to update before it gets + // used. + return {model->lagged_incubating.begin(), model->lagged_incubating.end()}; +} + +//[[Rcpp::export]] +void adult_mosquito_model_restore_state( + Rcpp::XPtr model, + std::vector state + ) { + model->lagged_incubating.assign(state.begin(), state.end()); +} + //[[Rcpp::export]] Rcpp::XPtr create_adult_solver( Rcpp::XPtr model, diff --git a/src/adult_mosquito_eqs.h b/src/adult_mosquito_eqs.h index 6fc30501..c3e0ac8b 100644 --- a/src/adult_mosquito_eqs.h +++ b/src/adult_mosquito_eqs.h @@ -28,7 +28,7 @@ enum class AdultState : size_t {S = 3, E = 4, I = 5}; */ struct AdultMosquitoModel { AquaticMosquitoModel growth_model; - std::queue lagged_incubating; //last tau values for incubating mosquitos + std::deque lagged_incubating; //last tau values for incubating mosquitos double mu; //death rate for adult female mosquitoes const double tau; //extrinsic incubation period double foim; //force of infection towards mosquitoes diff --git a/src/solver.cpp b/src/solver.cpp index 7cb8b9f4..0a5f7401 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -47,6 +47,11 @@ std::vector solver_get_states(Rcpp::XPtr solver) { return solver->state; } +//[[Rcpp::export]] +void solver_set_states(Rcpp::XPtr solver, std::vector state) { + solver->state = state; +} + //[[Rcpp::export]] void solver_step(Rcpp::XPtr solver) { solver->step(); diff --git a/src/timeseries.cpp b/src/timeseries.cpp index 2383ac73..51a494de 100644 --- a/src/timeseries.cpp +++ b/src/timeseries.cpp @@ -15,18 +15,18 @@ Timeseries::Timeseries(size_t max_size, double default_value) : max_size(max_size), has_default(true), default_value(default_value) {} void Timeseries::push(double value, double time) { - values.insert({time, value}); + _values.insert({time, value}); if (max_size != -1) { - while(values.size() > max_size) { - values.erase(values.begin()); + while(_values.size() > max_size) { + _values.erase(_values.begin()); } } } double Timeseries::at(double time, bool linear) const { - auto it = values.lower_bound(time); - if (it == values.end()) { - if (values.size() > 0 && !linear) { + auto it = _values.lower_bound(time); + if (it == _values.end()) { + if (_values.size() > 0 && !linear) { it--; return it->second; } @@ -45,7 +45,7 @@ double Timeseries::at(double time, bool linear) const { auto after_element = *it; while(it->first > time) { // Check if we're at the start of the timeseries - if (it == values.begin()) { + if (it == _values.begin()) { if (has_default) { return default_value; } @@ -64,6 +64,14 @@ double Timeseries::at(double time, bool linear) const { return it->second; } +const std::map& Timeseries::values() { + return _values; +} + +void Timeseries::set_values(std::map values) { + _values = std::move(values); +} + //[[Rcpp::export]] Rcpp::XPtr create_timeseries(size_t size, double default_value) { return Rcpp::XPtr(new Timeseries(size, default_value), true); @@ -78,3 +86,32 @@ double timeseries_at(Rcpp::XPtr timeseries, double timestep, bool li void timeseries_push(Rcpp::XPtr timeseries, double value, double timestep) { return timeseries->push(value, timestep); } + +//[[Rcpp::export]] +Rcpp::List timeseries_save_state(Rcpp::XPtr timeseries) { + std::vector timesteps; + std::vector values; + for (const auto& entry: timeseries->values()) { + timesteps.push_back(entry.first); + values.push_back(entry.second); + } + return Rcpp::DataFrame::create( + Rcpp::Named("timestep") = timesteps, + Rcpp::Named("value") = values + ); +} + +//[[Rcpp::export]] +void timeseries_restore_state(Rcpp::XPtr timeseries, Rcpp::List state) { + std::vector timesteps = state["timestep"]; + std::vector values = state["value"]; + if (timesteps.size() != values.size()) { + Rcpp::stop("Bad size"); + } + + std::map values_map; + for (size_t i = 0; i < timesteps.size(); i++) { + values_map.insert({timesteps[i], values[i]}); + } + timeseries->set_values(std::move(values_map)); +} diff --git a/src/timeseries.h b/src/timeseries.h index c9e7c32f..78f3780b 100644 --- a/src/timeseries.h +++ b/src/timeseries.h @@ -12,7 +12,7 @@ class Timeseries { private: - std::map values; + std::map _values; size_t max_size; void clean(); bool has_default; @@ -23,6 +23,9 @@ class Timeseries { Timeseries(size_t, double); void push(double, double); double at(double, bool = true) const; + + const std::map& values(); + void set_values(std::map state); }; #endif /* SRC_TIMESERIES_ */ diff --git a/src/utils.cpp b/src/utils.cpp index 838a9530..d2b36043 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -8,6 +8,16 @@ void random_seed(size_t seed) { Random::get_instance().seed(seed); } +//[[Rcpp::export]] +std::string random_save_state() { + return Random::get_instance().save_state(); +} + +//[[Rcpp::export]] +void random_restore_state(std::string state) { + return Random::get_instance().restore_state(state); +} + //[[Rcpp::export]] std::vector bernoulli_multi_p_cpp(const std::vector p) { auto values = Random::get_instance().bernoulli_multi_p(p); diff --git a/tests/testthat/helper-integration.R b/tests/testthat/helper-integration.R index b69b7dcb..22c7dd63 100644 --- a/tests/testthat/helper-integration.R +++ b/tests/testthat/helper-integration.R @@ -88,6 +88,13 @@ mock_event <- function(event) { ) } +mock_solver <- function(states) { + list( + get_states = mockery::mock(states), + step = mockery::mock() + ) +} + expect_bitset_update <- function(mock, value, index, call = 1) { expect_equal(mockery::mock_args(mock)[[call]][[1]], value) expect_equal(mockery::mock_args(mock)[[call]][[2]]$to_vector(), index) diff --git a/tests/testthat/test-biting-integration.R b/tests/testthat/test-biting-integration.R index 66bf5b72..7da79b66 100644 --- a/tests/testthat/test-biting-integration.R +++ b/tests/testthat/test-biting-integration.R @@ -131,7 +131,7 @@ test_that('simulate_bites integrates eir calculation and mosquito side effects', expect_equal(effects_args[[1]][[8]], parameters) expect_equal(effects_args[[1]][[9]], timestep) - mockery::expect_args(eqs_update, 1, models[[1]], 25, f, parameters$mum) + mockery::expect_args(eqs_update, 1, models[[1]]$.model, 25, f, parameters$mum) mockery::expect_args( pois_mock, 1, diff --git a/tests/testthat/test-compartmental.R b/tests/testthat/test-compartmental.R index a4fb3b6a..9958c7a5 100644 --- a/tests/testthat/test-compartmental.R +++ b/tests/testthat/test-compartmental.R @@ -11,9 +11,14 @@ test_that('ODE stays at equilibrium with a constant total_M', { counts <- c() for (t in seq(timesteps)) { - counts <- rbind(counts, c(t, solver_get_states(solvers[[1]]))) - aquatic_mosquito_model_update(models[[1]], total_M, f, parameters$mum) - solver_step(solvers[[1]]) + counts <- rbind(counts, c(t, solvers[[1]]$get_states())) + aquatic_mosquito_model_update( + models[[1]]$.model, + total_M, + f, + parameters$mum + ) + solvers[[1]]$step() } expected <- c() @@ -41,16 +46,16 @@ test_that('Adult ODE stays at equilibrium with a constant foim and mu', { counts <- c() for (t in seq(timesteps)) { - states <- solver_get_states(solvers[[1]]) + states <- solvers[[1]]$get_states() counts <- rbind(counts, c(t, states)) adult_mosquito_model_update( - models[[1]], + models[[1]]$.model, parameters$mum, parameters$init_foim, states[ADULT_ODE_INDICES['Sm']], f ) - solver_step(solvers[[1]]) + solvers[[1]]$step() } expected <- c() @@ -82,9 +87,14 @@ test_that('ODE stays at equilibrium with low total_M', { counts <- c() for (t in seq(timesteps)) { - counts <- rbind(counts, c(t, solver_get_states(solvers[[1]]))) - aquatic_mosquito_model_update(models[[1]], total_M, f, parameters$mum) - solver_step(solvers[[1]]) + counts <- rbind(counts, c(t, solvers[[1]]$get_states())) + aquatic_mosquito_model_update( + models[[1]]$.model, + total_M, + f, + parameters$mum + ) + solvers[[1]]$step() } expected <- c() @@ -121,14 +131,19 @@ test_that('Changing total_M stabilises', { counts <- c() for (t in seq(timesteps)) { - counts <- rbind(counts, c(t, solver_get_states(solvers[[1]]))) + counts <- rbind(counts, c(t, solvers[[1]]$get_states())) if (t < change) { total_M <- total_M_0 } else { total_M <- total_M_1 } - aquatic_mosquito_model_update(models[[1]], total_M, f, parameters$mum) - solver_step(solvers[[1]]) + aquatic_mosquito_model_update( + models[[1]]$.model, + total_M, + f, + parameters$mum + ) + solvers[[1]]$step() } initial_eq <- initial_mosquito_counts( diff --git a/tests/testthat/test-emergence-integration.R b/tests/testthat/test-emergence-integration.R index fa7464be..bf452d2f 100644 --- a/tests/testthat/test-emergence-integration.R +++ b/tests/testthat/test-emergence-integration.R @@ -8,22 +8,19 @@ test_that('emergence process fails when there are not enough individuals', { c('gamb'), rep('gamb', 2000) ) + solvers <- list( + mock_solver(c(1000, 500, 100)), + mock_solver(c(1000, 500, 100)) + ) + emergence_process <- create_mosquito_emergence_process( - list(), + solvers, state, species, c('gamb'), parameters$dpl ) - mockery::stub( - emergence_process, - 'solver_get_states', - mockery::mock( - c(1000, 500, 100), - c(1000, 500, 100) - ) - ) - expect_error(emergence_process(0), '*') + expect_error(emergence_process(0), 'Not enough mosquitoes') }) test_that('emergence_process creates the correct number of susceptables', { @@ -36,23 +33,19 @@ test_that('emergence_process creates the correct number of susceptables', { c('a', 'b'), c('a', 'b') ) + solvers <- list( + mock_solver(c(100000, 50000, 10000)), + mock_solver(c(1000, 5000, 1000)) + ) + emergence_process <- create_mosquito_emergence_process( - list(mockery::mock(), mockery::mock()), + solvers, state, species, c('a', 'b'), parameters$dpl ) - mockery::stub( - emergence_process, - 'solver_get_states', - mockery::mock( - c(100000, 50000, 10000), - c(10000, 5000, 1000) - ) - ) - emergence_process(0) expect_bitset_update( diff --git a/tests/testthat/test-resume.R b/tests/testthat/test-resume.R new file mode 100644 index 00000000..83bada35 --- /dev/null +++ b/tests/testthat/test-resume.R @@ -0,0 +1,41 @@ +test_that('Simulation can be resumed', { + initial_timesteps <- 50 + total_timesteps <- 100 + + parameters <- get_parameters() + + set.seed(1) + first_phase <- run_resumable_simulation(initial_timesteps, parameters) + second_phase <- run_resumable_simulation( + total_timesteps, + parameters, + initial_state=first_phase$state, + restore_random_state=TRUE) + + set.seed(1) + uninterrupted_run <- run_simulation(total_timesteps, parameters=parameters) + + expect_equal(nrow(first_phase$data), initial_timesteps) + expect_equal(nrow(second_phase$data), total_timesteps - initial_timesteps) + expect_equal(rbind(first_phase$data, second_phase$data), uninterrupted_run) +}) + +test_that('Intervention parameters can be changed when resuming', { + initial_timesteps <- 50 + total_timesteps <- 100 + tbv_timesteps <- 70 + + # Because of how event scheduling works, we must enable TBV even in the inital phase. + # We set a coverage of 0 to act as-if it was disabled. + initial_parameters <- get_parameters() |> set_tbv(timesteps=tbv_timesteps, coverage=0, ages=5:60) + + tbv_parameters <- initial_parameters |> + set_tbv(timesteps=tbv_timesteps, coverage=1, ages=5:60) + + initial_run <- run_resumable_simulation(initial_timesteps, initial_parameters) + control_run <- run_resumable_simulation(total_timesteps, initial_parameters, initial_state = initial_run$state) + tbv_run <- run_resumable_simulation(total_timesteps, tbv_parameters, initial_state = initial_run$state) + + expect_equal(control_run$data$n_vaccinated_tbv[tbv_timesteps - initial_timesteps], 0) + expect_gt(tbv_run$data$n_vaccinated_tbv[tbv_timesteps - initial_timesteps], 0) +}) diff --git a/tests/testthat/test-seasonality.R b/tests/testthat/test-seasonality.R index 01302d8d..f600ebf4 100644 --- a/tests/testthat/test-seasonality.R +++ b/tests/testthat/test-seasonality.R @@ -15,14 +15,14 @@ test_that('Seasonality correctly affects P', { counts <- c() for (t in seq(timesteps)) { - counts <- rbind(counts, c(t, solver_get_states(solvers[[1]]))) + counts <- rbind(counts, c(t, solvers[[1]]$get_states())) aquatic_mosquito_model_update( - models[[1]], + models[[1]]$.model, total_M, parameters$blood_meal_rates, parameters$mum ) - solver_step(solvers[[1]]) + solvers[[1]]$step() } burn_in <- 20