From d16978e59261eca17193c6518101150ad95a1ee6 Mon Sep 17 00:00:00 2001 From: Thomas Brewer Date: Thu, 15 Feb 2024 11:52:51 +0000 Subject: [PATCH 1/4] Added check for undeveloped antimalarial resistance features (SPC/LPF/LCF/RDP). Included test to ensure check works and amended documentation in get_parameters() to warn users features in development. --- R/antimalarial_resistance.R | 12 +- R/parameters.R | 12 +- man/get_parameters.Rd | 12 +- tests/testthat/test-antimalarial-resistance.R | 243 ++++++++++++------ 4 files changed, 186 insertions(+), 93 deletions(-) diff --git a/R/antimalarial_resistance.R b/R/antimalarial_resistance.R index 532c042b..18d14d9a 100644 --- a/R/antimalarial_resistance.R +++ b/R/antimalarial_resistance.R @@ -26,6 +26,14 @@ set_antimalarial_resistance <- function(parameters, reinfection_prob, slow_parasite_clearance_time) { + if(any(partner_drug_resistance > 0, + slow_parasite_clearance_prob > 0, + late_clinical_failure_prob > 0, + late_parasitological_prob > 0, + reinfection_prob > 0)) { + stop("Parameters set for unimplemented feature - late clinical failure, late parasitological failure, or reinfection during prophylaxis") + } + if(any(c(length(artemisinin_resistance), length(partner_drug_resistance), length(slow_parasite_clearance_prob), @@ -48,8 +56,8 @@ set_antimalarial_resistance <- function(parameters, reinfection_prob < 0 | reinfection_prob > 1)) { stop("Resistance outcome probabilities must fall between 0 and 1") } - - if(length(slow_parasite_clearance_time) != 1) { + + if(length(slow_parasite_clearance_time) != 1) { stop("Error: length of slow_parasite_clearance_time not equal to 1") } diff --git a/R/parameters.R b/R/parameters.R index 9abbeb60..473225fa 100644 --- a/R/parameters.R +++ b/R/parameters.R @@ -201,13 +201,13 @@ #' * antimalarial_resistance_drug - vector of drugs for which resistance can be parameterised; default = NULL #' * antimalarial_resistance_timesteps - vector of time steps on which resistance updates occur; default = NULL #' * prop_artemisinin_resistant - vector of proportions of infections resistant to the artemisinin component of a given drug; default = NULL -#' * prop_partner_drug_resistant - vector of proportions of infections resistant to the parter drug component of a given drug; default = NULL -#' * slow_parasite_clearance_prob - vector of probabilities of slow parasite clearance for a given drug; default = NULL +#' * prop_partner_drug_resistant - vector of proportions of infections resistant to the parter drug component of a given drug; default = NULL (currently unimplemented) +#' * slow_parasite_clearance_prob - vector of probabilities of slow parasite clearance for a given drug; default = NULL (currently unimplemented) #' * early_treatment_failure_prob - vector of probabilities of early treatment failure for a given drug; default = NULL -#' * late_clinical_failure_prob - vector of probabilities of late clinical failure for a given drug; default = NULL -#' * late_parasitological_failure_prob - vector of probabilities of late parasitological failure for a given drug; default = NULL -#' * reinfection_during_prophylaxis - vector of probabilities of reinfection during prophylaxis for a given drug; default = NULL -#' * dt_slow_parasite_clearance - the delay for humans experiencing slow parasite clearance to move from state Tr to S; default = NULL +#' * late_clinical_failure_prob - vector of probabilities of late clinical failure for a given drug; default = NULL (currently unimplemented) +#' * late_parasitological_failure_prob - vector of probabilities of late parasitological failure for a given drug; default = NULL (currently unimplemented) +#' * reinfection_during_prophylaxis - vector of probabilities of reinfection during prophylaxis for a given drug; default = NULL (currently unimplemented) +#' * dt_slow_parasite_clearance - the delay for humans experiencing slow parasite clearance to move from state Tr to S; default = NULL (currently unimplemented) #' #' rendering: #' All values are in timesteps and all ranges are inclusive diff --git a/man/get_parameters.Rd b/man/get_parameters.Rd index fc6ba000..1a0dde29 100644 --- a/man/get_parameters.Rd +++ b/man/get_parameters.Rd @@ -222,13 +222,13 @@ please set antimalarial resistance parameters with the convenience functions in \item antimalarial_resistance_drug - vector of drugs for which resistance can be parameterised; default = NULL \item antimalarial_resistance_timesteps - vector of time steps on which resistance updates occur; default = NULL \item prop_artemisinin_resistant - vector of proportions of infections resistant to the artemisinin component of a given drug; default = NULL -\item prop_partner_drug_resistant - vector of proportions of infections resistant to the parter drug component of a given drug; default = NULL -\item slow_parasite_clearance_prob - vector of probabilities of slow parasite clearance for a given drug; default = NULL +\item prop_partner_drug_resistant - vector of proportions of infections resistant to the parter drug component of a given drug; default = NULL (currently unimplemented) +\item slow_parasite_clearance_prob - vector of probabilities of slow parasite clearance for a given drug; default = NULL (currently unimplemented) \item early_treatment_failure_prob - vector of probabilities of early treatment failure for a given drug; default = NULL -\item late_clinical_failure_prob - vector of probabilities of late clinical failure for a given drug; default = NULL -\item late_parasitological_failure_prob - vector of probabilities of late parasitological failure for a given drug; default = NULL -\item reinfection_during_prophylaxis - vector of probabilities of reinfection during prophylaxis for a given drug; default = NULL -\item dt_slow_parasite_clearance - the delay for humans experiencing slow parasite clearance to move from state Tr to S; default = NULL +\item late_clinical_failure_prob - vector of probabilities of late clinical failure for a given drug; default = NULL (currently unimplemented) +\item late_parasitological_failure_prob - vector of probabilities of late parasitological failure for a given drug; default = NULL (currently unimplemented) +\item reinfection_during_prophylaxis - vector of probabilities of reinfection during prophylaxis for a given drug; default = NULL (currently unimplemented) +\item dt_slow_parasite_clearance - the delay for humans experiencing slow parasite clearance to move from state Tr to S; default = NULL (currently unimplemented) } rendering: diff --git a/tests/testthat/test-antimalarial-resistance.R b/tests/testthat/test-antimalarial-resistance.R index 400d93a3..feb7f67a 100644 --- a/tests/testthat/test-antimalarial-resistance.R +++ b/tests/testthat/test-antimalarial-resistance.R @@ -10,11 +10,11 @@ test_that('set_antimalarial_resistance() toggles resistance on', { timesteps = 1, artemisinin_resistance = 0.5, partner_drug_resistance = 0, - slow_parasite_clearance_prob = 0.5, + slow_parasite_clearance_prob = 0, early_treatment_failure_prob = 0.6, - late_clinical_failure_prob = 0.2, - late_parasitological_prob = 0.3, - reinfection_prob = 0.4, + late_clinical_failure_prob = 0, + late_parasitological_prob = 0, + reinfection_prob = 0, slow_parasite_clearance_time = 10) -> simparams expect_identical(object = simparams$antimalarial_resistance, expected = TRUE) }) @@ -31,11 +31,11 @@ test_that('set_antimalarial_resistance() errors if parameter inputs of different timesteps = c(1, 10), artemisinin_resistance = 0.5, partner_drug_resistance = 0, - slow_parasite_clearance_prob = 0.5, + slow_parasite_clearance_prob = 0, early_treatment_failure_prob = 0.6, - late_clinical_failure_prob = 0.2, - late_parasitological_prob = 0.3, - reinfection_prob = 0.4, + late_clinical_failure_prob = 0, + late_parasitological_prob = 0, + reinfection_prob = 0, slow_parasite_clearance_time = 10)) }) @@ -51,11 +51,11 @@ test_that('set_antimalarial_resistance() errors if resistance proportions outsid timesteps = 1, artemisinin_resistance = 1.01, partner_drug_resistance = 0, - slow_parasite_clearance_prob = 0.5, + slow_parasite_clearance_prob = 0, early_treatment_failure_prob = 0.6, - late_clinical_failure_prob = 0.2, - late_parasitological_prob = 0.3, - reinfection_prob = 0.4, + late_clinical_failure_prob = 0, + late_parasitological_prob = 0, + reinfection_prob = 0, slow_parasite_clearance_time = 10), regexp = "Artemisinin and partner-drug resistance proportions must fall between 0 and 1") }) @@ -72,11 +72,11 @@ test_that('set_antimalarial_resistance() errors if resistance phenotype probabil timesteps = 1, artemisinin_resistance = 0.4, partner_drug_resistance = 0, - slow_parasite_clearance_prob = -0.5, - early_treatment_failure_prob = 0.6, - late_clinical_failure_prob = 0.2, - late_parasitological_prob = 0.3, - reinfection_prob = 0.4, + slow_parasite_clearance_prob = 0, + early_treatment_failure_prob = -0.6, + late_clinical_failure_prob = 0, + late_parasitological_prob = 0, + reinfection_prob = 0, slow_parasite_clearance_time = 5)) }) @@ -91,12 +91,12 @@ test_that('set_antimalarial_resistance() errors if drug index > than number of d drug = 2, timesteps = 1, artemisinin_resistance = 0.4, - partner_drug_resistance = 0.3, - slow_parasite_clearance_prob = 0.5, + partner_drug_resistance = 0, + slow_parasite_clearance_prob = 0, early_treatment_failure_prob = 0.6, - late_clinical_failure_prob = 0.2, - late_parasitological_prob = 0.3, - reinfection_prob = 0.4)) + late_clinical_failure_prob = 0, + late_parasitological_prob = 0, + reinfection_prob = 0)) }) test_that('set_antimalarial_resistance() assigns parameters correctly despite order in which resistance parameters are specified', { @@ -111,7 +111,7 @@ test_that('set_antimalarial_resistance() assigns parameters correctly despite or timesteps = 1, artemisinin_resistance = 0.5, partner_drug_resistance = 0, - slow_parasite_clearance_prob = 0.41, + slow_parasite_clearance_prob = 0, early_treatment_failure_prob = 0.2, late_clinical_failure_prob = 0, late_parasitological_prob = 0, @@ -120,36 +120,36 @@ test_that('set_antimalarial_resistance() assigns parameters correctly despite or parameters <- set_antimalarial_resistance(parameters = parameters, drug = 3, timesteps = 1, - artemisinin_resistance = 0, - partner_drug_resistance = 0.43, + artemisinin_resistance = 0.43, + partner_drug_resistance = 0, slow_parasite_clearance_prob = 0, early_treatment_failure_prob = 0, - late_clinical_failure_prob = 0.01, - late_parasitological_prob = 0.42, - reinfection_prob = 0.89, + late_clinical_failure_prob = 0, + late_parasitological_prob = 0, + reinfection_prob = 0, slow_parasite_clearance_time = 10) parameters <- set_antimalarial_resistance(parameters = parameters, drug = 1, timesteps = 1, artemisinin_resistance = 0.27, - partner_drug_resistance = 0.61, - slow_parasite_clearance_prob = 0.23, + partner_drug_resistance = 0, + slow_parasite_clearance_prob = 0, early_treatment_failure_prob = 0.9, - late_clinical_failure_prob = 0.49, - late_parasitological_prob = 0.81, - reinfection_prob = 0.009, + late_clinical_failure_prob = 0, + late_parasitological_prob = 0, + reinfection_prob = 0, slow_parasite_clearance_time = 20) expect_identical(parameters$antimalarial_resistance, TRUE) expect_identical(unlist(parameters$antimalarial_resistance_drug), c(2, 3, 1)) expect_identical(unlist(parameters$antimalarial_resistance_timesteps), rep(1, 3)) - expect_identical(unlist(parameters$prop_artemisinin_resistant), c(0.5, 0, 0.27)) - expect_identical(unlist(parameters$prop_partner_drug_resistant), c(0, 0.43, 0.61)) - expect_identical(unlist(parameters$slow_parasite_clearance_prob), c(0.41, 0, 0.23)) + expect_identical(unlist(parameters$prop_artemisinin_resistant), c(0.5, 0.43, 0.27)) + expect_identical(unlist(parameters$prop_partner_drug_resistant), c(0, 0, 0)) + expect_identical(unlist(parameters$slow_parasite_clearance_prob), c(0, 0, 0)) expect_identical(unlist(parameters$early_treatment_failure_prob), c(0.2, 0, 0.9)) - expect_identical(unlist(parameters$late_clinical_failure_prob), c(0, 0.01, 0.49)) - expect_identical(unlist(parameters$late_parasitological_failure_prob), c(0, 0.42, 0.81)) - expect_identical(unlist(parameters$reinfection_during_prophylaxis), c(0, 0.89, 0.009)) + expect_identical(unlist(parameters$late_clinical_failure_prob), c(0, 0, 0)) + expect_identical(unlist(parameters$late_parasitological_failure_prob), c(0, 0, 0)) + expect_identical(unlist(parameters$reinfection_during_prophylaxis), c(0, 0, 0)) expect_identical(unlist(parameters$dt_slow_parasite_clearance), c(5, 10, 20)) }) @@ -168,12 +168,12 @@ test_that(desc = "set_antimalarial_resistance errors if length slow_parasite_cle drug = 1, timesteps = c(0, 10), artemisinin_resistance = c(0.4, 0.8), - partner_drug_resistance = c(0.23, 0.43), - slow_parasite_clearance_prob = c(0.2, 0.4), + partner_drug_resistance = c(0, 0), + slow_parasite_clearance_prob = c(0, 0), early_treatment_failure_prob = c(0, 0.45), - late_clinical_failure_prob = c(0.01, 0.01), - late_parasitological_prob = c(0.05, 0.06), - reinfection_prob = c(0.86, 0.86), + late_clinical_failure_prob = c(0, 0), + late_parasitological_prob = c(0, 0), + reinfection_prob = c(0, 0), slow_parasite_clearance_time = c(10 ,11)), "Error: length of slow_parasite_clearance_time not equal to 1") }) @@ -192,16 +192,101 @@ test_that(desc = "set_antimalarial_resistance errors if slow_parasite_clearance_ drug = 1, timesteps = c(0, 10), artemisinin_resistance = c(0.4, 0.8), - partner_drug_resistance = c(0.23, 0.43), - slow_parasite_clearance_prob = c(0.2, 0.4), + partner_drug_resistance = c(0, 0), + slow_parasite_clearance_prob = c(0, 0), early_treatment_failure_prob = c(0, 0.45), - late_clinical_failure_prob = c(0.01, 0.01), - late_parasitological_prob = c(0.05, 0.06), - reinfection_prob = c(0.86, 0.86), + late_clinical_failure_prob = c(0, 0), + late_parasitological_prob = c(0, 0), + reinfection_prob = c(0, 0), slow_parasite_clearance_time = c(0)), "Error: slow_parasite_clearance_time is non-positive") }) +test_that("set_antimalarial_resistance() errors when users attempt to use undeveloped model features", { + + # Partner Drug Resistance + expect_error(get_parameters() |> + set_drugs(drugs = list(SP_AQ_params)) |> + set_clinical_treatment(drug = 1, timesteps = 1, coverages = 0.6) |> + set_antimalarial_resistance(drug = 1, + timesteps = c(1, 10), + artemisinin_resistance = c(0.4, 0.5), + partner_drug_resistance = c(0, 0.5), + slow_parasite_clearance_prob = c(0, 0), + early_treatment_failure_prob = c(0.8, 0.8), + late_clinical_failure_prob = c(0, 0), + late_parasitological_prob = c(0, 0), + reinfection_prob = c(0, 0), + slow_parasite_clearance_time = 10), + "Parameters set for unimplemented feature - late clinical failure, late parasitological failure, or reinfection during prophylaxis") + + + # Slow Parasite Clearance + expect_error(get_parameters() |> + set_drugs(drugs = list(SP_AQ_params)) |> + set_clinical_treatment(drug = 1, timesteps = 1, coverages = 0.6) |> + set_antimalarial_resistance(drug = 1, + timesteps = c(1, 10), + artemisinin_resistance = c(0.4, 0.5), + partner_drug_resistance = c(0, 0), + slow_parasite_clearance_prob = c(0, 0.0001), + early_treatment_failure_prob = c(0.8, 0.8), + late_clinical_failure_prob = c(0, 0), + late_parasitological_prob = c(0, 0), + reinfection_prob = c(0, 0), + slow_parasite_clearance_time = 10), + "Parameters set for unimplemented feature - late clinical failure, late parasitological failure, or reinfection during prophylaxis") + + # Late Clinical Failure + expect_error(get_parameters() |> + set_drugs(drugs = list(SP_AQ_params)) |> + set_clinical_treatment(drug = 1, timesteps = 1, coverages = 0.6) |> + set_antimalarial_resistance(drug = 1, + timesteps = c(1, 10), + artemisinin_resistance = c(0.4, 0.5), + partner_drug_resistance = c(0, 0), + slow_parasite_clearance_prob = c(0, 0), + early_treatment_failure_prob = c(0.8, 0.8), + late_clinical_failure_prob = c(0.6, 0.43), + late_parasitological_prob = c(0, 0), + reinfection_prob = c(0, 0), + slow_parasite_clearance_time = 10), + "Parameters set for unimplemented feature - late clinical failure, late parasitological failure, or reinfection during prophylaxis") + + # Late Parasitological Failure + expect_error(get_parameters() |> + set_drugs(drugs = list(SP_AQ_params)) |> + set_clinical_treatment(drug = 1, timesteps = 1, coverages = 0.6) |> + set_antimalarial_resistance(drug = 1, + timesteps = c(1, 10), + artemisinin_resistance = c(0.4, 0.5), + partner_drug_resistance = c(0, 0), + slow_parasite_clearance_prob = c(0, 0), + early_treatment_failure_prob = c(0.8, 0.8), + late_clinical_failure_prob = c(0, 0), + late_parasitological_prob = c(1, 0), + reinfection_prob = c(0, 0), + slow_parasite_clearance_time = 10), + "Parameters set for unimplemented feature - late clinical failure, late parasitological failure, or reinfection during prophylaxis") + + # Reinfection During Prophylaxis + expect_error(get_parameters() |> + set_drugs(drugs = list(SP_AQ_params)) |> + set_clinical_treatment(drug = 1, timesteps = 1, coverages = 0.6) |> + set_antimalarial_resistance(drug = 1, + timesteps = c(1, 10), + artemisinin_resistance = c(0.4, 0.5), + partner_drug_resistance = c(0, 0), + slow_parasite_clearance_prob = c(0, 0), + early_treatment_failure_prob = c(0.8, 0.8), + late_clinical_failure_prob = c(0, 0), + late_parasitological_prob = c(0, 0), + reinfection_prob = c(0.21, 0), + slow_parasite_clearance_time = 10), + "Parameters set for unimplemented feature - late clinical failure, late parasitological failure, or reinfection during prophylaxis") + +}) + test_that('get_antimalarial_resistance_parameters() correctly retrieves parameters when multiple drugs assigned', { get_parameters(overrides = list(human_population = 10000)) %>% @@ -213,32 +298,32 @@ test_that('get_antimalarial_resistance_parameters() correctly retrieves paramete set_antimalarial_resistance(drug = 2, timesteps = c(0, 20), artemisinin_resistance = c(0.02, 0.2), - partner_drug_resistance = c(0.02, 0.2), - slow_parasite_clearance_prob = c(0.02, 0.2), + partner_drug_resistance = c(0, 0), + slow_parasite_clearance_prob = c(0, 0), early_treatment_failure_prob = c(0.02, 0.2), - late_clinical_failure_prob = c(0.02, 0.2), - late_parasitological_prob = c(0.02, 0.2), - reinfection_prob = c(0.02, 0.2), + late_clinical_failure_prob = c(0, 0), + late_parasitological_prob = c(0, 0), + reinfection_prob = c(0, 0), slow_parasite_clearance_time = 20) %>% set_antimalarial_resistance(drug = 1, timesteps = c(0, 10), artemisinin_resistance = c(0.01, 0.1), - partner_drug_resistance = c(0.01, 0.1), - slow_parasite_clearance_prob = c(0.01, 0.1), + partner_drug_resistance = c(0, 0), + slow_parasite_clearance_prob = c(0, 0), early_treatment_failure_prob = c(0.01, 0.1), - late_clinical_failure_prob = c(0.01, 0.1), - late_parasitological_prob = c(0.01, 0.1), - reinfection_prob = c(0.01, 0.1), + late_clinical_failure_prob = c(0, 0), + late_parasitological_prob = c(0, 0), + reinfection_prob = c(0, 0), slow_parasite_clearance_time = 10) %>% set_antimalarial_resistance(drug = 3, timesteps = c(0, 30), artemisinin_resistance = c(0.03, 0.3), - partner_drug_resistance = c(0.03, 0.3), - slow_parasite_clearance_prob = c(0.03, 0.3), + partner_drug_resistance = c(0, 0), + slow_parasite_clearance_prob = c(0, 0), early_treatment_failure_prob = c(0.03, 0.3), - late_clinical_failure_prob = c(0.03, 0.3), - late_parasitological_prob = c(0.03, 0.3), - reinfection_prob = c(0.03, 0.3), + late_clinical_failure_prob = c(0, 0), + late_parasitological_prob = c(0, 0), + reinfection_prob = c(0, 0), slow_parasite_clearance_time = 30) -> parameters drugs <- c(1, 3, 2, 1, 2, 3, 3, 3, 2, 1, 3, 1, 2, 3, 2) @@ -250,12 +335,12 @@ test_that('get_antimalarial_resistance_parameters() correctly retrieves paramete expected_resistance_parameters <- list() expected_resistance_parameters$artemisinin_resistance_proportion <- c(0.1, 0.03, 0.2, 0.1, 0.2, 0.03, 0.03, 0.03, 0.2, 0.1, 0.03, 0.1, 0.2, 0.03, 0.2) - expected_resistance_parameters$partner_drug_resistance_proportion <- c(0.1, 0.03, 0.2, 0.1, 0.2, 0.03, 0.03, 0.03, 0.2, 0.1, 0.03, 0.1, 0.2, 0.03, 0.2) - expected_resistance_parameters$slow_parasite_clearance_probability <- c(0.1, 0.03, 0.2, 0.1, 0.2, 0.03, 0.03, 0.03, 0.2, 0.1, 0.03, 0.1, 0.2, 0.03, 0.2) + expected_resistance_parameters$partner_drug_resistance_proportion <- c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + expected_resistance_parameters$slow_parasite_clearance_probability <- c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) expected_resistance_parameters$early_treatment_failure_probability <- c(0.1, 0.03, 0.2, 0.1, 0.2, 0.03, 0.03, 0.03, 0.2, 0.1, 0.03, 0.1, 0.2, 0.03, 0.2) - expected_resistance_parameters$late_clinical_failure_probability <- c(0.1, 0.03, 0.2, 0.1, 0.2, 0.03, 0.03, 0.03, 0.2, 0.1, 0.03, 0.1, 0.2, 0.03, 0.2) - expected_resistance_parameters$late_parasitological_failure_probability <- c(0.1, 0.03, 0.2, 0.1, 0.2, 0.03, 0.03, 0.03, 0.2, 0.1, 0.03, 0.1, 0.2, 0.03, 0.2) - expected_resistance_parameters$reinfection_during_prophylaxis_probability <- c(0.1, 0.03, 0.2, 0.1, 0.2, 0.03, 0.03, 0.03, 0.2, 0.1, 0.03, 0.1, 0.2, 0.03, 0.2) + expected_resistance_parameters$late_clinical_failure_probability <- c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + expected_resistance_parameters$late_parasitological_failure_probability <- c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + expected_resistance_parameters$reinfection_during_prophylaxis_probability <- c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) expected_resistance_parameters$dt_slow_parasite_clearance <- c(10, 30, 20, 10, 20, 30, 30, 30, 20, 10, 30, 10, 20, 30, 20) expect_identical(resistance_parameters, expected = expected_resistance_parameters) @@ -273,12 +358,12 @@ test_that('get_antimalarial_resistance_parameters() correctly retrieves paramete set_antimalarial_resistance(drug = 2, timesteps = c(0, 20), artemisinin_resistance = c(0.02, 0.2), - partner_drug_resistance = c(0.02, 0.2), - slow_parasite_clearance_prob = c(0.02, 0.2), + partner_drug_resistance = c(0, 0), + slow_parasite_clearance_prob = c(0, 0), early_treatment_failure_prob = c(0.02, 0.2), - late_clinical_failure_prob = c(0.02, 0.2), - late_parasitological_prob = c(0.02, 0.2), - reinfection_prob = c(0.02, 0.2), + late_clinical_failure_prob = c(0, 0), + late_parasitological_prob = c(0, 0), + reinfection_prob = c(0, 0), slow_parasite_clearance_time = 20) -> parameters drugs <- c(1, 3, 2, 1, 2, 3, 3, 3, 2, 1, 3, 1, 2, 3, 2) @@ -290,12 +375,12 @@ test_that('get_antimalarial_resistance_parameters() correctly retrieves paramete expected_resistance_parameters <- list() expected_resistance_parameters$artemisinin_resistance_proportion <- c(0, 0, 0.2, 0, 0.2, 0, 0, 0, 0.2, 0, 0, 0, 0.2, 0, 0.2) - expected_resistance_parameters$partner_drug_resistance_proportion <- c(0, 0, 0.2, 0, 0.2, 0, 0, 0, 0.2, 0, 0, 0, 0.2, 0, 0.2) - expected_resistance_parameters$slow_parasite_clearance_probability <- c(0, 0, 0.2, 0, 0.2, 0, 0, 0, 0.2, 0, 0, 0, 0.2, 0, 0.2) + expected_resistance_parameters$partner_drug_resistance_proportion <- c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + expected_resistance_parameters$slow_parasite_clearance_probability <- c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) expected_resistance_parameters$early_treatment_failure_probability <- c(0, 0, 0.2, 0, 0.2, 0, 0, 0, 0.2, 0, 0, 0, 0.2, 0, 0.2) - expected_resistance_parameters$late_clinical_failure_probability <- c(0, 0, 0.2, 0, 0.2, 0, 0, 0, 0.2, 0, 0, 0, 0.2, 0, 0.2) - expected_resistance_parameters$late_parasitological_failure_probability <- c(0, 0, 0.2, 0, 0.2, 0, 0, 0, 0.2, 0, 0, 0, 0.2, 0, 0.2) - expected_resistance_parameters$reinfection_during_prophylaxis_probability <- c(0, 0, 0.2, 0, 0.2, 0, 0, 0, 0.2, 0, 0, 0, 0.2, 0, 0.2) + expected_resistance_parameters$late_clinical_failure_probability <- c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + expected_resistance_parameters$late_parasitological_failure_probability <- c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + expected_resistance_parameters$reinfection_during_prophylaxis_probability <- c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) expected_resistance_parameters$dt_slow_parasite_clearance <- c(5, 5, 20, 5, 20, 5, 5, 5, 20, 5, 5, 5, 20, 5, 20) expect_identical(resistance_parameters, expected = expected_resistance_parameters) From c42e2c52cee9f70e6067f8e6474e62f8428bcae5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul=20Li=C3=A9tar?= Date: Tue, 27 Feb 2024 20:02:49 +0000 Subject: [PATCH 2/4] Add initial support for saving and restoring simulation state. (#280) The model can now be run for a given number of steps, have its state saved, and then restore it and run for some more time steps. Parameters of the model may be modified when resuming, allowing a simulation to be forked with a single historical run and many futures, modelling different intervention scenarios. There are some limitations as to which parameters may be modified. In general, the structure of the simulation must remain the same, with the same variables and same events. This means interventions cannot be enabled or disabled when resuming, only their parameterization can change. Additionally, the timing of intervention events should not be modified. These limitations may be relaxed in the future, if there is a need for it. To avoid changing the existing API, this feature is available as a new `run_resumable_simulation` function. The function returns a pair of values, the existing dataframe with the simulation data and an object representing the internal simulation state. The function can be called a second time, passing the saved simulation state as an extra argument. See the `test-resume.R` file for usage of the new function. The implementation builds on top of individual's new support for this. Individual already saves the state of its native objects, ie. variables and events. The malaria model keeps quite a bit of state outside of individual, which we need to save and restore explicitly. This is done by creating a set of "stateful objects", ie. R6 objects with a pair `save_state` and `restore_state` methods. This interface is implemented by every bit of the model that has state to capture: - `LaggedValue` objects store the short term EIR and FOIM history. - `Solver` objects represent the current state of ODE solvers. - Adult mosquito ODEs keep a history of incubating values which need to be restored. - `CorrelationParameters` stores a randomly sampled value. This needs to be saved to ensure the simulation is resumed with that same value. - In addition to R's native RNG (whose state is already saved by individual), the model uses the dqrng library, whose state needs saving. --- DESCRIPTION | 6 +- R/RcppExports.R | 28 +++++++ R/biting_process.R | 8 +- R/compartmental.R | 81 +++++++++++++++++--- R/correlation.R | 19 +++++ R/lag.R | 8 ++ R/model.R | 64 ++++++++++++++-- R/mosquito_biology.R | 2 +- R/render.R | 2 +- R/stateful.R | 47 ++++++++++++ man/CorrelationParameters.Rd | 32 ++++++++ man/run_resumable_simulation.Rd | 34 +++++++++ src/Random.cpp | 15 +++- src/Random.h | 3 + src/RcppExports.cpp | 84 ++++++++++++++++++++- src/adult_mosquito_eqs.cpp | 24 +++++- src/adult_mosquito_eqs.h | 2 +- src/solver.cpp | 5 ++ src/timeseries.cpp | 51 +++++++++++-- src/timeseries.h | 5 +- src/utils.cpp | 10 +++ tests/testthat/helper-integration.R | 7 ++ tests/testthat/test-biting-integration.R | 2 +- tests/testthat/test-compartmental.R | 39 +++++++--- tests/testthat/test-emergence-integration.R | 33 ++++---- tests/testthat/test-resume.R | 41 ++++++++++ tests/testthat/test-seasonality.R | 6 +- 27 files changed, 583 insertions(+), 75 deletions(-) create mode 100644 R/stateful.R create mode 100644 man/run_resumable_simulation.Rd create mode 100644 tests/testthat/test-resume.R diff --git a/DESCRIPTION b/DESCRIPTION index 1d971723..ce25ce05 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -69,13 +69,15 @@ LazyData: true Remotes: mrc-ide/malariaEquilibrium, mrc-ide/individual +Additional_repositories: + https://mrc-ide.r-universe.dev Imports: - individual (>= 0.1.7), + individual (>= 0.1.13), malariaEquilibrium (>= 1.0.1), Rcpp, statmod, MASS, - dqrng, + dqrng (>= 0.3.2.2), sitmo, BH, R6, diff --git a/R/RcppExports.R b/R/RcppExports.R index 73a51dd0..01f3dc11 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -9,6 +9,14 @@ adult_mosquito_model_update <- function(model, mu, foim, susceptible, f) { invisible(.Call(`_malariasimulation_adult_mosquito_model_update`, model, mu, foim, susceptible, f)) } +adult_mosquito_model_save_state <- function(model) { + .Call(`_malariasimulation_adult_mosquito_model_save_state`, model) +} + +adult_mosquito_model_restore_state <- function(model, state) { + invisible(.Call(`_malariasimulation_adult_mosquito_model_restore_state`, model, state)) +} + create_adult_solver <- function(model, init, r_tol, a_tol, max_steps) { .Call(`_malariasimulation_create_adult_solver`, model, init, r_tol, a_tol, max_steps) } @@ -41,6 +49,10 @@ solver_get_states <- function(solver) { .Call(`_malariasimulation_solver_get_states`, solver) } +solver_set_states <- function(solver, state) { + invisible(.Call(`_malariasimulation_solver_set_states`, solver, state)) +} + solver_step <- function(solver) { invisible(.Call(`_malariasimulation_solver_step`, solver)) } @@ -57,10 +69,26 @@ timeseries_push <- function(timeseries, value, timestep) { invisible(.Call(`_malariasimulation_timeseries_push`, timeseries, value, timestep)) } +timeseries_save_state <- function(timeseries) { + .Call(`_malariasimulation_timeseries_save_state`, timeseries) +} + +timeseries_restore_state <- function(timeseries, state) { + invisible(.Call(`_malariasimulation_timeseries_restore_state`, timeseries, state)) +} + random_seed <- function(seed) { invisible(.Call(`_malariasimulation_random_seed`, seed)) } +random_save_state <- function() { + .Call(`_malariasimulation_random_save_state`) +} + +random_restore_state <- function(state) { + invisible(.Call(`_malariasimulation_random_restore_state`, state)) +} + bernoulli_multi_p_cpp <- function(p) { .Call(`_malariasimulation_bernoulli_multi_p_cpp`, p) } diff --git a/R/biting_process.R b/R/biting_process.R index ec09620d..801aea20 100644 --- a/R/biting_process.R +++ b/R/biting_process.R @@ -103,7 +103,7 @@ simulate_bites <- function( for (s_i in seq_along(parameters$species)) { species_name <- parameters$species[[s_i]] - solver_states <- solver_get_states(solvers[[s_i]]) + solver_states <- solvers[[s_i]]$get_states() p_bitten <- prob_bitten(timestep, variables, s_i, parameters) Q0 <- parameters$Q0[[s_i]] W <- average_p_successful(p_bitten$prob_bitten_survives, .pi, Q0) @@ -167,7 +167,7 @@ simulate_bites <- function( if (parameters$individual_mosquitoes) { # update the ODE with stats for ovoposition calculations aquatic_mosquito_model_update( - models[[s_i]], + models[[s_i]]$.model, species_index$size(), f, mu @@ -189,7 +189,7 @@ simulate_bites <- function( ) } else { adult_mosquito_model_update( - models[[s_i]], + models[[s_i]]$.model, mu, foim, solver_states[[ADULT_ODE_INDICES['Sm']]], @@ -235,7 +235,7 @@ calculate_infectious <- function(species, solvers, variables, parameters) { ) ) } - calculate_infectious_compartmental(solver_get_states(solvers[[species]])) + calculate_infectious_compartmental(solvers[[species]]$get_states()) } calculate_infectious_individual <- function( diff --git a/R/compartmental.R b/R/compartmental.R index f3405728..7df83a6e 100644 --- a/R/compartmental.R +++ b/R/compartmental.R @@ -50,16 +50,16 @@ parameterise_mosquito_models <- function(parameters, timesteps) { m )[ADULT_ODE_INDICES['Sm']] return( - create_adult_mosquito_model( + AdultMosquitoModel$new(create_adult_mosquito_model( growth_model, parameters$mum[[i]], parameters$dem, susceptible * parameters$init_foim, parameters$init_foim - ) + )) ) } - growth_model + AquaticMosquitoModel$new(growth_model) } ) } @@ -72,22 +72,22 @@ parameterise_solvers <- function(models, parameters) { init <- initial_mosquito_counts(parameters, i, parameters$init_foim, m) if (!parameters$individual_mosquitoes) { return( - create_adult_solver( - models[[i]], + Solver$new(create_adult_solver( + models[[i]]$.model, init, parameters$r_tol, parameters$a_tol, parameters$ode_max_steps - ) + )) ) } - create_aquatic_solver( - models[[i]], + Solver$new(create_aquatic_solver( + models[[i]]$.model, init[ODE_INDICES], parameters$r_tol, parameters$a_tol, parameters$ode_max_steps - ) + )) } ) } @@ -103,7 +103,7 @@ create_compartmental_rendering_process <- function(renderer, solvers, parameters counts <- rep(0, length(indices)) for (s_i in seq_along(solvers)) { if (parameters$species_proportions[[s_i]] > 0) { - row <- solver_get_states(solvers[[s_i]]) + row <- solvers[[s_i]]$get_states() } else { row <- rep(0, length(indices)) } @@ -128,8 +128,67 @@ create_solver_stepping_process <- function(solvers, parameters) { function(timestep) { for (i in seq_along(solvers)) { if (parameters$species_proportions[[i]] > 0) { - solver_step(solvers[[i]]) + solvers[[i]]$step() } } } } + +Solver <- R6::R6Class( + 'Solver', + private = list( + .solver = NULL + ), + public = list( + initialize = function(solver) { + private$.solver <- solver + }, + step = function() { + solver_step(private$.solver) + }, + get_states = function() { + solver_get_states(private$.solver) + }, + + # This is the same as `get_states`, just exposed under the interface that + # is expected of stateful objects. + save_state = function() { + solver_get_states(private$.solver) + }, + restore_state = function(state) { + solver_set_states(private$.solver, state) + } + ) +) + +AquaticMosquitoModel <- R6::R6Class( + 'AquaticMosquitoModel', + public = list( + .model = NULL, + initialize = function(model) { + self$.model <- model + }, + + # The aquatic mosquito model doesn't have any state to save or restore (the + # state of the ODE is stored separately). We still provide these methods to + # conform to the expected interface. + save_state = function() { NULL }, + restore_state = function(state) { } + ) +) + +AdultMosquitoModel <- R6::R6Class( + 'AdultMosquitoModel', + public = list( + .model = NULL, + initialize = function(model) { + self$.model <- model + }, + save_state = function() { + adult_mosquito_model_save_state(self$.model) + }, + restore_state = function(state) { + adult_mosquito_model_restore_state(self$.model, state) + } + ) +) diff --git a/R/correlation.R b/R/correlation.R index 68119369..df5f88f5 100644 --- a/R/correlation.R +++ b/R/correlation.R @@ -115,6 +115,25 @@ CorrelationParameters <- R6::R6Class( dimnames(private$.mvnorm)[[2]] <- private$interventions } private$.mvnorm + }, + + #' @description Save the correlation state. + save_state = function() { + # mvnorm is sampled at random lazily on its first use. We need to save it + # in order to restore the same value when resuming the simulation, + # otherwise we would be drawing a new, probably different, value. + # The rest of the object is derived deterministically from the parameters + # and does not need saving. + list(mvnorm=private$.mvnorm) + }, + + #' @description Restore the correlation state. + #' Only the randomly drawn weights are restored. The object needs to be + #' initialized with the same rhos. + #' @param state a previously saved correlation state, as returned by the + #' save_state method. + restore_state = function(state) { + private$.mvnorm <- state$mvnorm } ) ) diff --git a/R/lag.R b/R/lag.R index 35e47026..ca858ac7 100644 --- a/R/lag.R +++ b/R/lag.R @@ -14,6 +14,14 @@ LaggedValue <- R6::R6Class( get = function(timestep) { timeseries_at(private$history, timestep, TRUE) + }, + + save_state = function() { + timeseries_save_state(private$history) + }, + + restore_state = function(state) { + timeseries_restore_state(private$history, state) } ) ) diff --git a/R/model.R b/R/model.R index 4c4f7c2b..34bbfb46 100644 --- a/R/model.R +++ b/R/model.R @@ -87,6 +87,28 @@ run_simulation <- function( timesteps, parameters = NULL, correlations = NULL +) { + run_resumable_simulation(timesteps, parameters, correlations)$data +} + +#' @title Run the simulation in a resumable way +#' +#' @description this function accepts an initial simulation state as an argument, and returns the +#' final state after running all of its timesteps. This allows one run to be resumed, possibly +#' having changed some of the parameters. +#' @param timesteps the timestep at which to stop the simulation +#' @param parameters a named list of parameters to use +#' @param correlations correlation parameters +#' @param initial_state the state from which the simulation is resumed +#' @param restore_random_state if TRUE, restore the random number generator's state from the checkpoint. +#' @return a list with two entries, one for the dataframe of results and one for the final +#' simulation state. +run_resumable_simulation <- function( + timesteps, + parameters = NULL, + correlations = NULL, + initial_state = NULL, + restore_random_state = FALSE ) { random_seed(ceiling(runif(1) * .Machine$integer.max)) if (is.null(parameters)) { @@ -108,7 +130,23 @@ run_simulation <- function( ) vector_models <- parameterise_mosquito_models(parameters, timesteps) solvers <- parameterise_solvers(vector_models, parameters) - individual::simulation_loop( + + lagged_eir <- create_lagged_eir(variables, solvers, parameters) + lagged_infectivity <- create_lagged_infectivity(variables, parameters) + + stateful_objects <- unlist(list( + RandomState$new(restore_random_state), + correlations, + vector_models, + solvers, + lagged_eir, + lagged_infectivity)) + + if (!is.null(initial_state)) { + restore_state(initial_state$malariasimulation, stateful_objects) + } + + individual_state <- individual::simulation_loop( processes = create_processes( renderer, variables, @@ -117,15 +155,31 @@ run_simulation <- function( vector_models, solvers, correlations, - list(create_lagged_eir(variables, solvers, parameters)), - list(create_lagged_infectivity(variables, parameters)), + list(lagged_eir), + list(lagged_infectivity), timesteps ), variables = variables, events = unlist(events), - timesteps = timesteps + timesteps = timesteps, + state = initial_state$individual, + restore_random_state = restore_random_state ) - renderer$to_dataframe() + + final_state <- list( + timesteps=timesteps, + individual=individual_state, + malariasimulation=save_state(stateful_objects) + ) + + data <- renderer$to_dataframe() + if (!is.null(initial_state)) { + # Drop the timesteps we didn't simulate from the data. + # It would just be full of NA. + data <- data[-(1:initial_state$timesteps),] + } + + list(data=data, state=final_state) } #' @title Run a metapopulation model diff --git a/R/mosquito_biology.R b/R/mosquito_biology.R index 381450d9..4225a46f 100644 --- a/R/mosquito_biology.R +++ b/R/mosquito_biology.R @@ -220,7 +220,7 @@ create_mosquito_emergence_process <- function( p_counts <- vnapply( solvers, function(solver) { - solver_get_states(solver)[[ODE_INDICES[['P']]]] + solver$get_states()[[ODE_INDICES[['P']]]] } ) n <- sum(p_counts) * rate diff --git a/R/render.R b/R/render.R index 3fcf4882..60cb5b73 100644 --- a/R/render.R +++ b/R/render.R @@ -146,7 +146,7 @@ create_total_M_renderer_compartmental <- function(renderer, solvers, parameters) function(timestep) { total_M <- 0 for (i in seq_along(solvers)) { - row <- solver_get_states(solvers[[i]]) + row <- solvers[[i]]$get_states() species_M <- sum(row[ADULT_ODE_INDICES]) total_M <- total_M + species_M renderer$render(paste0('total_M_', parameters$species[[i]]), species_M, timestep) diff --git a/R/stateful.R b/R/stateful.R new file mode 100644 index 00000000..05c614a0 --- /dev/null +++ b/R/stateful.R @@ -0,0 +1,47 @@ +#' @title Save the state of a list of \emph{stateful objects} +#' @description The state of each element is saved and stored into a single +#' object, representing them in a way that can be exported and re-used later. +#' @param objects A list of stateful objects to be saved. Stateful objects are +#' instances of R6 classes with a pair of \code{save_state} and +#' \code{restore_state} methods. +#' @noRd +save_state <- function(objects) { + lapply(objects, function(o) o$save_state()) +} + +#' @title Restore the state of a collection of stateful objects +#' @description This is the counterpart of \code{save_state}. Calling it +#' restores the collection of objects back into their original state. +#' @param state A state object, as returned by the \code{save_state} function. +#' @param objects A collection of stateful objects to be restored. +#' @noRd +restore_state <- function(state, objects) { + stopifnot(length(state) == length(objects)) + for (i in seq_along(state)) { + objects[[i]]$restore_state(state[[i]]) + } +} + +#' @title a placeholder class to save the random number generator class. +#' @description the class integrates with the simulation loop to save and +#' restore the random number generator class when appropriate. +#' @noRd +RandomState <- R6::R6Class( + 'RandomState', + private = list( + restore_random_state = NULL + ), + public = list( + initialize = function(restore_random_state) { + private$restore_random_state <- restore_random_state + }, + save_state = function() { + random_save_state() + }, + restore_state = function(state) { + if (private$restore_random_state) { + random_restore_state(state) + } + } + ) +) diff --git a/man/CorrelationParameters.Rd b/man/CorrelationParameters.Rd index c2e6ada7..480d663d 100644 --- a/man/CorrelationParameters.Rd +++ b/man/CorrelationParameters.Rd @@ -19,6 +19,8 @@ Describes an event in the simulation \item \href{#method-CorrelationParameters-inter_intervention_rho}{\code{CorrelationParameters$inter_intervention_rho()}} \item \href{#method-CorrelationParameters-sigma}{\code{CorrelationParameters$sigma()}} \item \href{#method-CorrelationParameters-mvnorm}{\code{CorrelationParameters$mvnorm()}} +\item \href{#method-CorrelationParameters-save_state}{\code{CorrelationParameters$save_state()}} +\item \href{#method-CorrelationParameters-restore_state}{\code{CorrelationParameters$restore_state()}} \item \href{#method-CorrelationParameters-clone}{\code{CorrelationParameters$clone()}} } } @@ -101,6 +103,36 @@ multivariate norm draws for these parameters \if{html}{\out{
}}\preformatted{CorrelationParameters$mvnorm()}\if{html}{\out{
}} } +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-CorrelationParameters-save_state}{}}} +\subsection{Method \code{save_state()}}{ +Save the correlation state. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{CorrelationParameters$save_state()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-CorrelationParameters-restore_state}{}}} +\subsection{Method \code{restore_state()}}{ +Restore the correlation state. +Only the randomly drawn weights are restored. The object needs to be +initialized with the same rhos. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{CorrelationParameters$restore_state(state)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{state}}{a previously saved correlation state, as returned by the +save_state method.} +} +\if{html}{\out{
}} +} } \if{html}{\out{
}} \if{html}{\out{}} diff --git a/man/run_resumable_simulation.Rd b/man/run_resumable_simulation.Rd new file mode 100644 index 00000000..8991fd1c --- /dev/null +++ b/man/run_resumable_simulation.Rd @@ -0,0 +1,34 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/model.R +\name{run_resumable_simulation} +\alias{run_resumable_simulation} +\title{Run the simulation in a resumable way} +\usage{ +run_resumable_simulation( + timesteps, + parameters = NULL, + correlations = NULL, + initial_state = NULL, + restore_random_state = FALSE +) +} +\arguments{ +\item{timesteps}{the timestep at which to stop the simulation} + +\item{parameters}{a named list of parameters to use} + +\item{correlations}{correlation parameters} + +\item{initial_state}{the state from which the simulation is resumed} + +\item{restore_random_state}{if TRUE, restore the random number generator's state from the checkpoint.} +} +\value{ +a list with two entries, one for the dataframe of results and one for the final +simulation state. +} +\description{ +this function accepts an initial simulation state as an argument, and returns the +final state after running all of its timesteps. This allows one run to be resumed, possibly +having changed some of the parameters. +} diff --git a/src/Random.cpp b/src/Random.cpp index fad22963..2906ccc7 100644 --- a/src/Random.cpp +++ b/src/Random.cpp @@ -61,7 +61,7 @@ void Random::prop_sample_bucket( // all probabilities are the same if (heavy == n) { for (auto i = 0; i < size; ++i) { - *result = (*rng)(n); + *result = (*rng)((uint64_t)n); ++result; } return; @@ -122,10 +122,21 @@ void Random::prop_sample_bucket( // sample for (auto i = 0; i < size; ++i) { - size_t bucket = (*rng)(n); + size_t bucket = (*rng)((uint64_t)n); double acceptance = dqrng::uniform01((*rng)()); *result = (acceptance < dividing_probs[bucket]) ? bucket : alternative_index[bucket]; ++result; } } + +std::string Random::save_state() { + std::ostringstream stream; + stream << *rng; + return stream.str(); +} + +void Random::restore_state(std::string state) { + std::istringstream stream(state); + stream >> *rng; +} diff --git a/src/Random.h b/src/Random.h index b6e15039..e796fb9a 100644 --- a/src/Random.h +++ b/src/Random.h @@ -58,6 +58,9 @@ class Random : public RandomInterface { Random(Random &&other) = delete; Random& operator=(const Random &other) = delete; Random& operator=(Random &&other) = delete; + + std::string save_state(); + void restore_state(std::string state); private: Random() : rng(dqrng::generator(42)) {}; }; diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index f5c226fd..7675e8e6 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -40,6 +40,28 @@ BEGIN_RCPP return R_NilValue; END_RCPP } +// adult_mosquito_model_save_state +std::vector adult_mosquito_model_save_state(Rcpp::XPtr model); +RcppExport SEXP _malariasimulation_adult_mosquito_model_save_state(SEXP modelSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< Rcpp::XPtr >::type model(modelSEXP); + rcpp_result_gen = Rcpp::wrap(adult_mosquito_model_save_state(model)); + return rcpp_result_gen; +END_RCPP +} +// adult_mosquito_model_restore_state +void adult_mosquito_model_restore_state(Rcpp::XPtr model, std::vector state); +RcppExport SEXP _malariasimulation_adult_mosquito_model_restore_state(SEXP modelSEXP, SEXP stateSEXP) { +BEGIN_RCPP + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< Rcpp::XPtr >::type model(modelSEXP); + Rcpp::traits::input_parameter< std::vector >::type state(stateSEXP); + adult_mosquito_model_restore_state(model, state); + return R_NilValue; +END_RCPP +} // create_adult_solver Rcpp::XPtr create_adult_solver(Rcpp::XPtr model, std::vector init, double r_tol, double a_tol, size_t max_steps); RcppExport SEXP _malariasimulation_create_adult_solver(SEXP modelSEXP, SEXP initSEXP, SEXP r_tolSEXP, SEXP a_tolSEXP, SEXP max_stepsSEXP) { @@ -168,6 +190,17 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// solver_set_states +void solver_set_states(Rcpp::XPtr solver, std::vector state); +RcppExport SEXP _malariasimulation_solver_set_states(SEXP solverSEXP, SEXP stateSEXP) { +BEGIN_RCPP + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< Rcpp::XPtr >::type solver(solverSEXP); + Rcpp::traits::input_parameter< std::vector >::type state(stateSEXP); + solver_set_states(solver, state); + return R_NilValue; +END_RCPP +} // solver_step void solver_step(Rcpp::XPtr solver); RcppExport SEXP _malariasimulation_solver_step(SEXP solverSEXP) { @@ -215,6 +248,28 @@ BEGIN_RCPP return R_NilValue; END_RCPP } +// timeseries_save_state +Rcpp::List timeseries_save_state(Rcpp::XPtr timeseries); +RcppExport SEXP _malariasimulation_timeseries_save_state(SEXP timeseriesSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< Rcpp::XPtr >::type timeseries(timeseriesSEXP); + rcpp_result_gen = Rcpp::wrap(timeseries_save_state(timeseries)); + return rcpp_result_gen; +END_RCPP +} +// timeseries_restore_state +void timeseries_restore_state(Rcpp::XPtr timeseries, Rcpp::List state); +RcppExport SEXP _malariasimulation_timeseries_restore_state(SEXP timeseriesSEXP, SEXP stateSEXP) { +BEGIN_RCPP + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< Rcpp::XPtr >::type timeseries(timeseriesSEXP); + Rcpp::traits::input_parameter< Rcpp::List >::type state(stateSEXP); + timeseries_restore_state(timeseries, state); + return R_NilValue; +END_RCPP +} // random_seed void random_seed(size_t seed); RcppExport SEXP _malariasimulation_random_seed(SEXP seedSEXP) { @@ -225,6 +280,26 @@ BEGIN_RCPP return R_NilValue; END_RCPP } +// random_save_state +std::string random_save_state(); +RcppExport SEXP _malariasimulation_random_save_state() { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + rcpp_result_gen = Rcpp::wrap(random_save_state()); + return rcpp_result_gen; +END_RCPP +} +// random_restore_state +void random_restore_state(std::string state); +RcppExport SEXP _malariasimulation_random_restore_state(SEXP stateSEXP) { +BEGIN_RCPP + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< std::string >::type state(stateSEXP); + random_restore_state(state); + return R_NilValue; +END_RCPP +} // bernoulli_multi_p_cpp std::vector bernoulli_multi_p_cpp(const std::vector p); RcppExport SEXP _malariasimulation_bernoulli_multi_p_cpp(SEXP pSEXP) { @@ -260,11 +335,13 @@ BEGIN_RCPP END_RCPP } -RcppExport SEXP run_testthat_tests(); +RcppExport SEXP run_testthat_tests(void); static const R_CallMethodDef CallEntries[] = { {"_malariasimulation_create_adult_mosquito_model", (DL_FUNC) &_malariasimulation_create_adult_mosquito_model, 5}, {"_malariasimulation_adult_mosquito_model_update", (DL_FUNC) &_malariasimulation_adult_mosquito_model_update, 5}, + {"_malariasimulation_adult_mosquito_model_save_state", (DL_FUNC) &_malariasimulation_adult_mosquito_model_save_state, 1}, + {"_malariasimulation_adult_mosquito_model_restore_state", (DL_FUNC) &_malariasimulation_adult_mosquito_model_restore_state, 2}, {"_malariasimulation_create_adult_solver", (DL_FUNC) &_malariasimulation_create_adult_solver, 5}, {"_malariasimulation_create_aquatic_mosquito_model", (DL_FUNC) &_malariasimulation_create_aquatic_mosquito_model, 18}, {"_malariasimulation_aquatic_mosquito_model_update", (DL_FUNC) &_malariasimulation_aquatic_mosquito_model_update, 4}, @@ -273,11 +350,16 @@ static const R_CallMethodDef CallEntries[] = { {"_malariasimulation_eggs_laid", (DL_FUNC) &_malariasimulation_eggs_laid, 3}, {"_malariasimulation_rainfall", (DL_FUNC) &_malariasimulation_rainfall, 5}, {"_malariasimulation_solver_get_states", (DL_FUNC) &_malariasimulation_solver_get_states, 1}, + {"_malariasimulation_solver_set_states", (DL_FUNC) &_malariasimulation_solver_set_states, 2}, {"_malariasimulation_solver_step", (DL_FUNC) &_malariasimulation_solver_step, 1}, {"_malariasimulation_create_timeseries", (DL_FUNC) &_malariasimulation_create_timeseries, 2}, {"_malariasimulation_timeseries_at", (DL_FUNC) &_malariasimulation_timeseries_at, 3}, {"_malariasimulation_timeseries_push", (DL_FUNC) &_malariasimulation_timeseries_push, 3}, + {"_malariasimulation_timeseries_save_state", (DL_FUNC) &_malariasimulation_timeseries_save_state, 1}, + {"_malariasimulation_timeseries_restore_state", (DL_FUNC) &_malariasimulation_timeseries_restore_state, 2}, {"_malariasimulation_random_seed", (DL_FUNC) &_malariasimulation_random_seed, 1}, + {"_malariasimulation_random_save_state", (DL_FUNC) &_malariasimulation_random_save_state, 0}, + {"_malariasimulation_random_restore_state", (DL_FUNC) &_malariasimulation_random_restore_state, 1}, {"_malariasimulation_bernoulli_multi_p_cpp", (DL_FUNC) &_malariasimulation_bernoulli_multi_p_cpp, 1}, {"_malariasimulation_bitset_index_cpp", (DL_FUNC) &_malariasimulation_bitset_index_cpp, 2}, {"_malariasimulation_fast_weighted_sample", (DL_FUNC) &_malariasimulation_fast_weighted_sample, 2}, diff --git a/src/adult_mosquito_eqs.cpp b/src/adult_mosquito_eqs.cpp index ad5553eb..8126c7d7 100644 --- a/src/adult_mosquito_eqs.cpp +++ b/src/adult_mosquito_eqs.cpp @@ -17,7 +17,7 @@ AdultMosquitoModel::AdultMosquitoModel( ) : growth_model(growth_model), mu(mu), tau(tau), foim(foim) { for (auto i = 0u; i < tau; ++i) { - lagged_incubating.push(incubating); + lagged_incubating.push_back(incubating); } } @@ -82,12 +82,30 @@ void adult_mosquito_model_update( model->foim = foim; model->growth_model.f = f; model->growth_model.mum = mu; - model->lagged_incubating.push(susceptible * foim); + model->lagged_incubating.push_back(susceptible * foim); if (model->lagged_incubating.size() > 0) { - model->lagged_incubating.pop(); + model->lagged_incubating.pop_front(); } } +//[[Rcpp::export]] +std::vector adult_mosquito_model_save_state( + Rcpp::XPtr model + ) { + // Only the lagged_incubating needs to be saved. The rest of the model + // state is reset at each time step by a call to update before it gets + // used. + return {model->lagged_incubating.begin(), model->lagged_incubating.end()}; +} + +//[[Rcpp::export]] +void adult_mosquito_model_restore_state( + Rcpp::XPtr model, + std::vector state + ) { + model->lagged_incubating.assign(state.begin(), state.end()); +} + //[[Rcpp::export]] Rcpp::XPtr create_adult_solver( Rcpp::XPtr model, diff --git a/src/adult_mosquito_eqs.h b/src/adult_mosquito_eqs.h index 6fc30501..c3e0ac8b 100644 --- a/src/adult_mosquito_eqs.h +++ b/src/adult_mosquito_eqs.h @@ -28,7 +28,7 @@ enum class AdultState : size_t {S = 3, E = 4, I = 5}; */ struct AdultMosquitoModel { AquaticMosquitoModel growth_model; - std::queue lagged_incubating; //last tau values for incubating mosquitos + std::deque lagged_incubating; //last tau values for incubating mosquitos double mu; //death rate for adult female mosquitoes const double tau; //extrinsic incubation period double foim; //force of infection towards mosquitoes diff --git a/src/solver.cpp b/src/solver.cpp index 7cb8b9f4..0a5f7401 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -47,6 +47,11 @@ std::vector solver_get_states(Rcpp::XPtr solver) { return solver->state; } +//[[Rcpp::export]] +void solver_set_states(Rcpp::XPtr solver, std::vector state) { + solver->state = state; +} + //[[Rcpp::export]] void solver_step(Rcpp::XPtr solver) { solver->step(); diff --git a/src/timeseries.cpp b/src/timeseries.cpp index 2383ac73..51a494de 100644 --- a/src/timeseries.cpp +++ b/src/timeseries.cpp @@ -15,18 +15,18 @@ Timeseries::Timeseries(size_t max_size, double default_value) : max_size(max_size), has_default(true), default_value(default_value) {} void Timeseries::push(double value, double time) { - values.insert({time, value}); + _values.insert({time, value}); if (max_size != -1) { - while(values.size() > max_size) { - values.erase(values.begin()); + while(_values.size() > max_size) { + _values.erase(_values.begin()); } } } double Timeseries::at(double time, bool linear) const { - auto it = values.lower_bound(time); - if (it == values.end()) { - if (values.size() > 0 && !linear) { + auto it = _values.lower_bound(time); + if (it == _values.end()) { + if (_values.size() > 0 && !linear) { it--; return it->second; } @@ -45,7 +45,7 @@ double Timeseries::at(double time, bool linear) const { auto after_element = *it; while(it->first > time) { // Check if we're at the start of the timeseries - if (it == values.begin()) { + if (it == _values.begin()) { if (has_default) { return default_value; } @@ -64,6 +64,14 @@ double Timeseries::at(double time, bool linear) const { return it->second; } +const std::map& Timeseries::values() { + return _values; +} + +void Timeseries::set_values(std::map values) { + _values = std::move(values); +} + //[[Rcpp::export]] Rcpp::XPtr create_timeseries(size_t size, double default_value) { return Rcpp::XPtr(new Timeseries(size, default_value), true); @@ -78,3 +86,32 @@ double timeseries_at(Rcpp::XPtr timeseries, double timestep, bool li void timeseries_push(Rcpp::XPtr timeseries, double value, double timestep) { return timeseries->push(value, timestep); } + +//[[Rcpp::export]] +Rcpp::List timeseries_save_state(Rcpp::XPtr timeseries) { + std::vector timesteps; + std::vector values; + for (const auto& entry: timeseries->values()) { + timesteps.push_back(entry.first); + values.push_back(entry.second); + } + return Rcpp::DataFrame::create( + Rcpp::Named("timestep") = timesteps, + Rcpp::Named("value") = values + ); +} + +//[[Rcpp::export]] +void timeseries_restore_state(Rcpp::XPtr timeseries, Rcpp::List state) { + std::vector timesteps = state["timestep"]; + std::vector values = state["value"]; + if (timesteps.size() != values.size()) { + Rcpp::stop("Bad size"); + } + + std::map values_map; + for (size_t i = 0; i < timesteps.size(); i++) { + values_map.insert({timesteps[i], values[i]}); + } + timeseries->set_values(std::move(values_map)); +} diff --git a/src/timeseries.h b/src/timeseries.h index c9e7c32f..78f3780b 100644 --- a/src/timeseries.h +++ b/src/timeseries.h @@ -12,7 +12,7 @@ class Timeseries { private: - std::map values; + std::map _values; size_t max_size; void clean(); bool has_default; @@ -23,6 +23,9 @@ class Timeseries { Timeseries(size_t, double); void push(double, double); double at(double, bool = true) const; + + const std::map& values(); + void set_values(std::map state); }; #endif /* SRC_TIMESERIES_ */ diff --git a/src/utils.cpp b/src/utils.cpp index 838a9530..d2b36043 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -8,6 +8,16 @@ void random_seed(size_t seed) { Random::get_instance().seed(seed); } +//[[Rcpp::export]] +std::string random_save_state() { + return Random::get_instance().save_state(); +} + +//[[Rcpp::export]] +void random_restore_state(std::string state) { + return Random::get_instance().restore_state(state); +} + //[[Rcpp::export]] std::vector bernoulli_multi_p_cpp(const std::vector p) { auto values = Random::get_instance().bernoulli_multi_p(p); diff --git a/tests/testthat/helper-integration.R b/tests/testthat/helper-integration.R index b69b7dcb..22c7dd63 100644 --- a/tests/testthat/helper-integration.R +++ b/tests/testthat/helper-integration.R @@ -88,6 +88,13 @@ mock_event <- function(event) { ) } +mock_solver <- function(states) { + list( + get_states = mockery::mock(states), + step = mockery::mock() + ) +} + expect_bitset_update <- function(mock, value, index, call = 1) { expect_equal(mockery::mock_args(mock)[[call]][[1]], value) expect_equal(mockery::mock_args(mock)[[call]][[2]]$to_vector(), index) diff --git a/tests/testthat/test-biting-integration.R b/tests/testthat/test-biting-integration.R index 66bf5b72..7da79b66 100644 --- a/tests/testthat/test-biting-integration.R +++ b/tests/testthat/test-biting-integration.R @@ -131,7 +131,7 @@ test_that('simulate_bites integrates eir calculation and mosquito side effects', expect_equal(effects_args[[1]][[8]], parameters) expect_equal(effects_args[[1]][[9]], timestep) - mockery::expect_args(eqs_update, 1, models[[1]], 25, f, parameters$mum) + mockery::expect_args(eqs_update, 1, models[[1]]$.model, 25, f, parameters$mum) mockery::expect_args( pois_mock, 1, diff --git a/tests/testthat/test-compartmental.R b/tests/testthat/test-compartmental.R index a4fb3b6a..9958c7a5 100644 --- a/tests/testthat/test-compartmental.R +++ b/tests/testthat/test-compartmental.R @@ -11,9 +11,14 @@ test_that('ODE stays at equilibrium with a constant total_M', { counts <- c() for (t in seq(timesteps)) { - counts <- rbind(counts, c(t, solver_get_states(solvers[[1]]))) - aquatic_mosquito_model_update(models[[1]], total_M, f, parameters$mum) - solver_step(solvers[[1]]) + counts <- rbind(counts, c(t, solvers[[1]]$get_states())) + aquatic_mosquito_model_update( + models[[1]]$.model, + total_M, + f, + parameters$mum + ) + solvers[[1]]$step() } expected <- c() @@ -41,16 +46,16 @@ test_that('Adult ODE stays at equilibrium with a constant foim and mu', { counts <- c() for (t in seq(timesteps)) { - states <- solver_get_states(solvers[[1]]) + states <- solvers[[1]]$get_states() counts <- rbind(counts, c(t, states)) adult_mosquito_model_update( - models[[1]], + models[[1]]$.model, parameters$mum, parameters$init_foim, states[ADULT_ODE_INDICES['Sm']], f ) - solver_step(solvers[[1]]) + solvers[[1]]$step() } expected <- c() @@ -82,9 +87,14 @@ test_that('ODE stays at equilibrium with low total_M', { counts <- c() for (t in seq(timesteps)) { - counts <- rbind(counts, c(t, solver_get_states(solvers[[1]]))) - aquatic_mosquito_model_update(models[[1]], total_M, f, parameters$mum) - solver_step(solvers[[1]]) + counts <- rbind(counts, c(t, solvers[[1]]$get_states())) + aquatic_mosquito_model_update( + models[[1]]$.model, + total_M, + f, + parameters$mum + ) + solvers[[1]]$step() } expected <- c() @@ -121,14 +131,19 @@ test_that('Changing total_M stabilises', { counts <- c() for (t in seq(timesteps)) { - counts <- rbind(counts, c(t, solver_get_states(solvers[[1]]))) + counts <- rbind(counts, c(t, solvers[[1]]$get_states())) if (t < change) { total_M <- total_M_0 } else { total_M <- total_M_1 } - aquatic_mosquito_model_update(models[[1]], total_M, f, parameters$mum) - solver_step(solvers[[1]]) + aquatic_mosquito_model_update( + models[[1]]$.model, + total_M, + f, + parameters$mum + ) + solvers[[1]]$step() } initial_eq <- initial_mosquito_counts( diff --git a/tests/testthat/test-emergence-integration.R b/tests/testthat/test-emergence-integration.R index fa7464be..bf452d2f 100644 --- a/tests/testthat/test-emergence-integration.R +++ b/tests/testthat/test-emergence-integration.R @@ -8,22 +8,19 @@ test_that('emergence process fails when there are not enough individuals', { c('gamb'), rep('gamb', 2000) ) + solvers <- list( + mock_solver(c(1000, 500, 100)), + mock_solver(c(1000, 500, 100)) + ) + emergence_process <- create_mosquito_emergence_process( - list(), + solvers, state, species, c('gamb'), parameters$dpl ) - mockery::stub( - emergence_process, - 'solver_get_states', - mockery::mock( - c(1000, 500, 100), - c(1000, 500, 100) - ) - ) - expect_error(emergence_process(0), '*') + expect_error(emergence_process(0), 'Not enough mosquitoes') }) test_that('emergence_process creates the correct number of susceptables', { @@ -36,23 +33,19 @@ test_that('emergence_process creates the correct number of susceptables', { c('a', 'b'), c('a', 'b') ) + solvers <- list( + mock_solver(c(100000, 50000, 10000)), + mock_solver(c(1000, 5000, 1000)) + ) + emergence_process <- create_mosquito_emergence_process( - list(mockery::mock(), mockery::mock()), + solvers, state, species, c('a', 'b'), parameters$dpl ) - mockery::stub( - emergence_process, - 'solver_get_states', - mockery::mock( - c(100000, 50000, 10000), - c(10000, 5000, 1000) - ) - ) - emergence_process(0) expect_bitset_update( diff --git a/tests/testthat/test-resume.R b/tests/testthat/test-resume.R new file mode 100644 index 00000000..83bada35 --- /dev/null +++ b/tests/testthat/test-resume.R @@ -0,0 +1,41 @@ +test_that('Simulation can be resumed', { + initial_timesteps <- 50 + total_timesteps <- 100 + + parameters <- get_parameters() + + set.seed(1) + first_phase <- run_resumable_simulation(initial_timesteps, parameters) + second_phase <- run_resumable_simulation( + total_timesteps, + parameters, + initial_state=first_phase$state, + restore_random_state=TRUE) + + set.seed(1) + uninterrupted_run <- run_simulation(total_timesteps, parameters=parameters) + + expect_equal(nrow(first_phase$data), initial_timesteps) + expect_equal(nrow(second_phase$data), total_timesteps - initial_timesteps) + expect_equal(rbind(first_phase$data, second_phase$data), uninterrupted_run) +}) + +test_that('Intervention parameters can be changed when resuming', { + initial_timesteps <- 50 + total_timesteps <- 100 + tbv_timesteps <- 70 + + # Because of how event scheduling works, we must enable TBV even in the inital phase. + # We set a coverage of 0 to act as-if it was disabled. + initial_parameters <- get_parameters() |> set_tbv(timesteps=tbv_timesteps, coverage=0, ages=5:60) + + tbv_parameters <- initial_parameters |> + set_tbv(timesteps=tbv_timesteps, coverage=1, ages=5:60) + + initial_run <- run_resumable_simulation(initial_timesteps, initial_parameters) + control_run <- run_resumable_simulation(total_timesteps, initial_parameters, initial_state = initial_run$state) + tbv_run <- run_resumable_simulation(total_timesteps, tbv_parameters, initial_state = initial_run$state) + + expect_equal(control_run$data$n_vaccinated_tbv[tbv_timesteps - initial_timesteps], 0) + expect_gt(tbv_run$data$n_vaccinated_tbv[tbv_timesteps - initial_timesteps], 0) +}) diff --git a/tests/testthat/test-seasonality.R b/tests/testthat/test-seasonality.R index 01302d8d..f600ebf4 100644 --- a/tests/testthat/test-seasonality.R +++ b/tests/testthat/test-seasonality.R @@ -15,14 +15,14 @@ test_that('Seasonality correctly affects P', { counts <- c() for (t in seq(timesteps)) { - counts <- rbind(counts, c(t, solver_get_states(solvers[[1]]))) + counts <- rbind(counts, c(t, solvers[[1]]$get_states())) aquatic_mosquito_model_update( - models[[1]], + models[[1]]$.model, total_M, parameters$blood_meal_rates, parameters$mum ) - solver_step(solvers[[1]]) + solvers[[1]]$step() } burn_in <- 20 From b3376d33cda29cc5e4d7217fefad4b367f9f9167 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul=20Li=C3=A9tar?= Date: Thu, 21 Mar 2024 13:21:06 +0000 Subject: [PATCH 3/4] Rewrite the exponential decay process in C++. (#285) The existing process written in R needs to copy the contents of each variable from C++ using `v$get_values()`, then after scaling the vector it would copy the result back into C++ using `v$queue_update`. The amount of data copied and the time it took was pretty significant. 6 double variables, one for each kind of immunity, need to be updated in full at each time step, each as big as the population size. Moving this into C++ removes the need for any copy at all, besides the multication loop. Values are read out of a reference to the vector held by the DoubleVariable, the result of the multication is moved to the queue, and finally individual moves the vector in the queue into the DoubleVariable. The speedup from this change for a 1M population size is around 10%. An alternative optimization I considered was to compute the exponential decay lazily, recording only the timestep and value at which the immunity was last updated and using the closed form expression of the exponential decay. This would avoid the need to have mass updates of the immunity variables at every time step. Unfortunately in my testing this ends up being slower than even the current implementation, with all the time being spent in calculating the current value. This would also be a much more intrusive change, since every use of the immunity variables needs to be modified to take the last update timestep, the current timestep and the decay rate into consideration. --- DESCRIPTION | 2 +- R/RcppExports.R | 4 ++ R/processes.R | 3 +- src/RcppExports.cpp | 13 ++++ src/processes.cpp | 117 ++++++++++++++++++++++++++++++++ tests/testthat/test-processes.R | 23 +++++++ 6 files changed, 160 insertions(+), 2 deletions(-) create mode 100644 src/processes.cpp create mode 100644 tests/testthat/test-processes.R diff --git a/DESCRIPTION b/DESCRIPTION index ce25ce05..d92fa304 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -72,7 +72,7 @@ Remotes: Additional_repositories: https://mrc-ide.r-universe.dev Imports: - individual (>= 0.1.13), + individual (>= 0.1.15), malariaEquilibrium (>= 1.0.1), Rcpp, statmod, diff --git a/R/RcppExports.R b/R/RcppExports.R index 01f3dc11..678a50bf 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -45,6 +45,10 @@ rainfall <- function(t, g0, g, h, floor) { .Call(`_malariasimulation_rainfall`, t, g0, g, h, floor) } +exponential_process_cpp <- function(variable, rate) { + .Call(`_malariasimulation_exponential_process_cpp`, variable, rate) +} + solver_get_states <- function(solver) { .Call(`_malariasimulation_solver_get_states`, solver) } diff --git a/R/processes.R b/R/processes.R index d3b11fda..28810263 100644 --- a/R/processes.R +++ b/R/processes.R @@ -261,8 +261,9 @@ create_processes <- function( #' @param rate the exponential rate #' @noRd create_exponential_decay_process <- function(variable, rate) { + stopifnot(inherits(variable, "DoubleVariable")) decay_rate <- exp(-1/rate) - function(timestep) variable$queue_update(variable$get_values() * decay_rate) + exponential_process_cpp(variable$.variable, decay_rate) } #' @title Create and initialise lagged_infectivity object diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 7675e8e6..439cb12b 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -179,6 +179,18 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// exponential_process_cpp +Rcpp::XPtr exponential_process_cpp(Rcpp::XPtr variable, const double rate); +RcppExport SEXP _malariasimulation_exponential_process_cpp(SEXP variableSEXP, SEXP rateSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< Rcpp::XPtr >::type variable(variableSEXP); + Rcpp::traits::input_parameter< const double >::type rate(rateSEXP); + rcpp_result_gen = Rcpp::wrap(exponential_process_cpp(variable, rate)); + return rcpp_result_gen; +END_RCPP +} // solver_get_states std::vector solver_get_states(Rcpp::XPtr solver); RcppExport SEXP _malariasimulation_solver_get_states(SEXP solverSEXP) { @@ -349,6 +361,7 @@ static const R_CallMethodDef CallEntries[] = { {"_malariasimulation_carrying_capacity", (DL_FUNC) &_malariasimulation_carrying_capacity, 8}, {"_malariasimulation_eggs_laid", (DL_FUNC) &_malariasimulation_eggs_laid, 3}, {"_malariasimulation_rainfall", (DL_FUNC) &_malariasimulation_rainfall, 5}, + {"_malariasimulation_exponential_process_cpp", (DL_FUNC) &_malariasimulation_exponential_process_cpp, 2}, {"_malariasimulation_solver_get_states", (DL_FUNC) &_malariasimulation_solver_get_states, 1}, {"_malariasimulation_solver_set_states", (DL_FUNC) &_malariasimulation_solver_set_states, 2}, {"_malariasimulation_solver_step", (DL_FUNC) &_malariasimulation_solver_step, 1}, diff --git a/src/processes.cpp b/src/processes.cpp new file mode 100644 index 00000000..b5db4bad --- /dev/null +++ b/src/processes.cpp @@ -0,0 +1,117 @@ +#include +#include + +/** + * An iterator adaptor which yields the same values as the underlying iterator, + * but scaled by a pre-determined factor. + * + * This is used by the exponential_process below to scale an std::vector by a + * constant. + * + * There are two straightforward ways of performing the operation. The first is + * to create an empty vector, use `reserve(N)` to pre-allocate the vector and + * then call `push_back` with each new value. The second way would be to create + * a zero-initialised vector of size N and then use `operator[]` to fill in the + * values. + * + * Unfortunately both approaches have significant overhead. In the former, the + * use of `push_back` requires repeated checks as to whether the vector needs + * growing, despite the prior reserve call. These calls inhibits optimizations + * such as unrolling and auto-vectorization of the loop. The latter approach + * requires an initial memset when zero-initializing the vector, even though the + * vector then gets overwritten entirely. Sadly gcc fails to optimize out either + * of those. Ideally we want a way to create a pre-sized but uninitialised + * std::vector we can write to ourselves, but there is no API in the standard + * library to do this. All existing workarounds end up with an std::vector with + * non-default item type or allocators. + * + * There is however a way out! std::vector has a constructor which accepts a + * pair of iterators and fills the vector with values from the iterators. Using + * `std::distance` on the iterator pair it can even pre-allocate the vector to + * the right size. No zero-initialisation or no capacity checks, just one + * allocation and a straightforward easily optimizable loop. All we need is an + * iterator yielding the right values, hence `scale_iterator`. In C++20 we would + * probably be able to use the new ranges library as our iterators. + * + * How much does this matter? On microbenchmarks, for small and medium sized + * vector (<= 1M doubles), this version is about 30% faster than the + * zero-initialising implementation and 60% faster than the one which uses + * push_back. For larger vector sizes the difference is less pronounced, + * possibly because caches become saturated. At the time of writing, on a + * real-word run of malariasimulation with a population size of 1M the overall + * speedup is about 2-3%. + * + * https://wolchok.org/posts/cxx-trap-1-constant-size-vector/ + * https://codingnest.com/the-little-things-the-missing-performance-in-std-vector/ + * https://lemire.me/blog/2012/06/20/do-not-waste-time-with-stl-vectors/ + */ +template +struct scale_iterator { + using iterator_category = std::forward_iterator_tag; + using difference_type = typename std::iterator_traits::difference_type; + using value_type = typename std::iterator_traits::value_type; + using pointer = typename std::iterator_traits::pointer; + + // We skirt the rules a bit by returning a prvalue from `operator*`, even + // though the C++17 (and prior) standard says forward iterators are supposed + // to return a reference type (ie. a glvalue). Because the scaling is + // applied on the fly, there is no glvalue we could return a reference to. + // + // An input iterator would be allowed to return a prvalue, but the + // std::vector constructor wouldn't be able to figure out the length ahead + // of time if we were an input iterator. + // + // C++20 actually introduces parallel definitions of input and forward + // iterators, which relax this requirement, so under that classification our + // implementation in correct. + // + // In practice though, this does not really matter. We only use this + // iterator in one specific context, and the vector constructor doesn't do + // anything elaborate that we would be upsetting. + using reference = value_type; + + scale_iterator(underlying_iterator it, value_type factor) : it(it), factor(factor) {} + reference operator*() { + return factor * (*it); + } + bool operator==(const scale_iterator& other) { + return it == other.it; + } + bool operator!=(const scale_iterator& other) { + return it != other.it; + } + scale_iterator& operator++() { + it++; + return *this; + } + scale_iterator operator++(int) { + return scale_iterator(it++, factor); + } + + private: + underlying_iterator it; + value_type factor; +}; + +template +scale_iterator make_scale_iterator(T&& it, typename std::iterator_traits::value_type scale) { + return scale_iterator(std::forward(it), scale); +} + +//[[Rcpp::export]] +Rcpp::XPtr exponential_process_cpp( + Rcpp::XPtr variable, + const double rate +){ + return Rcpp::XPtr( + new process_t([=](size_t t){ + const std::vector& values = variable->get_values(); + std::vector new_values( + make_scale_iterator(values.cbegin(), rate), + make_scale_iterator(values.cend(), rate)); + + variable->queue_update(std::move(new_values), std::vector()); + }), + true + ); +} diff --git a/tests/testthat/test-processes.R b/tests/testthat/test-processes.R new file mode 100644 index 00000000..29f0d6b0 --- /dev/null +++ b/tests/testthat/test-processes.R @@ -0,0 +1,23 @@ +test_that("exponential_decay_process works as expected", { + # This rate gives a halving at every timestep + rate <- -1 / log(0.5) + + v <- individual::DoubleVariable$new(c(0,0.5,1,2,4,10)) + p <- create_exponential_decay_process(v, rate) + + individual:::execute_any_process(p, 1) + v$.update() + + expect_equal(v$get_values(), c(0, 0.25, 0.5, 1, 2, 5)) + + individual:::execute_any_process(p, 2) + v$.update() + + expect_equal(v$get_values(), c(0, 0.125, 0.25, 0.5, 1, 2.5)) +}) + +test_that("exponential_decay_process fails on IntegerVariable", { + rate <- -1 / log(0.5) + v <- individual::IntegerVariable$new(c(0,1,2,3)) + expect_error(create_exponential_decay_process(v, rate)) +}) From 2a3d4cc3a77a0ff0a953d8a3727ef90501b860b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul=20Li=C3=A9tar?= Date: Wed, 3 Apr 2024 10:56:05 +0100 Subject: [PATCH 4/4] Improve correlation tests. (#287) The tests for the correlation parameters were using a tolerance of 1e2 everywhere. Since the tolerance is relative, assertions would effectively succeed for any value within a few orders of magnitude. Setting obviously wrong expected results in the assertions did not produce any errors. This switches the tolerance to 0.1. This was chosen through experimenting, as something that didn't cause false negative even after many runs, while also being reasonably close. When using 0.01 I did get 1 failure in 100 runs of the test suite. Some of the assertions were using incorrected expected values, which flew under the radar because of the huge tolerance. I've also cleared up the tests and bit and made them more consitent with one another. Finally I changed the `CorrelationParameters` constructor to accept only the values it needs (population size and intervention booleans), rather than the full simulation parameters. This makes the test a bit more concise, and will also help with upcoming tests that work at restoring correlation state while adding interventions. The existing public wrapper `get_correlation_parameters` still has the same interface as before. --- R/correlation.R | 18 ++-- tests/testthat/test-correlation.R | 159 ++++++++++++++++++------------ 2 files changed, 104 insertions(+), 73 deletions(-) diff --git a/R/correlation.R b/R/correlation.R index df5f88f5..458a3015 100644 --- a/R/correlation.R +++ b/R/correlation.R @@ -24,11 +24,11 @@ CorrelationParameters <- R6::R6Class( public = list( #' @description initialise correlation parameters - #' @param parameters model parameters - initialize = function(parameters) { - # Find a list of enabled interventions - enabled <- vlapply(INTS, function(name) parameters[[name]]) - private$interventions <- INTS[enabled] + #' @param population popularion size + #' @param interventions character vector with the name of enabled interventions + initialize = function(population, interventions) { + private$population <- population + private$interventions <- interventions # Initialise a rho matrix for our interventions n_ints <- private$n_ints() @@ -38,9 +38,6 @@ CorrelationParameters <- R6::R6Class( ncol = n_ints, dimnames = list(private$interventions, private$interventions) ) - - # Store population for mvnorm draws - private$population <- parameters$human_population }, #' @description Add rho between rounds @@ -183,7 +180,10 @@ CorrelationParameters <- R6::R6Class( #' #' # You can now pass the correlation parameters to the run_simulation function get_correlation_parameters <- function(parameters) { - CorrelationParameters$new(parameters) + # Find a list of enabled interventions + enabled <- vlapply(INTS, function(name) parameters[[name]]) + + CorrelationParameters$new(parameters$human_population, INTS[enabled]) } #' @title Sample a population to intervene in given the correlation parameters diff --git a/tests/testthat/test-correlation.R b/tests/testthat/test-correlation.R index 880221ba..882831f8 100644 --- a/tests/testthat/test-correlation.R +++ b/tests/testthat/test-correlation.R @@ -1,105 +1,136 @@ test_that('1 correlation between rounds gives sensible samples', { pop <- 1e6 target <- seq(pop) - vaccine_coverage <- .2 - parameters <- get_parameters(list( - human_population = pop, - pev = TRUE - )) - correlations <- get_correlation_parameters(parameters) + + coverage_1 <- .2 + coverage_2 <- .4 + + correlations <- CorrelationParameters$new(pop, c('pev')) correlations$inter_round_rho('pev', 1) - round_1 <- sample_intervention(target, 'pev', vaccine_coverage, correlations) - round_2 <- sample_intervention(target, 'pev', vaccine_coverage, correlations) - expect_equal(sum(round_1), pop * .2, tolerance=1e2) - expect_equal(sum(round_2), pop * .2, tolerance=1e2) - expect_equal(sum(round_1 & round_2), pop * .2, tolerance=1e2) + + round_1 <- sample_intervention(target, 'pev', coverage_1, correlations) + round_2 <- sample_intervention(target, 'pev', coverage_2, correlations) + + expect_equal(sum(round_1), pop * coverage_1, tolerance=.1) + expect_equal(sum(round_2), pop * coverage_2, tolerance=.1) + + expect_equal( + sum(round_1 & round_2), + pop * min(coverage_1, coverage_2), + tolerance=.1) + + expect_equal( + sum(round_1 | round_2), + pop * max(coverage_1, coverage_2), + tolerance=.1) }) test_that('0 correlation between rounds gives sensible samples', { pop <- 1e6 target <- seq(pop) - vaccine_coverage <- .5 - parameters <- get_parameters(list( - human_population = pop, - pev = TRUE - )) - correlations <- get_correlation_parameters(parameters) + + coverage_1 <- .2 + coverage_2 <- .4 + + correlations <- CorrelationParameters$new(pop, c('pev')) correlations$inter_round_rho('pev', 0) - round_1 <- sample_intervention(target, 'pev', vaccine_coverage, correlations) - round_2 <- sample_intervention(target, 'pev', vaccine_coverage, correlations) + + round_1 <- sample_intervention(target, 'pev', coverage_1, correlations) + round_2 <- sample_intervention(target, 'pev', coverage_2, correlations) + + expect_equal(sum(round_1), pop * coverage_1, tolerance=.1) + expect_equal(sum(round_2), pop * coverage_2, tolerance=.1) + expect_equal( - length(intersect(which(round_1), which(round_2))), - pop * .5, - tolerance=1e2 - ) - expect_equal(sum(round_1), sum(round_2), tolerance=1e2) - expect_equal(sum(round_1), pop * .5, tolerance=1e2) + sum(round_1 & round_2), + pop * coverage_1 * coverage_2, + tolerance=.1) + + expect_equal( + sum(round_1 | round_2), + pop * (coverage_1 + coverage_2 - (coverage_1 * coverage_2)), + tolerance=.1) }) test_that('1 correlation between interventions gives sensible samples', { pop <- 1e6 target <- seq(pop) - vaccine_coverage <- .2 - mda_coverage <- .2 - parameters <- get_parameters(list( - human_population = pop, - pev = TRUE, - mda = TRUE - )) - correlations <- get_correlation_parameters(parameters) + + pev_coverage <- .2 + mda_coverage <- .4 + + correlations <- CorrelationParameters$new(pop, c('pev', 'mda')) correlations$inter_round_rho('pev', 1) correlations$inter_round_rho('mda', 1) correlations$inter_intervention_rho('pev', 'mda', 1) - vaccine_sample <- sample_intervention(target, 'pev', vaccine_coverage, correlations) + + pev_sample <- sample_intervention(target, 'pev', pev_coverage, correlations) mda_sample <- sample_intervention(target, 'mda', mda_coverage, correlations) - expect_equal(sum(vaccine_sample), pop * .2, tolerance=1e2) - expect_equal(sum(mda_sample), pop * .2, tolerance=1e2) - expect_equal(sum(vaccine_sample & mda_sample), pop * .2, tolerance=1e2) + expect_equal(sum(pev_sample), pop * pev_coverage, tolerance=.1) + expect_equal(sum(mda_sample), pop * mda_coverage, tolerance=.1) + + expect_equal( + sum(pev_sample & mda_sample), + pop * min(pev_coverage, mda_coverage), + tolerance=.1) + + expect_equal( + sum(pev_sample | mda_sample), + pop * max(pev_coverage, mda_coverage), + tolerance=.1) }) test_that('0 correlation between interventions gives sensible samples', { pop <- 1e6 target <- seq(pop) - vaccine_coverage <- .2 - mda_coverage <- .2 - parameters <- get_parameters(list( - human_population = pop, - pev = TRUE, - mda = TRUE - )) - correlations <- get_correlation_parameters(parameters) + + pev_coverage <- .2 + mda_coverage <- .4 + + correlations <- CorrelationParameters$new(pop, c('pev', 'mda')) correlations$inter_round_rho('pev', 1) correlations$inter_round_rho('mda', 1) correlations$inter_intervention_rho('pev', 'mda', 0) - vaccine_sample <- sample_intervention(target, 'pev', vaccine_coverage, correlations) + + pev_sample <- sample_intervention(target, 'pev', pev_coverage, correlations) mda_sample <- sample_intervention(target, 'mda', mda_coverage, correlations) + + expect_equal(sum(pev_sample), pop * pev_coverage, tolerance=.1) + expect_equal(sum(mda_sample), pop * mda_coverage, tolerance=.1) + expect_equal( - length(intersect(which(vaccine_sample), which(mda_sample))), - pop * .5, - tolerance=1e2 - ) - expect_equal(sum(vaccine_sample), sum(mda_sample), tolerance=1e2) - expect_equal(sum(vaccine_sample), pop * .5, tolerance=1e2) + sum(pev_sample & mda_sample), + pop * pev_coverage * mda_coverage, + tolerance=.1) + + expect_equal( + sum(pev_sample | mda_sample), + pop * (pev_coverage + mda_coverage - (pev_coverage * mda_coverage)), + tolerance=.1) }) test_that('-1 correlation between interventions gives sensible samples', { pop <- 1e6 target <- seq(pop) - vaccine_coverage <- .2 - mda_coverage <- .2 - parameters <- get_parameters(list( - human_population = pop, - pev = TRUE, - mda = TRUE - )) - correlations <- get_correlation_parameters(parameters) + + pev_coverage <- .2 + mda_coverage <- .4 + + correlations <- CorrelationParameters$new(pop, c('pev', 'mda')) correlations$inter_round_rho('pev', 1) correlations$inter_round_rho('mda', 1) correlations$inter_intervention_rho('pev', 'mda', -1) - vaccine_sample <- sample_intervention(target, 'pev', vaccine_coverage, correlations) + + pev_sample <- sample_intervention(target, 'pev', pev_coverage, correlations) mda_sample <- sample_intervention(target, 'mda', mda_coverage, correlations) - expect_equal(length(intersect(which(vaccine_sample), which(mda_sample))), 0) - expect_equal(sum(vaccine_sample), .2 * pop, tolerance=1e2) - expect_equal(sum(mda_sample), .2 * pop, tolerance=1e2) + + expect_equal(sum(pev_sample), pop * pev_coverage, tolerance=.1) + expect_equal(sum(mda_sample), pop * mda_coverage, tolerance=.1) + + expect_equal(sum(pev_sample & mda_sample), 0, tolerance=.1) + expect_equal( + sum(pev_sample | mda_sample), + pop * (pev_coverage + mda_coverage), + tolerance=.1) })