Skip to content

Commit

Permalink
Add initial support for saving and restoring simulation state. (#280)
Browse files Browse the repository at this point in the history
The model can now be run for a given number of steps, have its state
saved, and then restore it and run for some more time steps.

Parameters of the model may be modified when resuming, allowing a
simulation to be forked with a single historical run and many futures,
modelling different intervention scenarios.

There are some limitations as to which parameters may be modified. In
general, the structure of the simulation must remain the same, with the
same variables and same events. This means interventions cannot be
enabled or disabled when resuming, only their parameterization can
change. Additionally, the timing of intervention events should not be
modified. These limitations may be relaxed in the future, if there is a
need for it.

To avoid changing the existing API, this feature is available as a new
`run_resumable_simulation` function. The function returns a pair of
values, the existing dataframe with the simulation data and an object
representing the internal simulation state. The function can be called a
second time, passing the saved simulation state as an extra argument.
See the `test-resume.R` file for usage of the new function.

The implementation builds on top of individual's new support for this.
Individual already saves the state of its native objects, ie. variables
and events. The malaria model keeps quite a bit of state outside of
individual, which we need to save and restore explicitly. This is done
by creating a set of "stateful objects", ie. R6 objects with a pair
`save_state` and `restore_state` methods. This interface is implemented
by every bit of the model that has state to capture:

- `LaggedValue` objects store the short term EIR and FOIM history.
- `Solver` objects represent the current state of ODE solvers.
- Adult mosquito ODEs keep a history of incubating values which need to
  be restored.
- `CorrelationParameters` stores a randomly sampled value. This needs to
  be saved to ensure the simulation is resumed with that same value.
- In addition to R's native RNG (whose state is already saved by
  individual), the model uses the dqrng library, whose state needs
  saving.
  • Loading branch information
plietar authored Feb 27, 2024
1 parent 1a5f00a commit c42e2c5
Show file tree
Hide file tree
Showing 27 changed files with 583 additions and 75 deletions.
6 changes: 4 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,15 @@ LazyData: true
Remotes:
mrc-ide/malariaEquilibrium,
mrc-ide/individual
Additional_repositories:
https://mrc-ide.r-universe.dev
Imports:
individual (>= 0.1.7),
individual (>= 0.1.13),
malariaEquilibrium (>= 1.0.1),
Rcpp,
statmod,
MASS,
dqrng,
dqrng (>= 0.3.2.2),
sitmo,
BH,
R6,
Expand Down
28 changes: 28 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ adult_mosquito_model_update <- function(model, mu, foim, susceptible, f) {
invisible(.Call(`_malariasimulation_adult_mosquito_model_update`, model, mu, foim, susceptible, f))
}

adult_mosquito_model_save_state <- function(model) {
.Call(`_malariasimulation_adult_mosquito_model_save_state`, model)
}

adult_mosquito_model_restore_state <- function(model, state) {
invisible(.Call(`_malariasimulation_adult_mosquito_model_restore_state`, model, state))
}

create_adult_solver <- function(model, init, r_tol, a_tol, max_steps) {
.Call(`_malariasimulation_create_adult_solver`, model, init, r_tol, a_tol, max_steps)
}
Expand Down Expand Up @@ -41,6 +49,10 @@ solver_get_states <- function(solver) {
.Call(`_malariasimulation_solver_get_states`, solver)
}

solver_set_states <- function(solver, state) {
invisible(.Call(`_malariasimulation_solver_set_states`, solver, state))
}

solver_step <- function(solver) {
invisible(.Call(`_malariasimulation_solver_step`, solver))
}
Expand All @@ -57,10 +69,26 @@ timeseries_push <- function(timeseries, value, timestep) {
invisible(.Call(`_malariasimulation_timeseries_push`, timeseries, value, timestep))
}

timeseries_save_state <- function(timeseries) {
.Call(`_malariasimulation_timeseries_save_state`, timeseries)
}

timeseries_restore_state <- function(timeseries, state) {
invisible(.Call(`_malariasimulation_timeseries_restore_state`, timeseries, state))
}

random_seed <- function(seed) {
invisible(.Call(`_malariasimulation_random_seed`, seed))
}

random_save_state <- function() {
.Call(`_malariasimulation_random_save_state`)
}

random_restore_state <- function(state) {
invisible(.Call(`_malariasimulation_random_restore_state`, state))
}

bernoulli_multi_p_cpp <- function(p) {
.Call(`_malariasimulation_bernoulli_multi_p_cpp`, p)
}
Expand Down
8 changes: 4 additions & 4 deletions R/biting_process.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ simulate_bites <- function(

for (s_i in seq_along(parameters$species)) {
species_name <- parameters$species[[s_i]]
solver_states <- solver_get_states(solvers[[s_i]])
solver_states <- solvers[[s_i]]$get_states()
p_bitten <- prob_bitten(timestep, variables, s_i, parameters)
Q0 <- parameters$Q0[[s_i]]
W <- average_p_successful(p_bitten$prob_bitten_survives, .pi, Q0)
Expand Down Expand Up @@ -167,7 +167,7 @@ simulate_bites <- function(
if (parameters$individual_mosquitoes) {
# update the ODE with stats for ovoposition calculations
aquatic_mosquito_model_update(
models[[s_i]],
models[[s_i]]$.model,
species_index$size(),
f,
mu
Expand All @@ -189,7 +189,7 @@ simulate_bites <- function(
)
} else {
adult_mosquito_model_update(
models[[s_i]],
models[[s_i]]$.model,
mu,
foim,
solver_states[[ADULT_ODE_INDICES['Sm']]],
Expand Down Expand Up @@ -235,7 +235,7 @@ calculate_infectious <- function(species, solvers, variables, parameters) {
)
)
}
calculate_infectious_compartmental(solver_get_states(solvers[[species]]))
calculate_infectious_compartmental(solvers[[species]]$get_states())
}

calculate_infectious_individual <- function(
Expand Down
81 changes: 70 additions & 11 deletions R/compartmental.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ parameterise_mosquito_models <- function(parameters, timesteps) {
m
)[ADULT_ODE_INDICES['Sm']]
return(
create_adult_mosquito_model(
AdultMosquitoModel$new(create_adult_mosquito_model(
growth_model,
parameters$mum[[i]],
parameters$dem,
susceptible * parameters$init_foim,
parameters$init_foim
)
))
)
}
growth_model
AquaticMosquitoModel$new(growth_model)
}
)
}
Expand All @@ -72,22 +72,22 @@ parameterise_solvers <- function(models, parameters) {
init <- initial_mosquito_counts(parameters, i, parameters$init_foim, m)
if (!parameters$individual_mosquitoes) {
return(
create_adult_solver(
models[[i]],
Solver$new(create_adult_solver(
models[[i]]$.model,
init,
parameters$r_tol,
parameters$a_tol,
parameters$ode_max_steps
)
))
)
}
create_aquatic_solver(
models[[i]],
Solver$new(create_aquatic_solver(
models[[i]]$.model,
init[ODE_INDICES],
parameters$r_tol,
parameters$a_tol,
parameters$ode_max_steps
)
))
}
)
}
Expand All @@ -103,7 +103,7 @@ create_compartmental_rendering_process <- function(renderer, solvers, parameters
counts <- rep(0, length(indices))
for (s_i in seq_along(solvers)) {
if (parameters$species_proportions[[s_i]] > 0) {
row <- solver_get_states(solvers[[s_i]])
row <- solvers[[s_i]]$get_states()
} else {
row <- rep(0, length(indices))
}
Expand All @@ -128,8 +128,67 @@ create_solver_stepping_process <- function(solvers, parameters) {
function(timestep) {
for (i in seq_along(solvers)) {
if (parameters$species_proportions[[i]] > 0) {
solver_step(solvers[[i]])
solvers[[i]]$step()
}
}
}
}

Solver <- R6::R6Class(
'Solver',
private = list(
.solver = NULL
),
public = list(
initialize = function(solver) {
private$.solver <- solver
},
step = function() {
solver_step(private$.solver)
},
get_states = function() {
solver_get_states(private$.solver)
},

# This is the same as `get_states`, just exposed under the interface that
# is expected of stateful objects.
save_state = function() {
solver_get_states(private$.solver)
},
restore_state = function(state) {
solver_set_states(private$.solver, state)
}
)
)

AquaticMosquitoModel <- R6::R6Class(
'AquaticMosquitoModel',
public = list(
.model = NULL,
initialize = function(model) {
self$.model <- model
},

# The aquatic mosquito model doesn't have any state to save or restore (the
# state of the ODE is stored separately). We still provide these methods to
# conform to the expected interface.
save_state = function() { NULL },
restore_state = function(state) { }
)
)

AdultMosquitoModel <- R6::R6Class(
'AdultMosquitoModel',
public = list(
.model = NULL,
initialize = function(model) {
self$.model <- model
},
save_state = function() {
adult_mosquito_model_save_state(self$.model)
},
restore_state = function(state) {
adult_mosquito_model_restore_state(self$.model, state)
}
)
)
19 changes: 19 additions & 0 deletions R/correlation.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,25 @@ CorrelationParameters <- R6::R6Class(
dimnames(private$.mvnorm)[[2]] <- private$interventions
}
private$.mvnorm
},

#' @description Save the correlation state.
save_state = function() {
# mvnorm is sampled at random lazily on its first use. We need to save it
# in order to restore the same value when resuming the simulation,
# otherwise we would be drawing a new, probably different, value.
# The rest of the object is derived deterministically from the parameters
# and does not need saving.
list(mvnorm=private$.mvnorm)
},

#' @description Restore the correlation state.
#' Only the randomly drawn weights are restored. The object needs to be
#' initialized with the same rhos.
#' @param state a previously saved correlation state, as returned by the
#' save_state method.
restore_state = function(state) {
private$.mvnorm <- state$mvnorm
}
)
)
Expand Down
8 changes: 8 additions & 0 deletions R/lag.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ LaggedValue <- R6::R6Class(

get = function(timestep) {
timeseries_at(private$history, timestep, TRUE)
},

save_state = function() {
timeseries_save_state(private$history)
},

restore_state = function(state) {
timeseries_restore_state(private$history, state)
}
)
)
64 changes: 59 additions & 5 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,28 @@ run_simulation <- function(
timesteps,
parameters = NULL,
correlations = NULL
) {
run_resumable_simulation(timesteps, parameters, correlations)$data
}

#' @title Run the simulation in a resumable way
#'
#' @description this function accepts an initial simulation state as an argument, and returns the
#' final state after running all of its timesteps. This allows one run to be resumed, possibly
#' having changed some of the parameters.
#' @param timesteps the timestep at which to stop the simulation
#' @param parameters a named list of parameters to use
#' @param correlations correlation parameters
#' @param initial_state the state from which the simulation is resumed
#' @param restore_random_state if TRUE, restore the random number generator's state from the checkpoint.
#' @return a list with two entries, one for the dataframe of results and one for the final
#' simulation state.
run_resumable_simulation <- function(
timesteps,
parameters = NULL,
correlations = NULL,
initial_state = NULL,
restore_random_state = FALSE
) {
random_seed(ceiling(runif(1) * .Machine$integer.max))
if (is.null(parameters)) {
Expand All @@ -108,7 +130,23 @@ run_simulation <- function(
)
vector_models <- parameterise_mosquito_models(parameters, timesteps)
solvers <- parameterise_solvers(vector_models, parameters)
individual::simulation_loop(

lagged_eir <- create_lagged_eir(variables, solvers, parameters)
lagged_infectivity <- create_lagged_infectivity(variables, parameters)

stateful_objects <- unlist(list(
RandomState$new(restore_random_state),
correlations,
vector_models,
solvers,
lagged_eir,
lagged_infectivity))

if (!is.null(initial_state)) {
restore_state(initial_state$malariasimulation, stateful_objects)
}

individual_state <- individual::simulation_loop(
processes = create_processes(
renderer,
variables,
Expand All @@ -117,15 +155,31 @@ run_simulation <- function(
vector_models,
solvers,
correlations,
list(create_lagged_eir(variables, solvers, parameters)),
list(create_lagged_infectivity(variables, parameters)),
list(lagged_eir),
list(lagged_infectivity),
timesteps
),
variables = variables,
events = unlist(events),
timesteps = timesteps
timesteps = timesteps,
state = initial_state$individual,
restore_random_state = restore_random_state
)
renderer$to_dataframe()

final_state <- list(
timesteps=timesteps,
individual=individual_state,
malariasimulation=save_state(stateful_objects)
)

data <- renderer$to_dataframe()
if (!is.null(initial_state)) {
# Drop the timesteps we didn't simulate from the data.
# It would just be full of NA.
data <- data[-(1:initial_state$timesteps),]
}

list(data=data, state=final_state)
}

#' @title Run a metapopulation model
Expand Down
2 changes: 1 addition & 1 deletion R/mosquito_biology.R
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ create_mosquito_emergence_process <- function(
p_counts <- vnapply(
solvers,
function(solver) {
solver_get_states(solver)[[ODE_INDICES[['P']]]]
solver$get_states()[[ODE_INDICES[['P']]]]
}
)
n <- sum(p_counts) * rate
Expand Down
2 changes: 1 addition & 1 deletion R/render.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ create_total_M_renderer_compartmental <- function(renderer, solvers, parameters)
function(timestep) {
total_M <- 0
for (i in seq_along(solvers)) {
row <- solver_get_states(solvers[[i]])
row <- solvers[[i]]$get_states()
species_M <- sum(row[ADULT_ODE_INDICES])
total_M <- total_M + species_M
renderer$render(paste0('total_M_', parameters$species[[i]]), species_M, timestep)
Expand Down
Loading

0 comments on commit c42e2c5

Please sign in to comment.