Skip to content

Commit

Permalink
Merge pull request #219 from mrc-ide/mrc-2093
Browse files Browse the repository at this point in the history
Optionally remove dimensions
  • Loading branch information
weshinsley authored Mar 17, 2021
2 parents 60d2614 + f2b63a8 commit 0dc9a9d
Show file tree
Hide file tree
Showing 27 changed files with 551 additions and 114 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: odin
Title: ODE Generation and Integration
Version: 1.1.8
Version: 1.1.9
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Thibaut", "Jombart", role = "ctb"),
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
PACKAGE := $(shell grep '^Package:' DESCRIPTION | sed -E 's/^Package:[[:space:]]+//')
RSCRIPT = Rscript --no-init-file
RSCRIPT = Rscript

all: install

Expand Down
6 changes: 5 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# odin 1.1.9

* New option `rewrite_dims` (via `odin::odin_options`) which will attempt to simplify common dimensions. This can reduce the number of variables carried around in the model as these are typically very redundant and also known at compile time (mrc-2093)

# odin 1.1.8

* Annotate equations with `# ignore.unused` to locally suppress messages about unused variables (mrc-2122)
Expand All @@ -19,7 +23,7 @@
* The basic infrastructure has been overhauled, which will make some alternative compilation targets easier to support. We now use `pkgbuild` for the compilation which should ease debugging, and odin code compiled into packages will no longer issue a slew of warnings (and cooperate with automatic routine registration). This refactor has caused a few minor breaking changes:
- `$initial()` always requires time, even if it is ignored
- `$set_user()` and construction no longer work with positional argument matching - all arguments must be named
- The `$ir` field has become a method; add parens after it
- The `$ir` field has become a method; add parentheses after it
- The `compiler_warnings` option has been removed

# odin 1.0.7
Expand Down
1 change: 1 addition & 0 deletions R/dependencies.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ find_symbols <- function(expr, hide_errors = TRUE) {
if (length(e) >= 2L) {
## The if here avoids an invalid parse, e.g. length(); we'll
## pick that up later on.
## This is the one problematic use
variables$add(array_dim_name(deparse(e[[2L]])))
}
## Still need to declare the function as used because we'll
Expand Down
8 changes: 6 additions & 2 deletions R/generate_c_compiled.R
Original file line number Diff line number Diff line change
Expand Up @@ -579,12 +579,16 @@ generate_c_compiled_metadata <- function(dat, rewrite) {
sprintf_safe("SET_VECTOR_ELT(%s, %d, ScalarInteger(%s));",
target, i - 1L, rewrite(d$dimnames$length))
} else {
## NOTE: need to use array_dim_name here because we might have
## removed the dimension variable. However, this exists only
## through a short scope here and we could really use anything.
name <- array_dim_name(d$name)
c(sprintf_safe("SET_VECTOR_ELT(%s, %d, allocVector(INTSXP, %d));",
target, i - 1L, d$rank),
sprintf_safe("int * %s = INTEGER(VECTOR_ELT(%s, %d));",
d$dimnames$length, target, i - 1L),
name, target, i - 1L),
sprintf_safe("%s[%d] = %s;",
d$dimnames$length, seq_len(d$rank) - 1L,
name, seq_len(d$rank) - 1L,
vcapply(d$dimnames$dim, rewrite, USE.NAMES = FALSE)))
}
}
Expand Down
13 changes: 0 additions & 13 deletions R/ir_deserialise.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ ir_deserialise <- function(ir) {
dat$version <- numeric_version(dat$version)
dat$components <- lapply(dat$components, lapply, list_to_character)

if (dat$features$has_array) {
dat$data$elements <- lapply(dat$data$elements, ir_deserialise_data_dimnames)
}

names(dat$data$elements) <- vcapply(dat$data$elements, "[[", "name")
names(dat$data$variable$contents) <-
vcapply(dat$data$variable$contents, "[[", "name")
Expand Down Expand Up @@ -65,12 +61,3 @@ ir_deserialise_equation <- function(eq) {
}
eq
}


ir_deserialise_data_dimnames <- function(x) {
if (x$rank > 0L) {
v <- c("dim", "mult")
x$dimnames[v] <- lapply(x$dimnames[v], list_to_character)
}
x
}
166 changes: 141 additions & 25 deletions R/ir_parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ ir_parse <- function(x, options, type = NULL) {

eqs <- ir_parse_arrays(eqs, variables, config$include$names, source)

## This performs a round of optimisation where we try to simplify
## away expressions for the dimensions, which reduces the number of
## required variables.
if (options$rewrite_dims && features$has_array) {
eqs <- ir_parse_rewrite_dims(eqs)
}

packing <- ir_parse_packing(eqs, variables, source)
eqs <- c(eqs, packing$offsets)
packing$offsets <- NULL
Expand Down Expand Up @@ -318,13 +325,18 @@ ir_parse_stage <- function(eqs, dependencies, variables, time_name, source) {
stage[names_if(vlapply(eqs, is_null))] <- STAGE_NULL

i <- vlapply(eqs, function(x) !is.null(x$array))
len <- unique(vcapply(eqs[i], function(x) x$array$dimnames$length))
err <- stage[len] == STAGE_TIME
len <- lapply(eqs[i], function(x) x$array$dimnames$length)
## We end up with sometimes a string and sometimes a symbol here
## which is unsatisfactory.
len_var <- vcapply(len[!vlapply(len, is.numeric)], as.character)
err <- stage[len_var] == STAGE_TIME

if (any(err)) {
## TODO: in the case where we rewrite dimensions this error is not
## great beause we've lost the dim() call!
ir_parse_error(
"Array extent is determined by time",
ir_parse_error_lines(eqs[len[err]]), source)
ir_parse_error_lines(eqs[len_var[err]]), source)
}

stage
Expand All @@ -342,8 +354,11 @@ ir_parse_packing_new <- function(eqs, variables, offset_prefix) {

ir_parse_packing_internal <- function(names, rank, len, variables,
offset_prefix) {
## We'll pack from least to most complex:
i <- order(rank)
## We'll pack from least to most complex and everything with a fixed
## offset first. This puts all scalars first, then all arrays that
## have compile-time size next (in order of rank), then all arrays
## with user-time size (in order of rank).
i <- order(!vlapply(len, is.numeric), rank)
names <- names[i]
rank <- rank[i]
len <- len[i]
Expand All @@ -355,10 +370,9 @@ ir_parse_packing_internal <- function(names, rank, len, variables,
for (i in seq_along(names)) {
if (!is_array[[i]]) {
offset[[i + 1L]] <- i
} else if (identical(offset[[i]], 0L)) {
offset[[i + 1L]] <- as.name(len[[i]])
} else {
offset[[i + 1L]] <- call("+", offset[[i]], as.name(len[[i]]))
len_i <- if (is.numeric(len[[i]])) len[[i]] else as.name(len[[i]])
offset[[i + 1L]] <- static_eval(call("+", offset[[i]], len_i))
}
}

Expand Down Expand Up @@ -1190,10 +1204,11 @@ ir_parse_delay <- function(eqs, discrete, variables, source) {
## TODO: ideally we'd get the correct lines here for source, but
## that's low down the list of needs.
f <- function(x) {
depends <- if (is.numeric(x$dim)) character(0) else as.character(x$dim)
list(name = x$to,
type = "alloc",
source = integer(0),
depends = ir_parse_depends(variables = x$dim),
depends = ir_parse_depends(variables = depends),
lhs = list(name_data = x$to,
name_lhs = x$to,
name_equation = x$to),
Expand Down Expand Up @@ -1222,8 +1237,11 @@ ir_parse_delay_discrete <- function(eq, eqs, source) {
nm <- eq$name
nm_ring <- sprintf("delay_ring_%s", nm)

depends_ring <- list(functions = character(0),
variables = eq$array$dimnames$length %||% character(0))
len <- eq$array$dimnames$length
depends_ring <- list(
functions = character(0),
variables = if (is.character(len)) len else character(0))

lhs_ring <- list(name_data = nm_ring, name_equation = nm_ring,
name_lhs = nm_ring, storage_type = "ring_buffer")
eq_ring <- list(
Expand Down Expand Up @@ -1279,19 +1297,27 @@ ir_parse_delay_continuous <- function(eq, eqs, variables, source) {
substitutions <- list()
}

eq_len <- list(
name = nm_dim,
type = "expression_scalar",
source = eq$source,
depends = find_symbols(graph$packing$length),
lhs = list(name_data = nm_dim, name_equation = nm_dim, name_lhs = nm_dim,
storage_type = "int"),
rhs = list(value = graph$packing$length))
if (is.numeric(graph$packing$length)) {
eq_len <- NULL
val_len <- graph$packing$length
dep_len <- character(0)
} else {
eq_len <- list(
name = nm_dim,
type = "expression_scalar",
source = eq$source,
depends = find_symbols(graph$packing$length),
lhs = list(name_data = nm_dim, name_equation = nm_dim, name_lhs = nm_dim,
storage_type = "int"),
rhs = list(value = graph$packing$length))
val_len <- nm_dim
dep_len <- nm_dim
}

lhs_use <- eq$lhs[c("name_data", "name_equation", "name_lhs", "special")]
subs_from <- vcapply(substitutions, "[[", "to")
depends_use <- join_deps(list(
eq$depends, ir_parse_depends(variables = c(nm_dim, subs_from, TIME))))
eq$depends, ir_parse_depends(variables = c(dep_len, subs_from, TIME))))

eq_use <- list(
name = nm,
Expand All @@ -1304,15 +1330,15 @@ ir_parse_delay_continuous <- function(eq, eqs, variables, source) {
state = nm_state,
index = nm_index,
substitutions = substitutions,
variables = list(length = eq_len$name,
variables = list(length = val_len,
contents = graph$packing$contents),
equations = graph$equations,
default = eq$delay$default,
time = eq$delay$time,
depends = eq$delay$depends),
array = eq$array)

array <- list(dimnames = list(length = nm_dim, dim = NULL, mult = NULL),
array <- list(dimnames = list(length = val_len, dim = NULL, mult = NULL),
rank = 1L)
lhs_index <-
list(name_data = nm_index, name_equation = nm_index, name_lhs = nm_index,
Expand All @@ -1322,7 +1348,7 @@ ir_parse_delay_continuous <- function(eq, eqs, variables, source) {
offsets <- lapply(variables$contents[match(graph$variables, variable_names)],
"[[", "offset")
depends_index <- join_deps(lapply(offsets, find_symbols))
depends_index$variables <- union(depends_index$variables, nm_dim)
depends_index$variables <- union(depends_index$variables, dep_len)
eq_index <- list(
name = nm_index,
type = "delay_index",
Expand All @@ -1339,7 +1365,7 @@ ir_parse_delay_continuous <- function(eq, eqs, variables, source) {
name = nm_state,
type = "null",
source = eq$source,
depends = ir_parse_depends(variables = nm_dim),
depends = ir_parse_depends(variables = dep_len),
lhs = lhs_state,
array = array)

Expand All @@ -1348,7 +1374,9 @@ ir_parse_delay_continuous <- function(eq, eqs, variables, source) {
eq_index$depends$variables <- c(eq_index$depends$variables, names(offsets))
}

extra <- c(list(eq_len, eq_index, eq_state, eq_use), offsets)
extra <- c(if (is.null(eq_len)) NULL else list(eq_len),
list(eq_index, eq_state, eq_use),
offsets)
names(extra) <- vcapply(extra, "[[", "name")

stopifnot(sum(names(eqs) == eq$name) == 1)
Expand Down Expand Up @@ -1490,3 +1518,91 @@ ir_parse_expr_rhs_check_inplace <- function(lhs, rhs, line, source) {
line, source)
}
}


## This approach could probably be applied over the whole tree really,
## as we might be able to eliminate some other compile time
## things. However, doing that will make the models less debuggable.
ir_parse_rewrite_dims <- function(eqs) {
compute <- function(x) {
if (is.numeric(x)) {
x
} else if (is.symbol(x)) {
x_eq <- eqs[[deparse_str(x)]]
## use identical() here to cope with x_eq being NULL when 't' is
## passed through (that will be an error elsewhere).
if (identical(x_eq$type, "expression_scalar")) {
compute(x_eq$rhs$value)
} else {
x
}
} else if (is_call(x, "length")) {
## NOTE: use array_dim_name because we might hit things like
## length(y) where 'y' is one of the variables; we can't look up
## eqs[[name]]$array$length without checking that.
compute(as.name(array_dim_name(as.character(x[[2]]))))
} else if (is_call(x, "dim")) {
compute(as.name(array_dim_name(as.character(x[[2]]), x[[3]])))
} else if (is.recursive(x)) {
x[-1] <- lapply(x[-1], compute)
x
} else { # NULL
x
}
}

## alternatively look in all $array$dimnames elements
nms <- grep("dim_", names(eqs), value = TRUE)

val <- lapply(eqs[nms], function(eq)
static_eval(compute(eq$rhs$value)))

rewrite <- vlapply(val, function(x) is.symbol(x) || is.numeric(x))

subs <- val[rewrite]

## Try and deduplicate the rest. However, it's not totally obvious
## that we can do this without creating some weird dependency
## issues.
leave <- val[!rewrite]
## Do not deduplicate NULL dimensions; these are set by user() later.
dup <- duplicated(leave) & !vlapply(leave, is.null)
if (any(dup)) {
i <- match(leave[dup], leave)
subs <- c(subs,
set_names(lapply(names(leave)[i], as.name), names(leave)[dup]))
}

replace <- function(x, y) {
i <- match(vcapply(x, function(x) x %||% ""), names(y))
j <- which(!is.na(i))
x[j] <- unname(y)[i[j]]
na_drop(x)
}

subs_env <- list2env(subs, parent = emptyenv())
subs_dep <- vcapply(subs, function(x)
if (is.numeric(x)) NA_character_ else deparse_str(x))

rewrite_eq <- function(eq) {
eq$rhs$value <- substitute_(eq$rhs$value, subs_env)

eq$depends$variables <- replace(eq$depends$variables, subs_dep)
eq$lhs$depends$variables <- replace(eq$lhs$depends$variables, subs_dep)

if (!is.null(eq$array$dimnames)) {
eq$array$dimnames$length <- replace(eq$array$dimnames$length, subs)[[1]]
eq$array$dimnames$dim <- replace(eq$array$dimnames$dim, subs)
eq$array$dimnames$mult <- replace(eq$array$dimnames$mult, subs)
}

if (!is.null(eq$delay)) {
eq$delay$depends$variables <-
replace(eq$delay$depends$variables, subs_dep)
}

eq
}

lapply(eqs[setdiff(names(eqs), names(subs))], rewrite_eq)
}
9 changes: 5 additions & 4 deletions R/ir_parse_arrays.R
Original file line number Diff line number Diff line change
Expand Up @@ -465,13 +465,13 @@ ir_parse_arrays_dims <- function(eq, eqs, rank, variables, output) {
depends = depends_dim)
}
eq_dim <- lapply(seq_len(rank), f_eq_dim)
dimnames$dim <- vcapply(eq_dim, "[[", "name")
dimnames$dim <- lapply(eq_dim, "[[", "name")

## At this point, modify how we compute total length:
dims <- lapply(dimnames$dim, as.name)
eq_length$rhs$value <- r_fold_call("*", dims)
eq_length$depends <- list(functions = character(0),
variables = dimnames$dim)
variables = list_to_character(dimnames$dim))

## Even more bits
if (rank > 2L) {
Expand All @@ -488,11 +488,12 @@ ir_parse_arrays_dims <- function(eq, eqs, rank, variables, output) {
implicit = TRUE,
source = eq$source,
depends = list(functions = character(0),
variables = dimnames$dim[j]))
variables = list_to_character(dimnames$dim[j])))
}
eq_mult <- lapply(3:rank, f_eq_mult)
}
dimnames$mult <- c("", dimnames$dim[[1]], vcapply(eq_mult, "[[", "name"))
dimnames$mult <- c(list("", dimnames$dim[[1]]),
lapply(eq_mult, "[[", "name"))
}

no_alloc <-
Expand Down
2 changes: 2 additions & 0 deletions R/ir_serialise.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ ir_serialise_data <- function(data) {
} else {
ret$dimnames <- x$dimnames
ret$dimnames$length <- ir_serialise_expression(x$dimnames$length)
ret$dimnames$dim <- lapply(ret$dimnames$dim, ir_serialise_expression)
ret$dimnames$mult <- lapply(ret$dimnames$mult, ir_serialise_expression)
}
ret$stage <- scalar(STAGE_NAME[x$stage + 1L])
ret
Expand Down
Loading

0 comments on commit 0dc9a9d

Please sign in to comment.