diff --git a/R/check_args_default.R b/R/check_args_default.R index 0c34e1a2..25091ef2 100644 --- a/R/check_args_default.R +++ b/R/check_args_default.R @@ -65,70 +65,22 @@ compartments = compartments_default ) - # add null intervention and vaccination if these are missing - # if not missing, check that they conform to expectations - # add null rate_intervention if this is missing - # if not missing, check that it conforms to expectations - if (is.null(mod_args[["intervention"]])) { - # add dummy list elements named "contacts", and one named "transmissibility" - mod_args[["intervention"]] <- list( - contacts = no_contacts_intervention( - mod_args[["population"]] - ), - transmissibility = no_rate_intervention() - ) - } else { - # check intervention list names - checkmate::assert_names( - names(mod_args[["intervention"]]), - subset.of = c( - "transmissibility", "infectiousness_rate", "recovery_rate", "contacts" - ) - ) - # if a contacts intervention is passed, check it - if ("contacts" %in% names(mod_args[["intervention"]])) { - # check the intervention on contacts - assert_intervention( - mod_args[["intervention"]][["contacts"]], "contacts", - mod_args[["population"]] - ) - } else { - # if not contacts intervention is passed, add a dummy one - mod_args[["intervention"]]$contacts <- no_contacts_intervention( - mod_args[["population"]] - ) - } - - # if there is only an intervention on contacts, add a dummy intervention - # on the transmissibility - if (identical(names(mod_args[["intervention"]]), "contacts")) { - mod_args[["intervention"]]$transmissibility <- no_rate_intervention() - } - } + assert_intervention( + mod_args[["intervention"]][["contacts"]], "contacts", + mod_args[["population"]] + ) - if (is.null(mod_args[["vaccination"]])) { - mod_args[["vaccination"]] <- no_vaccination( - mod_args[["population"]] - ) - } else { - # default model only supports a single dose vaccination - assert_vaccination( - mod_args[["vaccination"]], - doses = 1L, mod_args[["population"]] - ) - } + assert_vaccination( + mod_args[["vaccination"]], + doses = 1L, mod_args[["population"]] + ) - # handle time dependence if not present, and check targets if present - if (is.null(mod_args[["time_dependence"]])) { - mod_args[["time_dependence"]] <- no_time_dependence() - } else { - checkmate::assert_names( - names(mod_args[["time_dependence"]]), - subset.of = c( - "transmissibility", "infectiousness_rate", "recovery_rate" - ) + checkmate::assert_names( + names(mod_args[["time_dependence"]]), + subset.of = c( + "transmissibility", "infectiousness_rate", "recovery_rate" ) - } + ) # return arguments invisibly invisible(mod_args) diff --git a/R/model_default.R b/R/model_default.R index 5b2c4089..e06e1c6b 100644 --- a/R/model_default.R +++ b/R/model_default.R @@ -102,9 +102,12 @@ model_default_cpp <- function(population, transmissibility = 1.3 / 7.0, infectiousness_rate = 1.0 / 2.0, recovery_rate = 1.0 / 7.0, - intervention = NULL, - vaccination = NULL, - time_dependence = NULL, + intervention = list( + contacts = no_contacts_intervention(population), + transmissibility = no_rate_intervention() + ), + vaccination = no_vaccination(population), + time_dependence = no_time_dependence(), time_end = 100, increment = 1) { # check class on required inputs @@ -116,12 +119,12 @@ model_default_cpp <- function(population, # all intervention sub-classes pass check for intervention superclass # note intervention and time-dependence targets are checked in dedicated fn - checkmate::assert_list( - intervention, - types = "intervention", null.ok = TRUE, - any.missing = FALSE, names = "unique" - ) - checkmate::assert_class(vaccination, "vaccination", null.ok = TRUE) + # checkmate::assert_list( + # intervention, + # types = "intervention", null.ok = TRUE, + # any.missing = FALSE, names = "unique" + # ) + checkmate::assert_class(vaccination, "vaccination") # check that time-dependence functions are passed as a list with at least the # arguments `time` and `x` @@ -265,9 +268,12 @@ model_default_r <- function(population, transmissibility = 1.3 / 7.0, infectiousness_rate = 1.0 / 2.0, recovery_rate = 1.0 / 7.0, - intervention = NULL, - vaccination = NULL, - time_dependence = NULL, + intervention = list( + contacts = no_contacts_intervention(population), + transmissibility = no_rate_intervention() + ), + vaccination = no_vaccination(population), + time_dependence = no_time_dependence(), time_end = 100, increment = 1) { # check class on required inputs diff --git a/tests/testthat/test-input_checking_intervention.R b/tests/testthat/test-input_checking_intervention.R index b5fb43fc..6c6d92f1 100644 --- a/tests/testthat/test-input_checking_intervention.R +++ b/tests/testthat/test-input_checking_intervention.R @@ -31,6 +31,20 @@ test_intervention_bad <- intervention( reduction = 0.2 ) +test_rate_intervention <- intervention( + type = "rate", + time_begin = 60, + time_end = 100, + reduction = 0.2 +) +test_bad_rate_intervention <- intervention( + type = "rate", + time_begin = 60, + time_end = 100, + reduction = matrix(0.2) +) + + test_that("Interventions are checked correctly", { # check for no conditions on a well formed intervention expect_no_condition( @@ -46,6 +60,21 @@ test_that("Interventions are checked correctly", { test_intervention # with population missing ) ) + + expect_no_condition( + assert_intervention( + type = "rate", + test_rate_intervention # with population missing + ) + ) + expect_no_condition( + assert_intervention( + type = "rate", + test_rate_intervention, + population = test_population + ) + ) + expect_error( assert_intervention( test_intervention, @@ -62,4 +91,11 @@ test_that("Interventions are checked correctly", { population = test_population ) ) + expect_error( + assert_intervention( + test_bad_rate_intervention, + "contacts", + population = test_population + ) + ) }) diff --git a/vignettes/rate_interventions.Rmd b/vignettes/rate_interventions.Rmd index 43904e6d..0a580c78 100644 --- a/vignettes/rate_interventions.Rmd +++ b/vignettes/rate_interventions.Rmd @@ -130,7 +130,10 @@ data <- model_default_cpp( # with a mask mandate data_masks <- model_default_cpp( population = uk_population, - intervention = list(transmissibility = mask_mandate), + intervention = list( + transmissibility = mask_mandate, + contacts = no_contacts_intervention(uk_population) + ), time_end = 200, increment = 1.0 ) ```