Skip to content

Commit

Permalink
Merge branch 'dev' into feat/spc
Browse files Browse the repository at this point in the history
  • Loading branch information
tbreweric authored Apr 8, 2024
2 parents 5123d6e + 2a3d4cc commit 54456df
Show file tree
Hide file tree
Showing 32 changed files with 847 additions and 150 deletions.
6 changes: 4 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,15 @@ LazyData: true
Remotes:
mrc-ide/malariaEquilibrium,
mrc-ide/individual
Additional_repositories:
https://mrc-ide.r-universe.dev
Imports:
individual (>= 0.1.7),
individual (>= 0.1.15),
malariaEquilibrium (>= 1.0.1),
Rcpp,
statmod,
MASS,
dqrng,
dqrng (>= 0.3.2.2),
sitmo,
BH,
R6,
Expand Down
32 changes: 32 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ adult_mosquito_model_update <- function(model, mu, foim, susceptible, f) {
invisible(.Call(`_malariasimulation_adult_mosquito_model_update`, model, mu, foim, susceptible, f))
}

adult_mosquito_model_save_state <- function(model) {
.Call(`_malariasimulation_adult_mosquito_model_save_state`, model)
}

adult_mosquito_model_restore_state <- function(model, state) {
invisible(.Call(`_malariasimulation_adult_mosquito_model_restore_state`, model, state))
}

create_adult_solver <- function(model, init, r_tol, a_tol, max_steps) {
.Call(`_malariasimulation_create_adult_solver`, model, init, r_tol, a_tol, max_steps)
}
Expand Down Expand Up @@ -37,10 +45,18 @@ rainfall <- function(t, g0, g, h, floor) {
.Call(`_malariasimulation_rainfall`, t, g0, g, h, floor)
}

exponential_process_cpp <- function(variable, rate) {
.Call(`_malariasimulation_exponential_process_cpp`, variable, rate)
}

solver_get_states <- function(solver) {
.Call(`_malariasimulation_solver_get_states`, solver)
}

solver_set_states <- function(solver, state) {
invisible(.Call(`_malariasimulation_solver_set_states`, solver, state))
}

solver_step <- function(solver) {
invisible(.Call(`_malariasimulation_solver_step`, solver))
}
Expand All @@ -57,10 +73,26 @@ timeseries_push <- function(timeseries, value, timestep) {
invisible(.Call(`_malariasimulation_timeseries_push`, timeseries, value, timestep))
}

timeseries_save_state <- function(timeseries) {
.Call(`_malariasimulation_timeseries_save_state`, timeseries)
}

timeseries_restore_state <- function(timeseries, state) {
invisible(.Call(`_malariasimulation_timeseries_restore_state`, timeseries, state))
}

random_seed <- function(seed) {
invisible(.Call(`_malariasimulation_random_seed`, seed))
}

random_save_state <- function() {
.Call(`_malariasimulation_random_save_state`)
}

random_restore_state <- function(state) {
invisible(.Call(`_malariasimulation_random_restore_state`, state))
}

bernoulli_multi_p_cpp <- function(p) {
.Call(`_malariasimulation_bernoulli_multi_p_cpp`, p)
}
Expand Down
4 changes: 2 additions & 2 deletions R/antimalarial_resistance.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ set_antimalarial_resistance <- function(parameters,
reinfection_during_prophylaxis_probability < 0 | reinfection_during_prophylaxis_probability > 1)) {
stop("Resistance outcome probabilities must fall between 0 and 1")
}

if(length(slow_parasite_clearance_time) != 1) {
if(length(slow_parasite_clearance_time) != 1) {
stop("Error: length of slow_parasite_clearance_time not equal to 1")
}

Expand Down
8 changes: 4 additions & 4 deletions R/biting_process.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ simulate_bites <- function(

for (s_i in seq_along(parameters$species)) {
species_name <- parameters$species[[s_i]]
solver_states <- solver_get_states(solvers[[s_i]])
solver_states <- solvers[[s_i]]$get_states()
p_bitten <- prob_bitten(timestep, variables, s_i, parameters)
Q0 <- parameters$Q0[[s_i]]
W <- average_p_successful(p_bitten$prob_bitten_survives, .pi, Q0)
Expand Down Expand Up @@ -167,7 +167,7 @@ simulate_bites <- function(
if (parameters$individual_mosquitoes) {
# update the ODE with stats for ovoposition calculations
aquatic_mosquito_model_update(
models[[s_i]],
models[[s_i]]$.model,
species_index$size(),
f,
mu
Expand All @@ -189,7 +189,7 @@ simulate_bites <- function(
)
} else {
adult_mosquito_model_update(
models[[s_i]],
models[[s_i]]$.model,
mu,
foim,
solver_states[[ADULT_ODE_INDICES['Sm']]],
Expand Down Expand Up @@ -235,7 +235,7 @@ calculate_infectious <- function(species, solvers, variables, parameters) {
)
)
}
calculate_infectious_compartmental(solver_get_states(solvers[[species]]))
calculate_infectious_compartmental(solvers[[species]]$get_states())
}

calculate_infectious_individual <- function(
Expand Down
81 changes: 70 additions & 11 deletions R/compartmental.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ parameterise_mosquito_models <- function(parameters, timesteps) {
m
)[ADULT_ODE_INDICES['Sm']]
return(
create_adult_mosquito_model(
AdultMosquitoModel$new(create_adult_mosquito_model(
growth_model,
parameters$mum[[i]],
parameters$dem,
susceptible * parameters$init_foim,
parameters$init_foim
)
))
)
}
growth_model
AquaticMosquitoModel$new(growth_model)
}
)
}
Expand All @@ -72,22 +72,22 @@ parameterise_solvers <- function(models, parameters) {
init <- initial_mosquito_counts(parameters, i, parameters$init_foim, m)
if (!parameters$individual_mosquitoes) {
return(
create_adult_solver(
models[[i]],
Solver$new(create_adult_solver(
models[[i]]$.model,
init,
parameters$r_tol,
parameters$a_tol,
parameters$ode_max_steps
)
))
)
}
create_aquatic_solver(
models[[i]],
Solver$new(create_aquatic_solver(
models[[i]]$.model,
init[ODE_INDICES],
parameters$r_tol,
parameters$a_tol,
parameters$ode_max_steps
)
))
}
)
}
Expand All @@ -103,7 +103,7 @@ create_compartmental_rendering_process <- function(renderer, solvers, parameters
counts <- rep(0, length(indices))
for (s_i in seq_along(solvers)) {
if (parameters$species_proportions[[s_i]] > 0) {
row <- solver_get_states(solvers[[s_i]])
row <- solvers[[s_i]]$get_states()
} else {
row <- rep(0, length(indices))
}
Expand All @@ -128,8 +128,67 @@ create_solver_stepping_process <- function(solvers, parameters) {
function(timestep) {
for (i in seq_along(solvers)) {
if (parameters$species_proportions[[i]] > 0) {
solver_step(solvers[[i]])
solvers[[i]]$step()
}
}
}
}

Solver <- R6::R6Class(
'Solver',
private = list(
.solver = NULL
),
public = list(
initialize = function(solver) {
private$.solver <- solver
},
step = function() {
solver_step(private$.solver)
},
get_states = function() {
solver_get_states(private$.solver)
},

# This is the same as `get_states`, just exposed under the interface that
# is expected of stateful objects.
save_state = function() {
solver_get_states(private$.solver)
},
restore_state = function(state) {
solver_set_states(private$.solver, state)
}
)
)

AquaticMosquitoModel <- R6::R6Class(
'AquaticMosquitoModel',
public = list(
.model = NULL,
initialize = function(model) {
self$.model <- model
},

# The aquatic mosquito model doesn't have any state to save or restore (the
# state of the ODE is stored separately). We still provide these methods to
# conform to the expected interface.
save_state = function() { NULL },
restore_state = function(state) { }
)
)

AdultMosquitoModel <- R6::R6Class(
'AdultMosquitoModel',
public = list(
.model = NULL,
initialize = function(model) {
self$.model <- model
},
save_state = function() {
adult_mosquito_model_save_state(self$.model)
},
restore_state = function(state) {
adult_mosquito_model_restore_state(self$.model, state)
}
)
)
37 changes: 28 additions & 9 deletions R/correlation.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ CorrelationParameters <- R6::R6Class(
public = list(

#' @description initialise correlation parameters
#' @param parameters model parameters
initialize = function(parameters) {
# Find a list of enabled interventions
enabled <- vlapply(INTS, function(name) parameters[[name]])
private$interventions <- INTS[enabled]
#' @param population popularion size
#' @param interventions character vector with the name of enabled interventions
initialize = function(population, interventions) {
private$population <- population
private$interventions <- interventions

# Initialise a rho matrix for our interventions
n_ints <- private$n_ints()
Expand All @@ -38,9 +38,6 @@ CorrelationParameters <- R6::R6Class(
ncol = n_ints,
dimnames = list(private$interventions, private$interventions)
)

# Store population for mvnorm draws
private$population <- parameters$human_population
},

#' @description Add rho between rounds
Expand Down Expand Up @@ -115,6 +112,25 @@ CorrelationParameters <- R6::R6Class(
dimnames(private$.mvnorm)[[2]] <- private$interventions
}
private$.mvnorm
},

#' @description Save the correlation state.
save_state = function() {
# mvnorm is sampled at random lazily on its first use. We need to save it
# in order to restore the same value when resuming the simulation,
# otherwise we would be drawing a new, probably different, value.
# The rest of the object is derived deterministically from the parameters
# and does not need saving.
list(mvnorm=private$.mvnorm)
},

#' @description Restore the correlation state.
#' Only the randomly drawn weights are restored. The object needs to be
#' initialized with the same rhos.
#' @param state a previously saved correlation state, as returned by the
#' save_state method.
restore_state = function(state) {
private$.mvnorm <- state$mvnorm
}
)
)
Expand Down Expand Up @@ -164,7 +180,10 @@ CorrelationParameters <- R6::R6Class(
#'
#' # You can now pass the correlation parameters to the run_simulation function
get_correlation_parameters <- function(parameters) {
CorrelationParameters$new(parameters)
# Find a list of enabled interventions
enabled <- vlapply(INTS, function(name) parameters[[name]])

CorrelationParameters$new(parameters$human_population, INTS[enabled])
}

#' @title Sample a population to intervene in given the correlation parameters
Expand Down
8 changes: 8 additions & 0 deletions R/lag.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ LaggedValue <- R6::R6Class(

get = function(timestep) {
timeseries_at(private$history, timestep, TRUE)
},

save_state = function() {
timeseries_save_state(private$history)
},

restore_state = function(state) {
timeseries_restore_state(private$history, state)
}
)
)
Loading

0 comments on commit 54456df

Please sign in to comment.