Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial support for saving and restoring simulation state. #280

Merged
merged 12 commits into from
Feb 27, 2024
6 changes: 4 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,15 @@ LazyData: true
Remotes:
mrc-ide/malariaEquilibrium,
mrc-ide/individual
Additional_repositories:
https://mrc-ide.r-universe.dev
Imports:
individual (>= 0.1.7),
individual (>= 0.1.13),
malariaEquilibrium (>= 1.0.1),
Rcpp,
statmod,
MASS,
dqrng,
dqrng (>= 0.3.2.2),
sitmo,
BH,
R6,
Expand Down
28 changes: 28 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ adult_mosquito_model_update <- function(model, mu, foim, susceptible, f) {
invisible(.Call(`_malariasimulation_adult_mosquito_model_update`, model, mu, foim, susceptible, f))
}

adult_mosquito_model_save_state <- function(model) {
.Call(`_malariasimulation_adult_mosquito_model_save_state`, model)
}

adult_mosquito_model_restore_state <- function(model, state) {
invisible(.Call(`_malariasimulation_adult_mosquito_model_restore_state`, model, state))
}

create_adult_solver <- function(model, init, r_tol, a_tol, max_steps) {
.Call(`_malariasimulation_create_adult_solver`, model, init, r_tol, a_tol, max_steps)
}
Expand Down Expand Up @@ -41,6 +49,10 @@ solver_get_states <- function(solver) {
.Call(`_malariasimulation_solver_get_states`, solver)
}

solver_set_states <- function(solver, state) {
invisible(.Call(`_malariasimulation_solver_set_states`, solver, state))
}

solver_step <- function(solver) {
invisible(.Call(`_malariasimulation_solver_step`, solver))
}
Expand All @@ -57,10 +69,26 @@ timeseries_push <- function(timeseries, value, timestep) {
invisible(.Call(`_malariasimulation_timeseries_push`, timeseries, value, timestep))
}

timeseries_save_state <- function(timeseries) {
.Call(`_malariasimulation_timeseries_save_state`, timeseries)
}

timeseries_restore_state <- function(timeseries, state) {
invisible(.Call(`_malariasimulation_timeseries_restore_state`, timeseries, state))
}

random_seed <- function(seed) {
invisible(.Call(`_malariasimulation_random_seed`, seed))
}

random_save_state <- function() {
.Call(`_malariasimulation_random_save_state`)
}

random_restore_state <- function(state) {
invisible(.Call(`_malariasimulation_random_restore_state`, state))
}

bernoulli_multi_p_cpp <- function(p) {
.Call(`_malariasimulation_bernoulli_multi_p_cpp`, p)
}
Expand Down
8 changes: 4 additions & 4 deletions R/biting_process.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ simulate_bites <- function(

for (s_i in seq_along(parameters$species)) {
species_name <- parameters$species[[s_i]]
solver_states <- solver_get_states(solvers[[s_i]])
solver_states <- solvers[[s_i]]$get_states()
p_bitten <- prob_bitten(timestep, variables, s_i, parameters)
Q0 <- parameters$Q0[[s_i]]
W <- average_p_successful(p_bitten$prob_bitten_survives, .pi, Q0)
Expand Down Expand Up @@ -167,7 +167,7 @@ simulate_bites <- function(
if (parameters$individual_mosquitoes) {
# update the ODE with stats for ovoposition calculations
aquatic_mosquito_model_update(
models[[s_i]],
models[[s_i]]$.model,
species_index$size(),
f,
mu
Expand All @@ -189,7 +189,7 @@ simulate_bites <- function(
)
} else {
adult_mosquito_model_update(
models[[s_i]],
models[[s_i]]$.model,
mu,
foim,
solver_states[[ADULT_ODE_INDICES['Sm']]],
Expand Down Expand Up @@ -235,7 +235,7 @@ calculate_infectious <- function(species, solvers, variables, parameters) {
)
)
}
calculate_infectious_compartmental(solver_get_states(solvers[[species]]))
calculate_infectious_compartmental(solvers[[species]]$get_states())
}

calculate_infectious_individual <- function(
Expand Down
81 changes: 70 additions & 11 deletions R/compartmental.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ parameterise_mosquito_models <- function(parameters, timesteps) {
m
)[ADULT_ODE_INDICES['Sm']]
return(
create_adult_mosquito_model(
AdultMosquitoModel$new(create_adult_mosquito_model(
growth_model,
parameters$mum[[i]],
parameters$dem,
susceptible * parameters$init_foim,
parameters$init_foim
)
))
)
}
growth_model
AquaticMosquitoModel$new(growth_model)
}
)
}
Expand All @@ -72,22 +72,22 @@ parameterise_solvers <- function(models, parameters) {
init <- initial_mosquito_counts(parameters, i, parameters$init_foim, m)
if (!parameters$individual_mosquitoes) {
return(
create_adult_solver(
models[[i]],
Solver$new(create_adult_solver(
models[[i]]$.model,
init,
parameters$r_tol,
parameters$a_tol,
parameters$ode_max_steps
)
))
)
}
create_aquatic_solver(
models[[i]],
Solver$new(create_aquatic_solver(
models[[i]]$.model,
init[ODE_INDICES],
parameters$r_tol,
parameters$a_tol,
parameters$ode_max_steps
)
))
}
)
}
Expand All @@ -103,7 +103,7 @@ create_compartmental_rendering_process <- function(renderer, solvers, parameters
counts <- rep(0, length(indices))
for (s_i in seq_along(solvers)) {
if (parameters$species_proportions[[s_i]] > 0) {
row <- solver_get_states(solvers[[s_i]])
row <- solvers[[s_i]]$get_states()
} else {
row <- rep(0, length(indices))
}
Expand All @@ -128,8 +128,67 @@ create_solver_stepping_process <- function(solvers, parameters) {
function(timestep) {
for (i in seq_along(solvers)) {
if (parameters$species_proportions[[i]] > 0) {
solver_step(solvers[[i]])
solvers[[i]]$step()
}
}
}
}

Solver <- R6::R6Class(
'Solver',
private = list(
.solver = NULL
),
public = list(
initialize = function(solver) {
private$.solver <- solver
},
step = function() {
solver_step(private$.solver)
},
get_states = function() {
solver_get_states(private$.solver)
},

# This is the same as `get_states`, just exposed under the interface that
# is expected of stateful objects.
save_state = function() {
solver_get_states(private$.solver)
},
restore_state = function(state) {
solver_set_states(private$.solver, state)
}
)
)

AquaticMosquitoModel <- R6::R6Class(
'AquaticMosquitoModel',
public = list(
.model = NULL,
initialize = function(model) {
self$.model <- model
},

# The aquatic mosquito model doesn't have any state to save or restore (the
# state of the ODE is stored separately). We still provide these methods to
# conform to the expected interface.
save_state = function() { NULL },
restore_state = function(state) { }
)
)

AdultMosquitoModel <- R6::R6Class(
'AdultMosquitoModel',
public = list(
.model = NULL,
initialize = function(model) {
self$.model <- model
},
save_state = function() {
adult_mosquito_model_save_state(self$.model)
},
restore_state = function(state) {
adult_mosquito_model_restore_state(self$.model, state)
}
)
)
19 changes: 19 additions & 0 deletions R/correlation.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,25 @@ CorrelationParameters <- R6::R6Class(
dimnames(private$.mvnorm)[[2]] <- private$interventions
}
private$.mvnorm
},

#' @description Save the correlation state.
save_state = function() {
# mvnorm is sampled at random lazily on its first use. We need to save it
# in order to restore the same value when resuming the simulation,
# otherwise we would be drawing a new, probably different, value.
# The rest of the object is derived deterministically from the parameters
# and does not need saving.
list(mvnorm=private$.mvnorm)
},

#' @description Restore the correlation state.
#' Only the randomly drawn weights are restored. The object needs to be
#' initialized with the same rhos.
#' @param state a previously saved correlation state, as returned by the
#' save_state method.
restore_state = function(state) {
private$.mvnorm <- state$mvnorm
}
)
)
Expand Down
8 changes: 8 additions & 0 deletions R/lag.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ LaggedValue <- R6::R6Class(

get = function(timestep) {
timeseries_at(private$history, timestep, TRUE)
},

save_state = function() {
timeseries_save_state(private$history)
},

restore_state = function(state) {
timeseries_restore_state(private$history, state)
}
)
)
64 changes: 59 additions & 5 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,28 @@ run_simulation <- function(
timesteps,
parameters = NULL,
correlations = NULL
) {
run_resumable_simulation(timesteps, parameters, correlations)$data
}

#' @title Run the simulation in a resumable way
#'
#' @description this function accepts an initial simulation state as an argument, and returns the
#' final state after running all of its timesteps. This allows one run to be resumed, possibly
#' having changed some of the parameters.
#' @param timesteps the timestep at which to stop the simulation
#' @param parameters a named list of parameters to use
#' @param correlations correlation parameters
#' @param initial_state the state from which the simulation is resumed
#' @param restore_random_state if TRUE, restore the random number generator's state from the checkpoint.
#' @return a list with two entries, one for the dataframe of results and one for the final
#' simulation state.
run_resumable_simulation <- function(
timesteps,
parameters = NULL,
correlations = NULL,
initial_state = NULL,
restore_random_state = FALSE
) {
random_seed(ceiling(runif(1) * .Machine$integer.max))
if (is.null(parameters)) {
Expand All @@ -108,7 +130,23 @@ run_simulation <- function(
)
vector_models <- parameterise_mosquito_models(parameters, timesteps)
solvers <- parameterise_solvers(vector_models, parameters)
individual::simulation_loop(

lagged_eir <- create_lagged_eir(variables, solvers, parameters)
lagged_infectivity <- create_lagged_infectivity(variables, parameters)

stateful_objects <- unlist(list(
RandomState$new(restore_random_state),
correlations,
vector_models,
solvers,
lagged_eir,
lagged_infectivity))

if (!is.null(initial_state)) {
restore_state(initial_state$malariasimulation, stateful_objects)
}

individual_state <- individual::simulation_loop(
processes = create_processes(
renderer,
variables,
Expand All @@ -117,15 +155,31 @@ run_simulation <- function(
vector_models,
solvers,
correlations,
list(create_lagged_eir(variables, solvers, parameters)),
list(create_lagged_infectivity(variables, parameters)),
list(lagged_eir),
list(lagged_infectivity),
timesteps
),
variables = variables,
events = unlist(events),
timesteps = timesteps
timesteps = timesteps,
state = initial_state$individual,
restore_random_state = restore_random_state
)
renderer$to_dataframe()

final_state <- list(
timesteps=timesteps,
individual=individual_state,
malariasimulation=save_state(stateful_objects)
)

data <- renderer$to_dataframe()
if (!is.null(initial_state)) {
# Drop the timesteps we didn't simulate from the data.
# It would just be full of NA.
data <- data[-(1:initial_state$timesteps),]
}

list(data=data, state=final_state)
}

#' @title Run a metapopulation model
Expand Down
2 changes: 1 addition & 1 deletion R/mosquito_biology.R
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ create_mosquito_emergence_process <- function(
p_counts <- vnapply(
solvers,
function(solver) {
solver_get_states(solver)[[ODE_INDICES[['P']]]]
solver$get_states()[[ODE_INDICES[['P']]]]
}
)
n <- sum(p_counts) * rate
Expand Down
2 changes: 1 addition & 1 deletion R/render.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ create_total_M_renderer_compartmental <- function(renderer, solvers, parameters)
function(timestep) {
total_M <- 0
for (i in seq_along(solvers)) {
row <- solver_get_states(solvers[[i]])
row <- solvers[[i]]$get_states()
species_M <- sum(row[ADULT_ODE_INDICES])
total_M <- total_M + species_M
renderer$render(paste0('total_M_', parameters$species[[i]]), species_M, timestep)
Expand Down
Loading
Loading