Skip to content

Commit

Permalink
Merge pull request #222 from mrc-ide/i220-substitute
Browse files Browse the repository at this point in the history
Allow substitution of user values at compile time
  • Loading branch information
weshinsley authored Mar 19, 2021
2 parents ab0dab4 + ec167c5 commit 4ae3d1a
Show file tree
Hide file tree
Showing 10 changed files with 231 additions and 8 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: odin
Title: ODE Generation and Integration
Version: 1.1.10
Version: 1.1.11
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Thibaut", "Jombart", role = "ctb"),
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# odin 1.1.11

* New option `substitutions` (via `odin::odin_options`) which can substitute in scalar `user` values at compile time (#220)

# odin 1.1.9

* New option `rewrite_dims` (via `odin::odin_options`) which will attempt to simplify common dimensions. This can reduce the number of variables carried around in the model as these are typically very redundant and also known at compile time (mrc-2093)
Expand Down
33 changes: 33 additions & 0 deletions R/ir_parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ ir_parse <- function(x, options, type = NULL) {
## This performs a round of optimisation where we try to simplify
## away expressions for the dimensions, which reduces the number of
## required variables.
eqs <- ir_parse_substitute(eqs, options$substitutions)
if (options$rewrite_dims && features$has_array) {
eqs <- ir_parse_rewrite_dims(eqs)
}
Expand Down Expand Up @@ -1520,6 +1521,38 @@ ir_parse_expr_rhs_check_inplace <- function(lhs, rhs, line, source) {
}


ir_parse_substitute <- function(eqs, subs) {
if (is.null(subs)) {
return(eqs)
}

f <- function(nm) {
eq <- eqs[[nm]]
if (is.null(eq)) {
stop(sprintf("Substitution failed: '%s' is not an equation", nm),
call. = FALSE)
}
if (eq$type != "user") {
stop(sprintf("Substitution failed: '%s' is not a user() equation", nm),
call. = FALSE)
}
if (!is.null(eq$array)) {
stop(sprintf("Substitution failed: '%s' is an array", nm), call. = FALSE)
}
value <- support_coerce_mode(subs[[nm]], eq$user$integer,
eq$user$min, eq$user$max, nm)

eq$type <- "expression_scalar"
eq$rhs <- list(value = value)
eq$stochastic <- FALSE
eq
}

eqs[names(subs)] <- lapply(names(subs), f)
eqs
}


## This approach could probably be applied over the whole tree really,
## as we might be able to eliminate some other compile time
## things. However, doing that will make the models less debuggable.
Expand Down
28 changes: 26 additions & 2 deletions R/odin_options.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
##' messages with this option set to `TRUE` because parts of the
##' model have been effectively evaluated during processing.
##'
##' @param substitutions Optionally, a list of values to substitute into
##' model specification as constants, even though they are declared
##' as `user()`. This will be most useful in conjunction with
##' `rewrite_dims` to create a copy of your model with dimensions
##' known at compile time and all loops using literal integers.
##'
##' @return A list of parameters, of class `odin_options`
##'
##' @export
Expand All @@ -23,7 +29,7 @@ odin_options <- function(verbose = NULL, target = NULL, workdir = NULL,
validate = NULL, pretty = NULL, skip_cache = NULL,
compiler_warnings = NULL,
no_check_unused_equations = NULL,
rewrite_dims = NULL,
rewrite_dims = NULL, substitutions = NULL,
options = NULL) {
default_target <-
if (is.null(target) && !can_compile(verbose = FALSE)) "r" else "c"
Expand All @@ -35,6 +41,7 @@ odin_options <- function(verbose = NULL, target = NULL, workdir = NULL,
pretty = FALSE,
skip_cache = FALSE,
rewrite_dims = FALSE,
substitutions = NULL,
no_check_unused_equations = FALSE,
compiler_warnings = FALSE)
if (is.null(options)) {
Expand All @@ -46,14 +53,15 @@ odin_options <- function(verbose = NULL, target = NULL, workdir = NULL,
workdir = workdir,
skip_cache = assert_scalar_logical_or_null(skip_cache),
rewrite_dims = assert_scalar_logical_or_null(rewrite_dims),
substitutions = check_substitutions(substitutions),
no_check_unused_equations =
assert_scalar_logical_or_null(no_check_unused_equations),
compiler_warnings = assert_scalar_logical_or_null(compiler_warnings))
}
stopifnot(all(names(defaults) %in% names(options)))

for (i in names(defaults)) {
if (is.null(options[[i]])) {
if (is.null(options[[i]]) && i != "substitutions") {
options[[i]] <- getOption(paste0("odin.", i), defaults[[i]])
}
}
Expand All @@ -69,3 +77,19 @@ odin_options <- function(verbose = NULL, target = NULL, workdir = NULL,
class(options) <- "odin_options"
options
}


check_substitutions <- function(substitutions) {
if (is.null(substitutions)) {
return(NULL)
}
assert_named(substitutions, TRUE)
assert_is(substitutions, "list")
ok <- vlapply(substitutions, function(x)
is.numeric(x) && length(x) == 1L)
if (any(!ok)) {
stop("Invalid entry in substitutions: ",
paste(squote(names_if(!ok)), collapse = ", "))
}
substitutions
}
1 change: 1 addition & 0 deletions R/odin_parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,6 @@ odin_parse <- function(x, type = NULL, options = NULL) {
##' @rdname odin_parse
odin_parse_ <- function(x, options = NULL, type = NULL) {
options <- odin_options(options = options)
assert_scalar_character_or_null(type)
ir_parse(x, options, type)
}
28 changes: 24 additions & 4 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -245,19 +245,39 @@ na_drop <- function(x) {

assert_scalar_logical_or_null <- function(x, name = deparse(substitute(x))) {
if (!is.null(x)) {
if (length(x) != 1 || !is.logical(x) || !is.na(x)) {
if (length(x) != 1 || !is.logical(x) || is.na(x)) {
stop(sprintf("Expected '%s' to be a logical scalar (or NULL)", name))
}
}
invisible(x)
}


assert_scalar_logical_or_null <- function(x, name = deparse(substitute(x))) {
assert_scalar_character_or_null <- function(x, name = deparse(substitute(x))) {
if (!is.null(x)) {
if (length(x) != 1 || !is.logical(x) || is.na(x)) {
stop(sprintf("Expected '%s' to be a logical scalar (or NULL)", name))
if (length(x) != 1 || !is.character(x) || is.na(x)) {
stop(sprintf("Expected '%s' to be a character scalar (or NULL)", name))
}
}
invisible(x)
}


assert_named <- function(x, unique = FALSE, name = deparse(substitute(x))) {
if (is.null(names(x))) {
stop(sprintf("'%s' must be named", name), call. = FALSE)
}
if (unique && any(duplicated(names(x)))) {
stop(sprintf("'%s' must have unique names", name), call. = FALSE)
}
invisible(x)
}


assert_is <- function(x, what, name = deparse(substitute(x))) {
if (!inherits(x, what)) {
stop(sprintf("'%s' must be a %s", name, paste(what, collapse = " / ")),
call. = FALSE)
}
invisible(x)
}
8 changes: 7 additions & 1 deletion man/odin_options.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

72 changes: 72 additions & 0 deletions tests/testthat/test-parse2-rewrite.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,75 @@ test_that("rewrite arrays drops references to dim_ variables", {
expect_false(grepl("dim_S_1", ir))
expect_false(grepl("dim_S", ir))
})


test_that("Can create compile-time constants", {
ir <- odin_parse({
n <- user(integer = TRUE)
m <- user()
deriv(S[, ]) <- 0
deriv(I) <- S[n, m]
dim(S) <- c(n, m)
initial(S[, ]) <- S0[i, j]
initial(I) <- 0
S0[, ] <- user()
dim(S0) <- c(n, m)
}, options = odin_options(rewrite_dims = TRUE,
substitutions = list(n = 2, m = 3)))
dat <- ir_deserialise(ir)
expect_equal(dat$equations$n$type, "expression_scalar")
})


test_that("Can validate substitutions", {
code <- quote({
n <- user(integer = TRUE, min = 2)
m <- user(max = 10)
a <- 1
deriv(S[, ]) <- 0
deriv(I) <- S[n, m]
dim(S) <- c(n, m)
initial(S[, ]) <- S0[i, j] * a
initial(I) <- 0
S0[, ] <- user()
dim(S0) <- c(n, m)
})

expect_error(
odin_parse(code, options = odin_options(substitutions = list(y = 1))),
"Substitution failed: 'y' is not an equation")
expect_error(
odin_parse(code, options = odin_options(substitutions = list(a = 1))),
"Substitution failed: 'a' is not a user() equation", fixed = TRUE)
expect_error(
odin_parse(code, options = odin_options(substitutions = list(S0 = 1))),
"Substitution failed: 'S0' is an array", fixed = TRUE)
expect_error(
odin_parse(code, options = odin_options(substitutions = list(n = 1))),
"Expected 'n' to be at least 2")
expect_error(
odin_parse(code, options = odin_options(substitutions = list(n = 2.4))),
"Expected 'n' to be integer-like")
expect_error(
odin_parse(code, options = odin_options(substitutions = list(m = 20))),
"Expected 'm' to be at most 10")
expect_error(
odin_parse(code,
options = odin_options(substitutions = list(m = NA_real_))),
"'m' must not contain any NA values")
expect_error(
odin_parse(code,
options = odin_options(substitutions =
list(m = NULL, n = NULL))),
"Invalid entry in substitutions: 'm', 'n'")
expect_error(
odin_parse(code, options = odin_options(substitutions = list(1, 2))),
"'substitutions' must be named")
expect_error(
odin_parse(code, options =
odin_options(substitutions = list(n = 1, n = 1))),
"'substitutions' must have unique names")
expect_error(
odin_parse(code, options = odin_options(substitutions = c(n = 1))),
"'substitutions' must be a list")
})
23 changes: 23 additions & 0 deletions tests/testthat/test-run-basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -739,3 +739,26 @@ test_that_odin("Can set initial conditions directly in an ode", {
y <- mod$run(0:10, 2)
expect_equal(y[, "y"], seq(2, by = 2, length.out = 11))
})


test_that_odin("Can substitute user variables", {
gen <- odin({
n <- user(integer = TRUE)
m <- user()
deriv(S[, ]) <- 0
deriv(I) <- S[n, m]
dim(S) <- c(n, m)
initial(S[, ]) <- S0[i, j]
initial(I) <- 0
S0[, ] <- user()
dim(S0) <- c(n, m)
}, options = odin_options(rewrite_dims = TRUE,
substitutions = list(n = 2, m = 3)))
expect_equal(nrow(coef(gen)), 1) # only S0 now
S0 <- matrix(rpois(6, 10), 2, 3)
mod <- gen(S0 = S0)
dat <- mod$contents()
expect_equal(dat$n, 2)
expect_equal(dat$m, 3)
expect_equal(dat$initial_S, S0)
})
40 changes: 40 additions & 0 deletions tests/testthat/test-util.R
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,43 @@ test_that("validate inputs", {
expect_error(assert_scalar_logical_or_null(logical(0)),
"Expected '.+' to be a logical scalar \\(or NULL\\)")
})


test_that("validate inputs", {
expect_silent(assert_scalar_character_or_null(NULL))
expect_silent(assert_scalar_character_or_null("a"))

thing <- TRUE
expect_error(
assert_scalar_character_or_null(thing),
"Expected 'thing' to be a character scalar (or NULL)",
fixed = TRUE)
expect_error(assert_scalar_character_or_null(NA),
"Expected '.+' to be a character scalar \\(or NULL\\)")
expect_error(assert_scalar_character_or_null(character(0)),
"Expected '.+' to be a character scalar \\(or NULL\\)")
})


test_that("check names", {
expect_error(
assert_named(list()),
"must be named")
expect_error(
assert_named(list(1, 2)),
"must be named")
expect_silent(
assert_named(list(a = 1, a = 2)))
expect_error(
assert_named(list(a = 1, a = 2), TRUE),
"must have unique names")
})


test_that("Check S3 class", {
expect_silent(assert_is(structure(1, class = "foo"), "foo"))
expect_error(assert_is(structure(1, class = "bar"), "foo"),
"must be a foo")
expect_error(assert_is(1, c("foo", "bar")),
"must be a foo / bar")
})

0 comments on commit 4ae3d1a

Please sign in to comment.