Skip to content

Commit

Permalink
Merge pull request #223 from mrc-ide/mrc-2252
Browse files Browse the repository at this point in the history
Rewrite compile time constants
  • Loading branch information
weshinsley authored Mar 23, 2021
2 parents 4ae3d1a + 4f9a53f commit 32cabd2
Show file tree
Hide file tree
Showing 18 changed files with 286 additions and 99 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: odin
Title: ODE Generation and Integration
Version: 1.1.11
Version: 1.1.12
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Thibaut", "Jombart", role = "ctb"),
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# odin 1.1.12

* New option `rewrite_constants` (via `odin::odin_options`) which attempts to rewrite all constants in the model code before generation. This can considerably reduce the number of variable lookups (mrc-2252)

# odin 1.1.11

* New option `substitutions` (via `odin::odin_options`) which can substitute in scalar `user` values at compile time (#220)
Expand Down
2 changes: 1 addition & 1 deletion R/generate_c_compiled.R
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ generate_c_compiled_rhs_r <- function(dat, rewrite) {
sprintf_safe("const %s %s = %s;",
time_type, dat$meta$initial_time, initial_time),
c_expr_if(sprintf_safe("ISNA(%s)", dat$meta$initial_time),
sprintf_safe('%s = %s(%s, "dat$meta$time");',
sprintf_safe('%s = %s(%s, "%s");',
initial_time, time_access, dat$meta$time,
dat$meta$time)))
reset_initial_time <-
Expand Down
127 changes: 92 additions & 35 deletions R/ir_parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ ir_parse <- function(x, options, type = NULL) {
variables <- ir_parse_find_variables(eqs, features$discrete, source)

eqs <- lapply(eqs, ir_parse_rewrite_initial, variables)

eqs <- ir_parse_arrays(eqs, variables, config$include$names, source)

## This performs a round of optimisation where we try to simplify
## away expressions for the dimensions, which reduces the number of
## required variables.
eqs <- ir_parse_substitute(eqs, options$substitutions)
if (options$rewrite_dims && features$has_array) {
if (options$rewrite_constants) {
eqs <- ir_parse_rewrite_constants(eqs)
} else if (options$rewrite_dims && features$has_array) {
eqs <- ir_parse_rewrite_dims(eqs)
}

Expand Down Expand Up @@ -1553,72 +1554,123 @@ ir_parse_substitute <- function(eqs, subs) {
}


## This approach could probably be applied over the whole tree really,
## as we might be able to eliminate some other compile time
## things. However, doing that will make the models less debuggable.
ir_parse_rewrite_dims <- function(eqs) {
compute <- function(x) {
if (is.numeric(x)) {
x
} else if (is.symbol(x)) {
nms <- names_if(vlapply(eqs, function(x) isTRUE(x$lhs$dim)))
ir_parse_rewrite(nms, eqs)
}


ir_parse_rewrite_constants <- function(eqs) {
nms <- names_if(vlapply(eqs, function(x) x$type == "expression_scalar"))
ir_parse_rewrite(nms, eqs)
}


ir_parse_rewrite_compute_eqs <- function(nms, eqs) {
cache <- new_empty_env()
lapply(eqs[nms], function(eq)
static_eval(ir_parse_rewrite_compute(eq$rhs$value, eqs, cache)))
}


ir_parse_rewrite_compute <- function(x, eqs, cache) {
key <- deparse_str(x)
if (key %in% names(cache)) {
return(cache[[key]])
}

if (!is.numeric(x)) {
if (is.symbol(x)) {
x_eq <- eqs[[deparse_str(x)]]
## use identical() here to cope with x_eq being NULL when 't' is
## passed through (that will be an error elsewhere).
if (identical(x_eq$type, "expression_scalar")) {
compute(x_eq$rhs$value)
} else {
x
x <- ir_parse_rewrite_compute(x_eq$rhs$value, eqs, cache)
}
} else if (is_call(x, "length")) {
## NOTE: use array_dim_name because we might hit things like
## length(y) where 'y' is one of the variables; we can't look up
## eqs[[name]]$array$length without checking that.
compute(as.name(array_dim_name(as.character(x[[2]]))))
length_name <- as.name(array_dim_name(as.character(x[[2]])))
x <- ir_parse_rewrite_compute(length_name, eqs, cache)
} else if (is_call(x, "dim")) {
compute(as.name(array_dim_name(as.character(x[[2]]), x[[3]])))
dim_name <- as.name(array_dim_name(as.character(x[[2]]), x[[3]]))
x <- ir_parse_rewrite_compute(dim_name, eqs, cache)
} else if (is.recursive(x)) {
x[-1] <- lapply(x[-1], compute)
x
} else { # NULL
x
x[-1] <- lapply(x[-1], ir_parse_rewrite_compute, eqs, cache)
}
}

## alternatively look in all $array$dimnames elements
nms <- grep("dim_", names(eqs), value = TRUE)
cache[[key]] <- x
x
}


val <- lapply(eqs[nms], function(eq)
static_eval(compute(eq$rhs$value)))
ir_parse_rewrite <- function(nms, eqs) {
val <- tryCatch(
ir_parse_rewrite_compute_eqs(nms, eqs),
error = function(e) {
message("Rewrite failure: ", e$message)
list()
})

rewrite <- vlapply(val, function(x) is.symbol(x) || is.numeric(x))

subs <- val[rewrite]

## One small wrinkle here: don't rewrite things that are the target
## of a copy as the rewrite is complicated. This affects almost
## nothing in reality outside the tests?
copy_self <- unlist(lapply(eqs, function(x)
if (x$type == "copy") x$lhs$name_data), FALSE)
subs <- subs[setdiff(names(subs), copy_self)]

is_dim <- vlapply(eqs, function(x) isTRUE(x$lhs$dim))

## Try and deduplicate the rest. However, it's not totally obvious
## that we can do this without creating some weird dependency
## issues.
leave <- val[!rewrite]
## Do not deduplicate NULL dimensions; these are set by user() later.
dup <- duplicated(leave) & !vlapply(leave, is.null)
if (any(dup)) {
i <- match(leave[dup], leave)
subs <- c(subs,
set_names(lapply(names(leave)[i], as.name), names(leave)[dup]))
## issues. Also we need to only treat dimensions (and possibly
## offsets); we could do any compile-time thing really but we don't
## know it yet. Propagating other expressions through though can
## create problems.
check <- val[intersect(names_if(!rewrite), names_if(is_dim))]

if (length(check) > 0) {
dup <- duplicated(check) & !vlapply(check, is.null)
if (any(dup)) {
i <- match(check[dup], check)
subs <- c(subs,
set_names(lapply(names(check)[i], as.name), names(check)[dup]))
}
}

subs_env <- list2env(subs, parent = emptyenv())
subs_dep <- vcapply(subs, function(x)
if (is.numeric(x)) NA_character_ else deparse_str(x))

replace <- function(x, y) {
i <- match(vcapply(x, function(x) x %||% ""), names(y))
j <- which(!is.na(i))
x[j] <- unname(y)[i[j]]
na_drop(x)
}

subs_env <- list2env(subs, parent = emptyenv())
subs_dep <- vcapply(subs, function(x)
if (is.numeric(x)) NA_character_ else deparse_str(x))
rewrite_eq_array_part <- function(el) {
el$value <- substitute_(el$value, subs_env)
for (i in seq_along(el$index)) {
el$index[[i]]$value <- substitute_(el$index[[i]]$value, subs_env)
}
el
}

rewrite_eq <- function(eq) {
eq$rhs$value <- substitute_(eq$rhs$value, subs_env)
if (eq$type == "expression_array") {
eq$rhs <- lapply(eq$rhs, rewrite_eq_array_part)
} else if (eq$name %in% names(subs)) {
eq$rhs$value <- subs[[eq$name]]
} else {
eq$rhs$value <- substitute_(eq$rhs$value, subs_env)
}

eq$depends$variables <- replace(eq$depends$variables, subs_dep)
eq$lhs$depends$variables <- replace(eq$lhs$depends$variables, subs_dep)
Expand All @@ -1630,12 +1682,17 @@ ir_parse_rewrite_dims <- function(eqs) {
}

if (!is.null(eq$delay)) {
eq$delay$time <- substitute_(eq$delay$time, subs_env)
eq$delay$depends$variables <-
replace(eq$delay$depends$variables, subs_dep)
}

eq
}

lapply(eqs[setdiff(names(eqs), names(subs))], rewrite_eq)
## Can't drop initial(), deriv(), or update() calls even if they are
## constants.
keep <- names_if(!vlapply(eqs, function(x) is.null(x$lhs$special)))
i <- setdiff(names(eqs), setdiff(names(subs), keep))
lapply(eqs[i], rewrite_eq)
}
7 changes: 4 additions & 3 deletions R/ir_parse_arrays.R
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,8 @@ ir_parse_arrays_dims <- function(eq, eqs, rank, variables, output) {
lhs = list(name_lhs = eq$name,
name_data = eq$name,
name_equation = eq$name,
storage_type = "int"),
storage_type = "int",
dim = TRUE),
rhs = eq$rhs,
depends = depends_dim,
source = eq$source)
Expand All @@ -458,7 +459,7 @@ ir_parse_arrays_dims <- function(eq, eqs, rank, variables, output) {
name = d,
type = type,
lhs = list(name_lhs = d, name_data = d, name_equation = d,
storage_type = "int"),
storage_type = "int", dim = TRUE),
rhs = list(value = eq$rhs$value[[i]]),
implicit = TRUE,
source = eq$source,
Expand All @@ -483,7 +484,7 @@ ir_parse_arrays_dims <- function(eq, eqs, rank, variables, output) {
name = d,
type = "expression_scalar",
lhs = list(name_lhs = d, name_data = d, name_equation = d,
storage_type = "int"),
storage_type = "int", dim = TRUE),
rhs = list(value = r_fold_call("*", dims[j])),
implicit = TRUE,
source = eq$source,
Expand Down
26 changes: 19 additions & 7 deletions R/odin_options.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,21 @@
##' messages with this option set to `TRUE` because parts of the
##' model have been effectively evaluated during processing.
##'
##' @param substitutions Optionally, a list of values to substitute into
##' model specification as constants, even though they are declared
##' as `user()`. This will be most useful in conjunction with
##' `rewrite_dims` to create a copy of your model with dimensions
##' known at compile time and all loops using literal integers.
##' @param rewrite_constants Logical, indicating if odin should try
##' and rewrite *all* constant scalars. This is a superset of
##' `rewrite_dims` and may be slow for large models. Doing this will
##' make your model less debuggable; error messages will reference
##' expressions that have been extensively rewritten, some variables
##' will have been removed entirely or merged with other identical
##' expressions, and the generated code may not be obviously
##' connected to the original code.
##'
##' @param substitutions Optionally, a list of values to substitute
##' into model specification as constants, even though they are
##' declared as `user()`. This will be most useful in conjunction
##' with `rewrite_dims` to create a copy of your model with
##' dimensions known at compile time and all loops using literal
##' integers.
##'
##' @return A list of parameters, of class `odin_options`
##'
Expand All @@ -29,8 +39,8 @@ odin_options <- function(verbose = NULL, target = NULL, workdir = NULL,
validate = NULL, pretty = NULL, skip_cache = NULL,
compiler_warnings = NULL,
no_check_unused_equations = NULL,
rewrite_dims = NULL, substitutions = NULL,
options = NULL) {
rewrite_dims = NULL, rewrite_constants = NULL,
substitutions = NULL, options = NULL) {
default_target <-
if (is.null(target) && !can_compile(verbose = FALSE)) "r" else "c"
defaults <- list(
Expand All @@ -41,6 +51,7 @@ odin_options <- function(verbose = NULL, target = NULL, workdir = NULL,
pretty = FALSE,
skip_cache = FALSE,
rewrite_dims = FALSE,
rewrite_constants = FALSE,
substitutions = NULL,
no_check_unused_equations = FALSE,
compiler_warnings = FALSE)
Expand All @@ -53,6 +64,7 @@ odin_options <- function(verbose = NULL, target = NULL, workdir = NULL,
workdir = workdir,
skip_cache = assert_scalar_logical_or_null(skip_cache),
rewrite_dims = assert_scalar_logical_or_null(rewrite_dims),
rewrite_constants = assert_scalar_logical_or_null(rewrite_constants),
substitutions = check_substitutions(substitutions),
no_check_unused_equations =
assert_scalar_logical_or_null(no_check_unused_equations),
Expand Down
23 changes: 17 additions & 6 deletions man/odin_options.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/helper-odin.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ test_that_odin <- function(desc, code) {
code_enq <- rlang::enquo(code)
for (target in targets) {
opts <- list(odin.target = target,
odin.rewrite_dims = target == "c")
odin.rewrite_constants = target == "c")
testthat::test_that(sprintf("%s (%s)", desc, target),
withr::with_options(opts, rlang::eval_tidy(code_enq)))
}
Expand Down
3 changes: 2 additions & 1 deletion tests/testthat/test-ir.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ test_that("deserialise", {


test_that("Stage information included in IR", {
ir <- odin_parse_("examples/array_odin.R")
ir <- odin_parse_("examples/array_odin.R",
options = odin_options(rewrite_constants = FALSE))
dat <- odin_ir_deserialise(ir)
expect_equal(dat$data$elements$N_age$stage, "constant")
expect_equal(dat$data$elements$I_tot$stage, "time")
Expand Down
3 changes: 2 additions & 1 deletion tests/testthat/test-odin-validate.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ test_that("invalid model", {

test_that("unused variables can be detected", {
code <- c("initial(x) <- 1", "deriv(x) <- 1", "a <- 1")
res <- odin_validate(code, "text")
res <- odin_validate(code, "text",
odin_options(rewrite_constants = FALSE))
expect_equal(length(res$messages), 1L)
expect_match(res$messages[[1]]$msg, "Unused equation: a")
expect_equivalent(res$messages[[1]]$line, 3)
Expand Down
Loading

0 comments on commit 32cabd2

Please sign in to comment.