From 7205786fb03cb87425e6a0d8cde963a739a03234 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 8 Mar 2021 08:37:12 +0000 Subject: [PATCH 01/24] Towards proof-of-concept dimension removal --- R/dependencies.R | 1 + R/generate_c_compiled.R | 8 ++- R/ir_deserialise.R | 13 ---- R/ir_parse.R | 89 ++++++++++++++++++++++++++-- R/ir_parse_arrays.R | 7 ++- R/ir_serialise.R | 2 + R/odin.R | 29 +++++---- R/odin_options.R | 3 + R/opt.R | 62 +++++++++++++++++++ R/utils.R | 15 +++++ inst/schema.json | 6 +- tests/testthat/test-opt.R | 30 ++++++++++ tests/testthat/test-parse2-general.R | 24 ++++++++ tests/testthat/test-parse2-rewrite.R | 64 ++++++++++++++++++++ tests/testthat/test-run-opt.R | 19 ++++++ 15 files changed, 336 insertions(+), 36 deletions(-) create mode 100644 R/opt.R create mode 100644 tests/testthat/test-opt.R create mode 100644 tests/testthat/test-run-opt.R diff --git a/R/dependencies.R b/R/dependencies.R index a5f114fe..efa99a10 100644 --- a/R/dependencies.R +++ b/R/dependencies.R @@ -20,6 +20,7 @@ find_symbols <- function(expr, hide_errors = TRUE) { if (length(e) >= 2L) { ## The if here avoids an invalid parse, e.g. length(); we'll ## pick that up later on. + ## This is the one problematic use variables$add(array_dim_name(deparse(e[[2L]]))) } ## Still need to declare the function as used because we'll diff --git a/R/generate_c_compiled.R b/R/generate_c_compiled.R index 46aa2d0c..7f1e1594 100644 --- a/R/generate_c_compiled.R +++ b/R/generate_c_compiled.R @@ -579,12 +579,16 @@ generate_c_compiled_metadata <- function(dat, rewrite) { sprintf_safe("SET_VECTOR_ELT(%s, %d, ScalarInteger(%s));", target, i - 1L, rewrite(d$dimnames$length)) } else { + ## NOTE: need to use array_dim_name here because we might have + ## removed the dimension variable. However, this exists only + ## through a short scope here and we could really use anything. + name <- array_dim_name(d$name) c(sprintf_safe("SET_VECTOR_ELT(%s, %d, allocVector(INTSXP, %d));", target, i - 1L, d$rank), sprintf_safe("int * %s = INTEGER(VECTOR_ELT(%s, %d));", - d$dimnames$length, target, i - 1L), + name, target, i - 1L), sprintf_safe("%s[%d] = %s;", - d$dimnames$length, seq_len(d$rank) - 1L, + name, seq_len(d$rank) - 1L, vcapply(d$dimnames$dim, rewrite, USE.NAMES = FALSE))) } } diff --git a/R/ir_deserialise.R b/R/ir_deserialise.R index 4cb1f608..1b3c98dc 100644 --- a/R/ir_deserialise.R +++ b/R/ir_deserialise.R @@ -30,10 +30,6 @@ ir_deserialise <- function(ir) { dat$version <- numeric_version(dat$version) dat$components <- lapply(dat$components, lapply, list_to_character) - if (dat$features$has_array) { - dat$data$elements <- lapply(dat$data$elements, ir_deserialise_data_dimnames) - } - names(dat$data$elements) <- vcapply(dat$data$elements, "[[", "name") names(dat$data$variable$contents) <- vcapply(dat$data$variable$contents, "[[", "name") @@ -65,12 +61,3 @@ ir_deserialise_equation <- function(eq) { } eq } - - -ir_deserialise_data_dimnames <- function(x) { - if (x$rank > 0L) { - v <- c("dim", "mult") - x$dimnames[v] <- lapply(x$dimnames[v], list_to_character) - } - x -} diff --git a/R/ir_parse.R b/R/ir_parse.R index df5a1b55..65a227ab 100644 --- a/R/ir_parse.R +++ b/R/ir_parse.R @@ -19,6 +19,13 @@ ir_parse <- function(x, options, type = NULL) { eqs <- ir_parse_arrays(eqs, variables, config$include$names, source) + ## This performs a round of optimisation where we try to simplify + ## away expressions for the dimensions, which reduces the number of + ## required variables. + if (options$rewrite_dims && features$has_array) { + eqs <- ir_parse_rewrite_dims(eqs) + } + packing <- ir_parse_packing(eqs, variables, source) eqs <- c(eqs, packing$offsets) packing$offsets <- NULL @@ -318,13 +325,14 @@ ir_parse_stage <- function(eqs, dependencies, variables, time_name, source) { stage[names_if(vlapply(eqs, is_null))] <- STAGE_NULL i <- vlapply(eqs, function(x) !is.null(x$array)) - len <- unique(vcapply(eqs[i], function(x) x$array$dimnames$length)) - err <- stage[len] == STAGE_TIME + len <- lapply(eqs[i], function(x) x$array$dimnames$length) + len_var <- list_to_character(len[vlapply(len, is.character)]) + err <- stage[len_var] == STAGE_TIME if (any(err)) { ir_parse_error( "Array extent is determined by time", - ir_parse_error_lines(eqs[len[err]]), source) + ir_parse_error_lines(eqs[len_var[err]]), source) } stage @@ -358,7 +366,8 @@ ir_parse_packing_internal <- function(names, rank, len, variables, } else if (identical(offset[[i]], 0L)) { offset[[i + 1L]] <- as.name(len[[i]]) } else { - offset[[i + 1L]] <- call("+", offset[[i]], as.name(len[[i]])) + len_i <- if (is.numeric(len[[i]])) len[[i]] else as.name(len[[i]]) + offset[[i + 1L]] <- static_eval(call("+", offset[[i]], len_i)) } } @@ -1490,3 +1499,75 @@ ir_parse_expr_rhs_check_inplace <- function(lhs, rhs, line, source) { line, source) } } + + +## This approach could probably be applied over the whole tree really, +## as we might be able to eliminate some other compile time +## things. However, doing that will make the models less debuggable. +ir_parse_rewrite_dims <- function(eqs) { + compute <- function(x) { + if (is.numeric(x)) { + x + } else if (is.symbol(x)) { + x_eq <- eqs[[deparse_str(x)]] + if (x_eq$type == "expression_scalar") { + compute(x_eq$rhs$value) + } else if (x_eq$type == "user") { + x + } else { + stop("CHECK") # I don't think this is possible and return 'x'? + } + } else if (is.recursive(x)) { + x[-1] <- lapply(x[-1], compute) + x + } + } + + ## alternatively look in all $array$dimnames elements + nms <- grep("dim_", names(eqs), value = TRUE) + + val <- lapply(eqs[nms], function(eq) + static_eval(compute(eq$rhs$value))) + + rewrite <- vlapply(val, function(x) is.symbol(x) || is.numeric(x)) + + subs <- val[rewrite] + + ## Try and deduplicate the rest. However, it's not totally obvious + ## that we can do this without creating some weird dependency + ## issues. + leave <- val[!rewrite] + dup <- duplicated(leave) + if (any(dup)) { + i <- match(leave[dup], leave) + subs <- c(subs, + set_names(lapply(names(leave)[i], as.name), names(leave)[dup])) + } + + replace <- function(x, y) { + i <- match(vcapply(x, function(x) x %||% ""), names(y)) + j <- which(!is.na(i)) + x[j] <- unname(y)[i[j]] + na_drop(x) + } + + subs_env <- list2env(subs, parent = emptyenv()) + subs_dep <- vcapply(subs, function(x) + if (is.numeric(x)) NA_character_ else deparse_str(x)) + + rewrite_eq <- function(eq) { + eq$rhs$value <- substitute_(eq$rhs$value, subs_env) + + eq$depends$variables <- replace(eq$depends$variables, subs_dep) + eq$lhs$depends$variables <- replace(eq$lhs$depends$variables, subs_dep) + + if (!is.null(eq$array$dimnames)) { + eq$array$dimnames$length <- replace(eq$array$dimnames$length, subs)[[1]] + eq$array$dimnames$dim <- replace(eq$array$dimnames$dim, subs) + eq$array$dimnames$mult <- replace(eq$array$dimnames$mult, subs) + } + eq + } + + lapply(eqs[setdiff(names(eqs), names(subs))], rewrite_eq) +} diff --git a/R/ir_parse_arrays.R b/R/ir_parse_arrays.R index 7d78dbe6..0a176345 100644 --- a/R/ir_parse_arrays.R +++ b/R/ir_parse_arrays.R @@ -465,13 +465,13 @@ ir_parse_arrays_dims <- function(eq, eqs, rank, variables, output) { depends = depends_dim) } eq_dim <- lapply(seq_len(rank), f_eq_dim) - dimnames$dim <- vcapply(eq_dim, "[[", "name") + dimnames$dim <- lapply(eq_dim, "[[", "name") ## At this point, modify how we compute total length: dims <- lapply(dimnames$dim, as.name) eq_length$rhs$value <- r_fold_call("*", dims) eq_length$depends <- list(functions = character(0), - variables = dimnames$dim) + variables = list_to_character(dimnames$dim)) ## Even more bits if (rank > 2L) { @@ -492,7 +492,8 @@ ir_parse_arrays_dims <- function(eq, eqs, rank, variables, output) { } eq_mult <- lapply(3:rank, f_eq_mult) } - dimnames$mult <- c("", dimnames$dim[[1]], vcapply(eq_mult, "[[", "name")) + dimnames$mult <- c(list("", dimnames$dim[[1]]), + lapply(eq_mult, "[[", "name")) } no_alloc <- diff --git a/R/ir_serialise.R b/R/ir_serialise.R index 67bf75b7..25ff0e96 100644 --- a/R/ir_serialise.R +++ b/R/ir_serialise.R @@ -72,6 +72,8 @@ ir_serialise_data <- function(data) { } else { ret$dimnames <- x$dimnames ret$dimnames$length <- ir_serialise_expression(x$dimnames$length) + ret$dimnames$dim <- lapply(ret$dimnames$dim, ir_serialise_expression) + ret$dimnames$mult <- lapply(ret$dimnames$mult, ir_serialise_expression) } ret$stage <- scalar(STAGE_NAME[x$stage + 1L]) ret diff --git a/R/odin.R b/R/odin.R index ec0f5c58..72e746ad 100644 --- a/R/odin.R +++ b/R/odin.R @@ -70,6 +70,9 @@ ##' messages about unused variables. Defaults to the option ##' `odin.no_check_unused_equations` or `FALSE` otherwise. ##' +##' @param options An [odin_options] object; if given then this +##' overrides all options above. +##' ##' @return A function that can generate the model ##' ##' @author Rich FitzJohn @@ -98,7 +101,8 @@ ##' plot(y, xlab = "Time", ylab = "y", main = "", las = 1) odin <- function(x, verbose = NULL, target = NULL, workdir = NULL, validate = NULL, pretty = NULL, skip_cache = NULL, - compiler_warnings = NULL, no_check_unused_equations = NULL) { + compiler_warnings = NULL, no_check_unused_equations = NULL, + options = NULL) { xx <- substitute(x) if (is.symbol(xx)) { xx <- force(x) @@ -107,7 +111,7 @@ odin <- function(x, verbose = NULL, target = NULL, workdir = NULL, xx <- force(x) } odin_(xx, verbose, target, workdir, validate, pretty, skip_cache, - compiler_warnings, no_check_unused_equations) + compiler_warnings, no_check_unused_equations, options) } @@ -115,15 +119,18 @@ odin <- function(x, verbose = NULL, target = NULL, workdir = NULL, ##' @rdname odin odin_ <- function(x, verbose = NULL, target = NULL, workdir = NULL, validate = NULL, pretty = NULL, skip_cache = NULL, - compiler_warnings = NULL, no_check_unused_equations = NULL) { - options <- odin_options(verbose = verbose, - target = target, - workdir = workdir, - validate = validate, - pretty = pretty, - skip_cache = skip_cache, - no_check_unused_equations = no_check_unused_equations, - compiler_warnings = compiler_warnings) + compiler_warnings = NULL, no_check_unused_equations = NULL, + options = NULL) { + options <- odin_options( + verbose = verbose, + target = target, + workdir = workdir, + validate = validate, + pretty = pretty, + skip_cache = skip_cache, + no_check_unused_equations = no_check_unused_equations, + compiler_warnings = compiler_warnings, + options = options) ir <- odin_parse_(x, options) odin_generate(ir, options) diff --git a/R/odin_options.R b/R/odin_options.R index 98cea340..e4eb67ac 100644 --- a/R/odin_options.R +++ b/R/odin_options.R @@ -16,6 +16,7 @@ odin_options <- function(verbose = NULL, target = NULL, workdir = NULL, validate = NULL, pretty = NULL, skip_cache = NULL, compiler_warnings = NULL, no_check_unused_equations = NULL, + rewrite_dims = NULL, options = NULL) { default_target <- if (is.null(target) && !can_compile(verbose = FALSE)) "r" else "c" @@ -26,6 +27,7 @@ odin_options <- function(verbose = NULL, target = NULL, workdir = NULL, workdir = tempfile(), pretty = FALSE, skip_cache = FALSE, + rewrite_dims = FALSE, no_check_unused_equations = FALSE, compiler_warnings = FALSE) if (is.null(options)) { @@ -35,6 +37,7 @@ odin_options <- function(verbose = NULL, target = NULL, workdir = NULL, pretty = pretty, workdir = workdir, skip_cache = skip_cache, + rewrite_dims = rewrite_dims, no_check_unused_equations = no_check_unused_equations, compiler_warnings = compiler_warnings) } diff --git a/R/opt.R b/R/opt.R new file mode 100644 index 00000000..a22d74be --- /dev/null +++ b/R/opt.R @@ -0,0 +1,62 @@ +## things not done: +## Can resolve x - y for numeric args +## Can simplify a + b - c by rewriting as a + b + (-c) +## Pointless parens +## Don't cope with unary +/- +## Factorise simple linear combinations in +? + +## Part of the point of this is to assemble expressions into forms +## that an optimising compiler later in the chain can simplify. +static_eval <- function(expr) { + if (!is.recursive(expr)) { + return(expr) + } + + fn <- expr[[1]] + if (is_call(expr, "+") || is_call(expr, "*")) { + expr <- static_eval_assoc(expr) + } else { + expr[-1] <- lapply(expr[-1], static_eval) + } + + if (is_call(expr, "(") && length(expr) == 2L) { + expr <- expr[[2L]] + } + + expr +} + + +static_eval_assoc <- function(expr) { + fn <- as.character(expr[[1]]) + args <- collect_assoc(lapply(expr[-1], static_eval), fn) + + i <- vlapply(args, is.numeric) + if (any(i)) { + args <- c(args[!i], eval(r_fold_call(fn, args[i]), baseenv())) + } + + if (length(args) == 1L) { + return(args[[1L]]) + } + + r_fold_call(fn, order_args(args)) +} + + +collect_assoc <- function(args, fn) { + args <- as.list(args) + i <- vlapply(args, is_call, fn) + if (any(i)) { + args[i] <- lapply(args[i], function(x) collect_assoc(x[-1], fn)) + flatten1(args) + } else { + args + } +} + + +order_args <- function(args) { + i <- viapply(args, function(x) is.language(x) + is.recursive(x)) + args[order(i, decreasing = TRUE)] +} diff --git a/R/utils.R b/R/utils.R index 4690e0f6..5812cfb4 100644 --- a/R/utils.R +++ b/R/utils.R @@ -156,6 +156,11 @@ list_to_character <- function(x) { } +list_to_numeric <- function(x) { + vnapply(x, identity) +} + + sort_list <- function(x) { x[order(names(x))] } @@ -236,3 +241,13 @@ read_lines <- function(path) { clean_package_name <- function(name) { gsub("_", ".", name) } + + +flatten1 <- function(x) { + unlist(x, FALSE, FALSE) +} + + +na_drop <- function(x) { + x[!is.na(x)] +} diff --git a/inst/schema.json b/inst/schema.json index 1debc534..b854aa9a 100644 --- a/inst/schema.json +++ b/inst/schema.json @@ -783,17 +783,17 @@ "type": "object", "properties": { "length": { - "type": "string" + "$ref": "#/definitions/sexpression" }, "dim": { "oneOf": [ - {"$ref": "#/definitions/basic/character_vector"}, + {"$ref": "#/definitions/sexpression_vector"}, {"type": "null"} ] }, "mult": { "oneOf": [ - {"$ref": "#/definitions/basic/character_vector"}, + {"$ref": "#/definitions/sexpression_vector"}, {"type": "null"} ] } diff --git a/tests/testthat/test-opt.R b/tests/testthat/test-opt.R new file mode 100644 index 00000000..9d0c090a --- /dev/null +++ b/tests/testthat/test-opt.R @@ -0,0 +1,30 @@ +expect_equal(static_eval(quote(1 + 2)), 3) +expect_equal(static_eval(quote(1 + 2 + 3)), 6) + +expect_equal(static_eval(quote(a + 1 + 2)), quote(a + 3)) +expect_equal(static_eval(quote(1 + a + 2)), quote(a + 3)) +expect_equal(static_eval(quote(1 + 2 + a)), quote(a + 3)) + +expect_equal(static_eval(quote(a + 1 + b + 2 + c + 3)), + quote(a + b + c + 6)) + +expect_equal(collect_assoc(quote(a + b + c), quote(`+`)), + list(quote(a), quote(b), quote(c))) +expect_equal(collect_assoc(quote(a + 1 + b + 2 + c + 3), quote(`+`)), + list(quote(a), 1, quote(b), 2, quote(c), 3)) + +expect_equal(static_eval(quote(1 + (2) + 3)), 6) + +static_eval(quote(1 + (a + 2) + 3)) + +static_eval(quote(1 + (a + 2))) + +expect_equal(static_eval(quote(1 + (a + 2) + 3)), + quote(a + 6)) + +expect_equal(static_eval(quote((a + 2 * 3) + 4 * 5)), + quote(a + 26)) +expect_equal(static_eval(quote((a + 2 * 3) + 4 * b)), + quote(b * 4 + a + 6)) +expect_equal(static_eval(quote((1 + 4) * (b + 3))), + quote((b + 3) * 5)) diff --git a/tests/testthat/test-parse2-general.R b/tests/testthat/test-parse2-general.R index b2d1941c..cffbb7c6 100644 --- a/tests/testthat/test-parse2-general.R +++ b/tests/testthat/test-parse2-general.R @@ -822,3 +822,27 @@ test_that("can't use C identifier", { }), "Reserved name 'int' for lhs") }) + + +test_that("rewrite arrays", { + ## This does break + ## * dependencies.R:23 + ## * ir_pase.R: 769 + ## * ir_parse_arrays (465, 480) + + ## so not too bad. + + ir <- odin_parse({ + n <- 2 + m <- 2 + deriv(S[, ]) <- 0 + deriv(I) <- S[n, m] + dim(S) <- c(n, m) + initial(S[, ]) <- S0[i, j] + initial(I) <- 0 + S0[, ] <- user() + dim(S0) <- c(n, m) + }) + + +}) diff --git a/tests/testthat/test-parse2-rewrite.R b/tests/testthat/test-parse2-rewrite.R index 26b233de..e285f5cc 100644 --- a/tests/testthat/test-parse2-rewrite.R +++ b/tests/testthat/test-parse2-rewrite.R @@ -13,3 +13,67 @@ test_that("log", { output(a) <- log(1, 2, 3) }), "Expected 1-2 arguments in log call", class = "odin_error") }) + + + + +test_that("rewrite arrays", { + options <- odin_options(rewrite_dims = TRUE, validate = TRUE) + ir <- odin_parse({ + n <- 2 + m <- 2 + deriv(S[, ]) <- 0 + deriv(I) <- S[n, m] + dim(S) <- c(n, m) + initial(S[, ]) <- S0[i, j] + initial(I) <- 0 + S0[, ] <- user() + dim(S0) <- c(n, m) + }, options = options) + gen <- odin_generate(ir, options) + mod <- gen(S0 = matrix(runif(4), 2, 2)) + mod$run(0:10) + + options <- odin_options(validate = TRUE) + ir <- odin_parse({ + n <- 2 + m <- 2 + deriv(S[, ]) <- 0 + deriv(I) <- S[n, m] + dim(S) <- c(n, m) + initial(S[, ]) <- S0[i, j] + initial(I) <- 0 + S0[, ] <- user() + dim(S0) <- c(n, m) + }, options = options) + + odin_generate(ir, options) + + ir_deserialise(ir) + + expect_false(grepl("dim_S_1", ir)) + expect_false(grepl("dim_S", ir)) +}) + + +test_that("rewrite arrays with shared dimensions", { + options <- odin_options(rewrite_dims = TRUE, validate = FALSE) + ir <- odin_parse({ + n <- user(integer = TRUE) + m <- user(integer = TRUE) + deriv(x[, ]) <- 0 + deriv(y[, ]) <- 0 + initial(x[, ]) <- 0 + initial(y[, ]) <- 0 + dim(x) <- c(n, m) + dim(y) <- c(n, m) + }, options = options) + gen <- odin_generate(ir, options) + mod <- gen(n = 4, m = 5) + mod$contents() + + static_eval(quote(a * 2 * 3)) + static_eval(quote(2 * 3 * a)) + + +}) diff --git a/tests/testthat/test-run-opt.R b/tests/testthat/test-run-opt.R new file mode 100644 index 00000000..7669cd6f --- /dev/null +++ b/tests/testthat/test-run-opt.R @@ -0,0 +1,19 @@ +context("odin: opt") + +test_that("optimise dimensions away entirely", { + options <- odin_options(rewrite_dims = TRUE) + gen <- odin({ + n <- 2 + m <- 2 + deriv(S[, ]) <- 0 + deriv(I) <- S[n, m] + dim(S) <- c(n, m) + initial(S[, ]) <- S0[i, j] + initial(I) <- 0 + S0[, ] <- user() + dim(S0) <- c(n, m) + }, options = options) + + + +}) From 19ef2aab14c23ccf89f39fc2d4a1d275ca20d70b Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 8 Mar 2021 10:24:25 +0000 Subject: [PATCH 02/24] Sort packing to reduce offsets --- R/ir_parse.R | 7 +++++-- tests/testthat/test-parse2-general.R | 1 - 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/R/ir_parse.R b/R/ir_parse.R index 65a227ab..072886e3 100644 --- a/R/ir_parse.R +++ b/R/ir_parse.R @@ -350,8 +350,11 @@ ir_parse_packing_new <- function(eqs, variables, offset_prefix) { ir_parse_packing_internal <- function(names, rank, len, variables, offset_prefix) { - ## We'll pack from least to most complex: - i <- order(rank) + ## We'll pack from least to most complex and everything with a fixed + ## offset first. This puts all scalars first, then all arrays that + ## have compile-time size next (in order of rank), then all arrays + ## with user-time size (in order of rank). + i <- order(!vlapply(len, is.numeric), rank) names <- names[i] rank <- rank[i] len <- len[i] diff --git a/tests/testthat/test-parse2-general.R b/tests/testthat/test-parse2-general.R index cffbb7c6..f8921d4a 100644 --- a/tests/testthat/test-parse2-general.R +++ b/tests/testthat/test-parse2-general.R @@ -844,5 +844,4 @@ test_that("rewrite arrays", { dim(S0) <- c(n, m) }) - }) From faba99f30122b77dbe9d8329b33c4911b0a146cb Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 8 Mar 2021 11:46:41 +0000 Subject: [PATCH 03/24] Start ironing corner cases --- R/ir_parse.R | 10 +++++----- tests/testthat/test-ir.R | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/R/ir_parse.R b/R/ir_parse.R index 072886e3..5f448560 100644 --- a/R/ir_parse.R +++ b/R/ir_parse.R @@ -326,7 +326,7 @@ ir_parse_stage <- function(eqs, dependencies, variables, time_name, source) { i <- vlapply(eqs, function(x) !is.null(x$array)) len <- lapply(eqs[i], function(x) x$array$dimnames$length) - len_var <- list_to_character(len[vlapply(len, is.character)]) + len_var <- vcapply(len[vlapply(len, is.name)], deparse_str) err <- stage[len_var] == STAGE_TIME if (any(err)) { @@ -366,8 +366,6 @@ ir_parse_packing_internal <- function(names, rank, len, variables, for (i in seq_along(names)) { if (!is_array[[i]]) { offset[[i + 1L]] <- i - } else if (identical(offset[[i]], 0L)) { - offset[[i + 1L]] <- as.name(len[[i]]) } else { len_i <- if (is.numeric(len[[i]])) len[[i]] else as.name(len[[i]]) offset[[i + 1L]] <- static_eval(call("+", offset[[i]], len_i)) @@ -1513,9 +1511,11 @@ ir_parse_rewrite_dims <- function(eqs) { x } else if (is.symbol(x)) { x_eq <- eqs[[deparse_str(x)]] - if (x_eq$type == "expression_scalar") { + ## use identical() here to cope with x_eq being NULL when 't' is + ## passed through (that will be an error elsewhere). + if (identical(x_eq$type, "expression_scalar")) { compute(x_eq$rhs$value) - } else if (x_eq$type == "user") { + } else if (is.null(x_eq) || x_eq$type == "user") { x } else { stop("CHECK") # I don't think this is possible and return 'x'? diff --git a/tests/testthat/test-ir.R b/tests/testthat/test-ir.R index bbc1655d..294635f1 100644 --- a/tests/testthat/test-ir.R +++ b/tests/testthat/test-ir.R @@ -14,6 +14,6 @@ test_that("deserialise", { test_that("Stage information included in IR", { ir <- odin_parse_("examples/array_odin.R") dat <- odin_ir_deserialise(ir) - expect_equal(dat$data$elements$dim_S$stage, "constant") + expect_equal(dat$data$elements$N_age$stage, "constant") expect_equal(dat$data$elements$I_tot$stage, "time") }) From 309563394b3eaea7e6691da3fbaa83a1ec049d60 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 8 Mar 2021 17:15:37 +0000 Subject: [PATCH 04/24] Working through tests --- R/ir_parse.R | 4 +++- R/ir_parse_arrays.R | 2 +- tests/testthat/test-run-basic.R | 12 ++++++------ 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/R/ir_parse.R b/R/ir_parse.R index 5f448560..859aa000 100644 --- a/R/ir_parse.R +++ b/R/ir_parse.R @@ -1515,7 +1515,9 @@ ir_parse_rewrite_dims <- function(eqs) { ## passed through (that will be an error elsewhere). if (identical(x_eq$type, "expression_scalar")) { compute(x_eq$rhs$value) - } else if (is.null(x_eq) || x_eq$type == "user") { + } else if (is.null(x_eq) || x_eq$type %in% c("user", "null")) { + ## TODO: we get 'null' here from interpolated variables that + ## are problematic. x } else { stop("CHECK") # I don't think this is possible and return 'x'? diff --git a/R/ir_parse_arrays.R b/R/ir_parse_arrays.R index 0a176345..42e3f1e9 100644 --- a/R/ir_parse_arrays.R +++ b/R/ir_parse_arrays.R @@ -488,7 +488,7 @@ ir_parse_arrays_dims <- function(eq, eqs, rank, variables, output) { implicit = TRUE, source = eq$source, depends = list(functions = character(0), - variables = dimnames$dim[j])) + variables = list_to_character(dimnames$dim[j]))) } eq_mult <- lapply(3:rank, f_eq_mult) } diff --git a/tests/testthat/test-run-basic.R b/tests/testthat/test-run-basic.R index 6a9e3bd8..87a8286b 100644 --- a/tests/testthat/test-run-basic.R +++ b/tests/testthat/test-run-basic.R @@ -266,7 +266,7 @@ test_that_odin("array support", { n <- 3 dim(r) <- n dim(x) <- n - }) + }, options = odin_options(rewrite_dims = FALSE)) mod <- gen() @@ -310,7 +310,7 @@ test_that_odin("3d array", { initial(y[, , ]) <- 1 deriv(y[, , ]) <- y[i, j, k] * 0.1 dim(y) <- c(2, 3, 4) - }) + }, options = odin_options(rewrite_dims = FALSE)) mod <- gen() d <- mod$contents() @@ -388,7 +388,7 @@ test_that_odin("user array - indirect", { dim(r) <- n dim(x) <- n n <- user() - }) + }, options = odin_options(rewrite_dims = FALSE)) mod <- gen(n = 3, r = 1:3) expect_equal(sort_list(mod$contents()), @@ -432,7 +432,7 @@ test_that_odin("user array - direct 3d", { deriv(y) <- 1 r[, , ] <- user() dim(r) <- user() - }) + }, options = odin_options(rewrite_dims = FALSE)) m <- array(runif(24), 2:4) mod <- gen(r = m) @@ -459,7 +459,7 @@ test_that_odin("interpolation", { dim(tp) <- user() dim(zp) <- user() output(p) <- pulse - }) + }, options = odin_options(rewrite_dims = FALSE)) tt <- seq(0, 3, length.out = 301) tp <- c(0, 1, 2) @@ -578,7 +578,7 @@ test_that_odin("3d array time dependent and variable", { dim(y) <- c(2, 3, 4) dim(r) <- c(2, 3, 4) r[, , ] <- t * 0.1 - }) + }, options = odin_options(rewrite_dims = FALSE)) mod <- gen() d <- mod$contents() From 35204c15240b79594b9ad355a1c3e447ee546590 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 8 Mar 2021 18:04:12 +0000 Subject: [PATCH 05/24] More tests passing --- R/ir_parse.R | 19 ++++++++++++++++--- tests/testthat/test-run-basic.R | 2 +- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/R/ir_parse.R b/R/ir_parse.R index 859aa000..b60ac2d0 100644 --- a/R/ir_parse.R +++ b/R/ir_parse.R @@ -1232,8 +1232,11 @@ ir_parse_delay_discrete <- function(eq, eqs, source) { nm <- eq$name nm_ring <- sprintf("delay_ring_%s", nm) - depends_ring <- list(functions = character(0), - variables = eq$array$dimnames$length %||% character(0)) + len <- eq$array$dimnames$length + depends_ring <- list( + functions = character(0), + variables = if (is.character(len)) len else character(0)) + lhs_ring <- list(name_data = nm_ring, name_equation = nm_ring, name_lhs = nm_ring, storage_type = "ring_buffer") eq_ring <- list( @@ -1522,9 +1525,18 @@ ir_parse_rewrite_dims <- function(eqs) { } else { stop("CHECK") # I don't think this is possible and return 'x'? } + } else if (is_call(x, "length")) { + ## NOTE: use array_dim_name because we might hit things like + ## length(y) where 'y' is one of the variables; we can't look up + ## eqs[[name]]$array$length without checking that. + compute(as.name(array_dim_name(as.character(x[[2]])))) + } else if (is_call(x, "dim")) { + compute(as.name(array_dim_name(as.character(x[[2]]), x[[3]]))) } else if (is.recursive(x)) { x[-1] <- lapply(x[-1], compute) x + } else { # NULL + x } } @@ -1542,7 +1554,8 @@ ir_parse_rewrite_dims <- function(eqs) { ## that we can do this without creating some weird dependency ## issues. leave <- val[!rewrite] - dup <- duplicated(leave) + ## Do not deduplicate NULL dimensions; these are set by user() later. + dup <- duplicated(leave) & !vlapply(leave, is.null) if (any(dup)) { i <- match(leave[dup], leave) subs <- c(subs, diff --git a/tests/testthat/test-run-basic.R b/tests/testthat/test-run-basic.R index 87a8286b..324f032d 100644 --- a/tests/testthat/test-run-basic.R +++ b/tests/testthat/test-run-basic.R @@ -653,7 +653,7 @@ test_that_odin("discrete delays: matrix", { dim(y) <- c(2, 3) dim(z) <- c(2, 3) dim(a) <- c(2, 3) - }) + }, options = odin_options(rewrite_dims = FALSE)) mod <- gen() tt <- 0:10 From 851d00b40a5002b01e48159d929a005cbfea70a5 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 9 Mar 2021 08:18:22 +0000 Subject: [PATCH 06/24] Add tests for rewriting rules --- tests/testthat/test-opt.R | 61 +++++++++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 22 deletions(-) diff --git a/tests/testthat/test-opt.R b/tests/testthat/test-opt.R index 9d0c090a..4087420d 100644 --- a/tests/testthat/test-opt.R +++ b/tests/testthat/test-opt.R @@ -1,30 +1,47 @@ -expect_equal(static_eval(quote(1 + 2)), 3) -expect_equal(static_eval(quote(1 + 2 + 3)), 6) +context("opt") -expect_equal(static_eval(quote(a + 1 + 2)), quote(a + 3)) -expect_equal(static_eval(quote(1 + a + 2)), quote(a + 3)) -expect_equal(static_eval(quote(1 + 2 + a)), quote(a + 3)) +test_that("static_eval completely evaluates numeric expressions", { + expect_equal(static_eval(quote(1 + 2)), 3) + expect_equal(static_eval(quote(1 + 2 + 3)), 6) + expect_equal(static_eval(quote(1 + 2 * 3)), 7) + expect_equal(static_eval(quote(1 + (2) + 3)), 6) + expect_equal(static_eval(quote((1 + 2) * 3)), 9) +}) -expect_equal(static_eval(quote(a + 1 + b + 2 + c + 3)), - quote(a + b + c + 6)) -expect_equal(collect_assoc(quote(a + b + c), quote(`+`)), - list(quote(a), quote(b), quote(c))) -expect_equal(collect_assoc(quote(a + 1 + b + 2 + c + 3), quote(`+`)), - list(quote(a), 1, quote(b), 2, quote(c), 3)) +test_that("static_eval collects numbers up associatively", { + expect_equal(static_eval(quote(a + 3 + 2)), quote(a + 5)) + expect_equal(static_eval(quote(3 + a + 2)), quote(a + 5)) + expect_equal(static_eval(quote(3 + 2 + a)), quote(a + 5)) -expect_equal(static_eval(quote(1 + (2) + 3)), 6) + expect_equal(static_eval(quote(a * 3 * 2)), quote(a * 6)) + expect_equal(static_eval(quote(3 * a * 2)), quote(a * 6)) + expect_equal(static_eval(quote(3 * 2 * a)), quote(a * 6)) -static_eval(quote(1 + (a + 2) + 3)) + expect_equal(static_eval(quote(a + 1 + b + 2 + c + 3)), + quote(a + b + c + 6)) +}) -static_eval(quote(1 + (a + 2))) -expect_equal(static_eval(quote(1 + (a + 2) + 3)), - quote(a + 6)) +test_that("collect_assoc unfolds expressions", { + expect_equal(collect_assoc(quote(a + b + c), quote(`+`)), + list(quote(`+`), quote(a), quote(b), quote(c))) + expect_equal(collect_assoc(quote(a + 1 + b + 2 + c + 3), quote(`+`)), + list(quote(`+`), quote(a), 1, quote(b), 2, quote(c), 3)) +}) -expect_equal(static_eval(quote((a + 2 * 3) + 4 * 5)), - quote(a + 26)) -expect_equal(static_eval(quote((a + 2 * 3) + 4 * b)), - quote(b * 4 + a + 6)) -expect_equal(static_eval(quote((1 + 4) * (b + 3))), - quote((b + 3) * 5)) + +test_that("static_eval removes superfluous parens", { + expect_equal(static_eval(quote(1 + (a + 2))), quote(a + 3)) + expect_equal(static_eval(quote(1 + (a + 2) + 3)), quote(a + 6)) +}) + + +test_that("More complex examples", { + expect_equal(static_eval(quote((a + 2 * 3) + 4 * 5)), + quote(a + 26)) + expect_equal(static_eval(quote((a + 2 * 3) + 4 * b)), + quote(b * 4 + a + 6)) + expect_equal(static_eval(quote((1 + 4) * (b + 3))), + quote((b + 3) * 5)) +}) From 7de2b8c85d25f1a9fc2154f34c9700cb887c3fb9 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 9 Mar 2021 08:18:38 +0000 Subject: [PATCH 07/24] Read init files in tests --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 328ffbac..5ea14d12 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ PACKAGE := $(shell grep '^Package:' DESCRIPTION | sed -E 's/^Package:[[:space:]]+//') -RSCRIPT = Rscript --no-init-file +RSCRIPT = Rscript all: install From f64f64f3f35cb3589e8c046fec67f4cdee448135 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 9 Mar 2021 08:22:22 +0000 Subject: [PATCH 08/24] Corner cases in delays and lengths --- R/ir_parse.R | 46 +++++++++++++++------- tests/testthat/test-parse2-general.R | 23 ----------- tests/testthat/test-run-basic.R | 2 +- tests/testthat/test-run-delay-continuous.R | 21 +++++++--- 4 files changed, 48 insertions(+), 44 deletions(-) diff --git a/R/ir_parse.R b/R/ir_parse.R index b60ac2d0..fcb26888 100644 --- a/R/ir_parse.R +++ b/R/ir_parse.R @@ -1287,24 +1287,32 @@ ir_parse_delay_continuous <- function(eq, eqs, variables, source) { substitutions <- lapply(arrays, function(x) list(from = x, to = sprintf("delay_array_%s", x), - dim = eqs[[x]]$array$dimnames$length)) + dim = as.character(eqs[[x]]$array$dimnames$length))) } else { substitutions <- list() } - eq_len <- list( - name = nm_dim, - type = "expression_scalar", - source = eq$source, - depends = find_symbols(graph$packing$length), - lhs = list(name_data = nm_dim, name_equation = nm_dim, name_lhs = nm_dim, - storage_type = "int"), - rhs = list(value = graph$packing$length)) + if (is.numeric(graph$packing$length)) { + eq_len <- NULL + val_len <- graph$packing$length + dep_len <- character(0) + } else { + eq_len <- list( + name = nm_dim, + type = "expression_scalar", + source = eq$source, + depends = find_symbols(graph$packing$length), + lhs = list(name_data = nm_dim, name_equation = nm_dim, name_lhs = nm_dim, + storage_type = "int"), + rhs = list(value = graph$packing$length)) + val_len <- nm_dim + dep_len <- nm_dim + } lhs_use <- eq$lhs[c("name_data", "name_equation", "name_lhs", "special")] subs_from <- vcapply(substitutions, "[[", "to") depends_use <- join_deps(list( - eq$depends, ir_parse_depends(variables = c(nm_dim, subs_from, TIME)))) + eq$depends, ir_parse_depends(variables = c(dep_len, subs_from, TIME)))) eq_use <- list( name = nm, @@ -1317,7 +1325,7 @@ ir_parse_delay_continuous <- function(eq, eqs, variables, source) { state = nm_state, index = nm_index, substitutions = substitutions, - variables = list(length = eq_len$name, + variables = list(length = val_len, contents = graph$packing$contents), equations = graph$equations, default = eq$delay$default, @@ -1325,7 +1333,7 @@ ir_parse_delay_continuous <- function(eq, eqs, variables, source) { depends = eq$delay$depends), array = eq$array) - array <- list(dimnames = list(length = nm_dim, dim = NULL, mult = NULL), + array <- list(dimnames = list(length = val_len, dim = NULL, mult = NULL), rank = 1L) lhs_index <- list(name_data = nm_index, name_equation = nm_index, name_lhs = nm_index, @@ -1335,7 +1343,7 @@ ir_parse_delay_continuous <- function(eq, eqs, variables, source) { offsets <- lapply(variables$contents[match(graph$variables, variable_names)], "[[", "offset") depends_index <- join_deps(lapply(offsets, find_symbols)) - depends_index$variables <- union(depends_index$variables, nm_dim) + depends_index$variables <- union(depends_index$variables, dep_len) eq_index <- list( name = nm_index, type = "delay_index", @@ -1352,7 +1360,7 @@ ir_parse_delay_continuous <- function(eq, eqs, variables, source) { name = nm_state, type = "null", source = eq$source, - depends = ir_parse_depends(variables = nm_dim), + depends = ir_parse_depends(variables = dep_len), lhs = lhs_state, array = array) @@ -1361,7 +1369,9 @@ ir_parse_delay_continuous <- function(eq, eqs, variables, source) { eq_index$depends$variables <- c(eq_index$depends$variables, names(offsets)) } - extra <- c(list(eq_len, eq_index, eq_state, eq_use), offsets) + extra <- c(if (is.null(eq_len)) NULL else list(eq_len), + list(eq_index, eq_state, eq_use), + offsets) names(extra) <- vcapply(extra, "[[", "name") stopifnot(sum(names(eqs) == eq$name) == 1) @@ -1584,6 +1594,12 @@ ir_parse_rewrite_dims <- function(eqs) { eq$array$dimnames$dim <- replace(eq$array$dimnames$dim, subs) eq$array$dimnames$mult <- replace(eq$array$dimnames$mult, subs) } + + if (!is.null(eq$delay)) { + eq$delay$depends$variables <- + replace(eq$delay$depends$variables, subs_dep) + } + eq } diff --git a/tests/testthat/test-parse2-general.R b/tests/testthat/test-parse2-general.R index f8921d4a..b2d1941c 100644 --- a/tests/testthat/test-parse2-general.R +++ b/tests/testthat/test-parse2-general.R @@ -822,26 +822,3 @@ test_that("can't use C identifier", { }), "Reserved name 'int' for lhs") }) - - -test_that("rewrite arrays", { - ## This does break - ## * dependencies.R:23 - ## * ir_pase.R: 769 - ## * ir_parse_arrays (465, 480) - - ## so not too bad. - - ir <- odin_parse({ - n <- 2 - m <- 2 - deriv(S[, ]) <- 0 - deriv(I) <- S[n, m] - dim(S) <- c(n, m) - initial(S[, ]) <- S0[i, j] - initial(I) <- 0 - S0[, ] <- user() - dim(S0) <- c(n, m) - }) - -}) diff --git a/tests/testthat/test-run-basic.R b/tests/testthat/test-run-basic.R index 324f032d..1c5f1998 100644 --- a/tests/testthat/test-run-basic.R +++ b/tests/testthat/test-run-basic.R @@ -411,7 +411,7 @@ test_that_odin("user array - direct", { r[] <- user() dim(r) <- user() dim(x) <- length(r) - }) + }, options = odin_options(rewrite_dims = FALSE)) mod <- gen(r = 1:3) expect_equal( diff --git a/tests/testthat/test-run-delay-continuous.R b/tests/testthat/test-run-delay-continuous.R index dfd85251..014379e2 100644 --- a/tests/testthat/test-run-delay-continuous.R +++ b/tests/testthat/test-run-delay-continuous.R @@ -219,16 +219,27 @@ test_that_odin("delay index packing", { dim(e) <- 14 }) + dim_a <- 10 + dim_b <- 11 + dim_c <- 12 + dim_d <- 13 + dim_e <- 14 + dim_foo <- 9 + offset_variable_c <- 21 # i.e., 10 + 11 + offset_variable_e <- 46 # i.e., 10 + 11 + 12 + 13 + mod <- gen() dat <- mod$contents() seq0 <- function(n) seq_len(n) - expect_equal(dat$dim_delay_foo, dat$dim_b + dat$dim_c + dat$dim_e) + if (odin_target_name() == "c") { + expect_length(dat$delay_state_foo, dim_b + dim_c + dim_e) + } - delay_index_foo <- c(dat$dim_a + seq0(dat$dim_b), - dat$offset_variable_c + seq0(dat$dim_c), - dat$offset_variable_e + seq0(dat$dim_e)) + delay_index_foo <- c(dim_a + seq0(dim_b), + offset_variable_c + seq0(dim_c), + offset_variable_e + seq0(dim_e)) if (odin_target_name() == "c") { delay_index_foo <- delay_index_foo - 1L } @@ -237,7 +248,7 @@ test_that_odin("delay index packing", { tt <- seq(0, 10, length.out = 11) yy <- mod$transform_variables(mod$run(tt)) - i <- seq_len(dat$dim_foo) + i <- seq_len(dim_foo) expect_equal(yy$foo[1, ], yy$b[1, i] + yy$c[1, i + 1] + yy$e[1, i + 2]) expect_equal(yy$foo[8, ], From 1e9b7ea694de57f29cfb449011fb1ced54d0457f Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 9 Mar 2021 09:09:51 +0000 Subject: [PATCH 09/24] Remaining corner cases --- R/ir_parse.R | 5 +++-- tests/testthat/test-run-examples.R | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/R/ir_parse.R b/R/ir_parse.R index fcb26888..5d758af3 100644 --- a/R/ir_parse.R +++ b/R/ir_parse.R @@ -1200,10 +1200,11 @@ ir_parse_delay <- function(eqs, discrete, variables, source) { ## TODO: ideally we'd get the correct lines here for source, but ## that's low down the list of needs. f <- function(x) { + depends <- if (is.numeric(x$dim)) character(0) else as.character(x$dim) list(name = x$to, type = "alloc", source = integer(0), - depends = ir_parse_depends(variables = x$dim), + depends = ir_parse_depends(variables = depends), lhs = list(name_data = x$to, name_lhs = x$to, name_equation = x$to), @@ -1287,7 +1288,7 @@ ir_parse_delay_continuous <- function(eq, eqs, variables, source) { substitutions <- lapply(arrays, function(x) list(from = x, to = sprintf("delay_array_%s", x), - dim = as.character(eqs[[x]]$array$dimnames$length))) + dim = eqs[[x]]$array$dimnames$length)) } else { substitutions <- list() } diff --git a/tests/testthat/test-run-examples.R b/tests/testthat/test-run-examples.R index dedcc094..7baa9a21 100644 --- a/tests/testthat/test-run-examples.R +++ b/tests/testthat/test-run-examples.R @@ -119,7 +119,8 @@ test_that_odin("user arrays", { dat3 <- mod3$contents() dat1 <- mod1$contents() - expect_true(setequal(names(dat1), names(dat3))) + expect_true(all(names(dat1) %in% names(dat3))) + expect_true(all(grepl("^offset_", setdiff(names(dat3), names(dat1))))) expect_equal(dat3[names(dat1)], dat1) ## Now, let's set some different parameters here and check enforcement: @@ -142,7 +143,8 @@ test_that_odin("user arrays", { mod4 <- gen4(age_width = age_width) dat4 <- mod4$contents() - expect_true(setequal(names(dat1), names(dat4))) + expect_true(all(names(dat1) %in% names(dat4))) + expect_true(all(grepl("^(dim|offset)_", setdiff(names(dat4), names(dat1))))) expect_equal(dat4[names(dat1)], dat1) res4 <- mod4$run(t) From 9e07e91f63cded444ef48a99ff7dd48a809926a1 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 9 Mar 2021 09:48:18 +0000 Subject: [PATCH 10/24] More corner cases in array-is-time detectection --- R/ir_parse.R | 6 +++- tests/testthat/helper-odin.R | 5 +-- tests/testthat/test-parse2-general.R | 48 ++++++++++++++++------------ 3 files changed, 36 insertions(+), 23 deletions(-) diff --git a/R/ir_parse.R b/R/ir_parse.R index 5d758af3..ceabf044 100644 --- a/R/ir_parse.R +++ b/R/ir_parse.R @@ -326,10 +326,14 @@ ir_parse_stage <- function(eqs, dependencies, variables, time_name, source) { i <- vlapply(eqs, function(x) !is.null(x$array)) len <- lapply(eqs[i], function(x) x$array$dimnames$length) - len_var <- vcapply(len[vlapply(len, is.name)], deparse_str) + ## We end up with sometimes a string and sometimes a symbol here + ## which is unsatisfactory. + len_var <- vcapply(len[!vlapply(len, is.numeric)], as.character) err <- stage[len_var] == STAGE_TIME if (any(err)) { + ## TODO: in the case where we rewrite dimensions this error is not + ## great beause we've lost the dim() call! ir_parse_error( "Array extent is determined by time", ir_parse_error_lines(eqs[len_var[err]]), source) diff --git a/tests/testthat/helper-odin.R b/tests/testthat/helper-odin.R index 7f695ab8..e868b7ed 100644 --- a/tests/testthat/helper-odin.R +++ b/tests/testthat/helper-odin.R @@ -128,8 +128,9 @@ test_that_odin <- function(desc, code) { targets <- test_odin_targets() code_enq <- rlang::enquo(code) for (target in targets) { + opts <- list(odin.target = target, + odin.rewrite_dims = target == "c") testthat::test_that(sprintf("%s (%s)", desc, target), - withr::with_options(list(odin.target = target), - rlang::eval_tidy(code_enq))) + withr::with_options(opts, rlang::eval_tidy(code_enq))) } } diff --git a/tests/testthat/test-parse2-general.R b/tests/testthat/test-parse2-general.R index b2d1941c..8389b3f2 100644 --- a/tests/testthat/test-parse2-general.R +++ b/tests/testthat/test-parse2-general.R @@ -370,26 +370,34 @@ test_that("recursive variables", { }) test_that("array extent and time", { - expect_error(odin_parse_(quote({ - deriv(y[]) <- 1 - initial(y[]) <- 0 - dim(y) <- t - })), "Array extent is determined by time", class = "odin_error") - - expect_error(odin_parse_(quote({ - deriv(y[]) <- 1 - initial(y[]) <- 0 - a <- t - dim(y) <- a - })), "Array extent is determined by time", class = "odin_error") - - expect_error(odin_parse_(quote({ - deriv(y[]) <- 1 - initial(y[]) <- 0 - deriv(z) <- 1 - initial(z) <- 0 - dim(y) <- z - })), "Array extent is determined by time", class = "odin_error") + for (rewrite_dims in c(FALSE, TRUE)) { + expect_error( + odin_parse_(quote({ + deriv(y[]) <- 1 + initial(y[]) <- 0 + dim(y) <- t + }), options = odin_options(rewrite_dims = rewrite_dims)), + "Array extent is determined by time", class = "odin_error") + + expect_error( + odin_parse_(quote({ + deriv(y[]) <- 1 + initial(y[]) <- 0 + a <- t + dim(y) <- a + }), options = odin_options(rewrite_dims = rewrite_dims)), + "Array extent is determined by time", class = "odin_error") + + expect_error( + odin_parse_(quote({ + deriv(y[]) <- 1 + initial(y[]) <- 0 + deriv(z) <- 1 + initial(z) <- 0 + dim(y) <- z + }), options = odin_options(rewrite_dims = rewrite_dims)), + "Array extent is determined by time", class = "odin_error") + } }) test_that("lhs checking", { From 0c2fb21f19189497572c88ebe33bfa1395bd83d1 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 9 Mar 2021 09:50:55 +0000 Subject: [PATCH 11/24] Tidy up test --- tests/testthat/test-parse2-rewrite.R | 51 ++-------------------------- tests/testthat/test-run-opt.R | 19 ----------- 2 files changed, 2 insertions(+), 68 deletions(-) delete mode 100644 tests/testthat/test-run-opt.R diff --git a/tests/testthat/test-parse2-rewrite.R b/tests/testthat/test-parse2-rewrite.R index e285f5cc..5125a259 100644 --- a/tests/testthat/test-parse2-rewrite.R +++ b/tests/testthat/test-parse2-rewrite.R @@ -15,10 +15,7 @@ test_that("log", { }) - - -test_that("rewrite arrays", { - options <- odin_options(rewrite_dims = TRUE, validate = TRUE) +test_that("rewrite arrays drops references to dim_ variables", { ir <- odin_parse({ n <- 2 m <- 2 @@ -29,51 +26,7 @@ test_that("rewrite arrays", { initial(I) <- 0 S0[, ] <- user() dim(S0) <- c(n, m) - }, options = options) - gen <- odin_generate(ir, options) - mod <- gen(S0 = matrix(runif(4), 2, 2)) - mod$run(0:10) - - options <- odin_options(validate = TRUE) - ir <- odin_parse({ - n <- 2 - m <- 2 - deriv(S[, ]) <- 0 - deriv(I) <- S[n, m] - dim(S) <- c(n, m) - initial(S[, ]) <- S0[i, j] - initial(I) <- 0 - S0[, ] <- user() - dim(S0) <- c(n, m) - }, options = options) - - odin_generate(ir, options) - - ir_deserialise(ir) - + }, options = odin_options(rewrite_dims = TRUE)) expect_false(grepl("dim_S_1", ir)) expect_false(grepl("dim_S", ir)) }) - - -test_that("rewrite arrays with shared dimensions", { - options <- odin_options(rewrite_dims = TRUE, validate = FALSE) - ir <- odin_parse({ - n <- user(integer = TRUE) - m <- user(integer = TRUE) - deriv(x[, ]) <- 0 - deriv(y[, ]) <- 0 - initial(x[, ]) <- 0 - initial(y[, ]) <- 0 - dim(x) <- c(n, m) - dim(y) <- c(n, m) - }, options = options) - gen <- odin_generate(ir, options) - mod <- gen(n = 4, m = 5) - mod$contents() - - static_eval(quote(a * 2 * 3)) - static_eval(quote(2 * 3 * a)) - - -}) diff --git a/tests/testthat/test-run-opt.R b/tests/testthat/test-run-opt.R deleted file mode 100644 index 7669cd6f..00000000 --- a/tests/testthat/test-run-opt.R +++ /dev/null @@ -1,19 +0,0 @@ -context("odin: opt") - -test_that("optimise dimensions away entirely", { - options <- odin_options(rewrite_dims = TRUE) - gen <- odin({ - n <- 2 - m <- 2 - deriv(S[, ]) <- 0 - deriv(I) <- S[n, m] - dim(S) <- c(n, m) - initial(S[, ]) <- S0[i, j] - initial(I) <- 0 - S0[, ] <- user() - dim(S0) <- c(n, m) - }, options = options) - - - -}) From fbcc70d153dc9807477efbc7d2c2a45eb0f5fb3e Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 9 Mar 2021 10:04:21 +0000 Subject: [PATCH 12/24] Update docs --- R/odin.R | 4 ++-- R/odin_options.R | 5 +---- man/odin.Rd | 8 ++++++-- man/odin_options.Rd | 4 ++-- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/R/odin.R b/R/odin.R index 72e746ad..625fce91 100644 --- a/R/odin.R +++ b/R/odin.R @@ -70,8 +70,8 @@ ##' messages about unused variables. Defaults to the option ##' `odin.no_check_unused_equations` or `FALSE` otherwise. ##' -##' @param options An [odin_options] object; if given then this -##' overrides all options above. +##' @param options Named list of options. If provided, then all other +##' options are ignored. ##' ##' @return A function that can generate the model ##' diff --git a/R/odin_options.R b/R/odin_options.R index e4eb67ac..91e71f89 100644 --- a/R/odin_options.R +++ b/R/odin_options.R @@ -1,14 +1,11 @@ ##' For lower-level odin functions [odin::odin_parse], -##' [odin::odin_validate] we accept a list of options rather +##' [odin::odin_validate] we only accept a list of options rather ##' than individually named options. ##' ##' @title Odin options ##' ##' @inheritParams odin ##' -##' @param options Named list of options. If provided, then all other -##' options are ignored. -##' ##' @export ##' @examples ##' odin_options() diff --git a/man/odin.Rd b/man/odin.Rd index d09ff6f2..09b22cc1 100644 --- a/man/odin.Rd +++ b/man/odin.Rd @@ -7,11 +7,12 @@ \usage{ odin(x, verbose = NULL, target = NULL, workdir = NULL, validate = NULL, pretty = NULL, skip_cache = NULL, compiler_warnings = NULL, - no_check_unused_equations = NULL) + no_check_unused_equations = NULL, options = NULL) odin_(x, verbose = NULL, target = NULL, workdir = NULL, validate = NULL, pretty = NULL, skip_cache = NULL, - compiler_warnings = NULL, no_check_unused_equations = NULL) + compiler_warnings = NULL, no_check_unused_equations = NULL, + options = NULL) } \arguments{ \item{x}{Either the name of a file to read, a text string (if @@ -52,6 +53,9 @@ version depending on underlying support in pkgbuild.} \item{no_check_unused_equations}{If \code{TRUE}, then don't print messages about unused variables. Defaults to the option \code{odin.no_check_unused_equations} or \code{FALSE} otherwise.} + +\item{options}{Named list of options. If provided, then all other +options are ignored.} } \value{ A function that can generate the model diff --git a/man/odin_options.Rd b/man/odin_options.Rd index 561dab05..e5ffe99a 100644 --- a/man/odin_options.Rd +++ b/man/odin_options.Rd @@ -7,7 +7,7 @@ odin_options(verbose = NULL, target = NULL, workdir = NULL, validate = NULL, pretty = NULL, skip_cache = NULL, compiler_warnings = NULL, no_check_unused_equations = NULL, - options = NULL) + rewrite_dims = NULL, options = NULL) } \arguments{ \item{verbose}{Logical scalar indicating if the compilation should @@ -50,7 +50,7 @@ options are ignored.} } \description{ For lower-level odin functions \link{odin_parse}, -\link{odin_validate} we accept a list of options rather +\link{odin_validate} we only accept a list of options rather than individually named options. } \examples{ From 6b17802649f8b3ee453d80ec397b44b2522c0d07 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 9 Mar 2021 10:08:31 +0000 Subject: [PATCH 13/24] Bump version and add news --- DESCRIPTION | 2 +- NEWS.md | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 8499b8a0..bebd1e3f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: odin Title: ODE Generation and Integration -Version: 1.1.8 +Version: 1.1.9 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Thibaut", "Jombart", role = "ctb"), diff --git a/NEWS.md b/NEWS.md index 66d7fbc5..07d58541 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,7 @@ +# odin 1.1.9 + +* New option `rewrite_dims` (via `odin::odin_options`) which will attempt to simplify common dimensions. This can reduce the number of variables carried around in the model as these are typically very redundant and also known at compile time (mrc-2093) + # odin 1.1.8 * Annotate equations with `# ignore.unused` to locally suppress messages about unused variables (mrc-2122) From d971556fac24c3e290f3586ad6439be594410dfa Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 9 Mar 2021 10:35:59 +0000 Subject: [PATCH 14/24] Add missing docs --- R/odin_options.R | 8 ++++++++ man/odin_options.Rd | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/R/odin_options.R b/R/odin_options.R index 91e71f89..e1b9d362 100644 --- a/R/odin_options.R +++ b/R/odin_options.R @@ -6,6 +6,14 @@ ##' ##' @inheritParams odin ##' +##' @param rewrite_dims Logical, indicating if odin should try and +##' rewrite your model dimensions (if using arrays). If `TRUE` then +##' we replace dimensions known at compile-time with literal +##' integers, and those known at initialisation with simplified and +##' shared expressions. You may get less-comprehensible error +##' messages with this option set to `TRUE` because parts of the +##' model have been effectively evaluated during processing. +##' ##' @export ##' @examples ##' odin_options() diff --git a/man/odin_options.Rd b/man/odin_options.Rd index e5ffe99a..176bb827 100644 --- a/man/odin_options.Rd +++ b/man/odin_options.Rd @@ -45,6 +45,14 @@ version depending on underlying support in pkgbuild.} messages about unused variables. Defaults to the option \code{odin.no_check_unused_equations} or \code{FALSE} otherwise.} +\item{rewrite_dims}{Logical, indicating if odin should try and +rewrite your model dimensions (if using arrays). If \code{TRUE} then +we replace dimensions known at compile-time with literal +integers, and those known at initialisation with simplified and +shared expressions. You may get less-comprehensible error +messages with this option set to \code{TRUE} because parts of the +model have been effectively evaluated during processing.} + \item{options}{Named list of options. If provided, then all other options are ignored.} } From 530faa443caabaf72c51cdd4f08f474ba81a2c03 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 9 Mar 2021 10:40:29 +0000 Subject: [PATCH 15/24] Check spelling --- NEWS.md | 2 +- inst/WORDLIST | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 07d58541..b8cb7f3e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -23,7 +23,7 @@ * The basic infrastructure has been overhauled, which will make some alternative compilation targets easier to support. We now use `pkgbuild` for the compilation which should ease debugging, and odin code compiled into packages will no longer issue a slew of warnings (and cooperate with automatic routine registration). This refactor has caused a few minor breaking changes: - `$initial()` always requires time, even if it is ignored - `$set_user()` and construction no longer work with positional argument matching - all arguments must be named - - The `$ir` field has become a method; add parens after it + - The `$ir` field has become a method; add parentheses after it - The `compiler_warnings` option has been removed # odin 1.0.7 diff --git a/inst/WORDLIST b/inst/WORDLIST index f65e4846..70f89ad7 100644 --- a/inst/WORDLIST +++ b/inst/WORDLIST @@ -1,4 +1,5 @@ AppVeyor +CodeFactor Deserialise Dormand FitzJohn @@ -40,7 +41,10 @@ json knitr lorenz mathcal +mrc odin's +pkgbuild +proc rightarrow rmarkdown roxygen From d61d4a70dc7ddd3c95f6c5db23ba4378d06379c6 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 9 Mar 2021 10:43:33 +0000 Subject: [PATCH 16/24] Remove unused branches --- R/ir_parse.R | 6 +----- R/utils.R | 5 ----- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/R/ir_parse.R b/R/ir_parse.R index ceabf044..8a623dfb 100644 --- a/R/ir_parse.R +++ b/R/ir_parse.R @@ -1533,12 +1533,8 @@ ir_parse_rewrite_dims <- function(eqs) { ## passed through (that will be an error elsewhere). if (identical(x_eq$type, "expression_scalar")) { compute(x_eq$rhs$value) - } else if (is.null(x_eq) || x_eq$type %in% c("user", "null")) { - ## TODO: we get 'null' here from interpolated variables that - ## are problematic. - x } else { - stop("CHECK") # I don't think this is possible and return 'x'? + x } } else if (is_call(x, "length")) { ## NOTE: use array_dim_name because we might hit things like diff --git a/R/utils.R b/R/utils.R index 5812cfb4..20468917 100644 --- a/R/utils.R +++ b/R/utils.R @@ -156,11 +156,6 @@ list_to_character <- function(x) { } -list_to_numeric <- function(x) { - vnapply(x, identity) -} - - sort_list <- function(x) { x[order(names(x))] } From d4b398c857f1d284d6f9e23f1300bd5415cb5d68 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 9 Mar 2021 11:28:36 +0000 Subject: [PATCH 17/24] Add class attribute --- R/odin_options.R | 3 +++ man/odin_options.Rd | 3 +++ tests/testthat/test-odin-options.R | 6 ++++++ 3 files changed, 12 insertions(+) diff --git a/R/odin_options.R b/R/odin_options.R index e1b9d362..eacced14 100644 --- a/R/odin_options.R +++ b/R/odin_options.R @@ -14,6 +14,8 @@ ##' messages with this option set to `TRUE` because parts of the ##' model have been effectively evaluated during processing. ##' +##' @return A list of parameters, of class `odin_options` +##' ##' @export ##' @examples ##' odin_options() @@ -62,5 +64,6 @@ odin_options <- function(verbose = NULL, target = NULL, workdir = NULL, read_include_unsupported(options$target)) } + class(options) <- "odin_options" options } diff --git a/man/odin_options.Rd b/man/odin_options.Rd index 176bb827..e25ddbdb 100644 --- a/man/odin_options.Rd +++ b/man/odin_options.Rd @@ -56,6 +56,9 @@ model have been effectively evaluated during processing.} \item{options}{Named list of options. If provided, then all other options are ignored.} } +\value{ +A list of parameters, of class \code{odin_options} +} \description{ For lower-level odin functions \link{odin_parse}, \link{odin_validate} we only accept a list of options rather diff --git a/tests/testthat/test-odin-options.R b/tests/testthat/test-odin-options.R index f866ea9e..72f7e037 100644 --- a/tests/testthat/test-odin-options.R +++ b/tests/testthat/test-odin-options.R @@ -1,5 +1,11 @@ context("odin_options") +test_that("odin_options creates a classed list", { + opts <- odin_options() + expect_s3_class(opts, "odin_options") + expect_true(is.list(opts)) +}) + test_that("can create placeholder handler for include parsing", { opts <- odin_options(target = "fortran") expect_error( From 89fec40dbae1aec28804f6f4729c80334ac66b54 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 9 Mar 2021 14:37:54 +0000 Subject: [PATCH 18/24] Minimal validation of inputs --- R/odin_options.R | 20 +++++++++++--------- R/utils.R | 20 ++++++++++++++++++++ tests/testthat/test-util.R | 17 +++++++++++++++++ 3 files changed, 48 insertions(+), 9 deletions(-) diff --git a/R/odin_options.R b/R/odin_options.R index eacced14..89a55e54 100644 --- a/R/odin_options.R +++ b/R/odin_options.R @@ -38,15 +38,17 @@ odin_options <- function(verbose = NULL, target = NULL, workdir = NULL, no_check_unused_equations = FALSE, compiler_warnings = FALSE) if (is.null(options)) { - options <- list(validate = validate, - verbose = verbose, - target = target, - pretty = pretty, - workdir = workdir, - skip_cache = skip_cache, - rewrite_dims = rewrite_dims, - no_check_unused_equations = no_check_unused_equations, - compiler_warnings = compiler_warnings) + options <- list( + validate = assert_scalar_logical_or_null(validate), + verbose = assert_scalar_logical_or_null(verbose), + target = target, + pretty = assert_scalar_logical_or_null(pretty), + workdir = workdir, + skip_cache = assert_scalar_logical_or_null(skip_cache), + rewrite_dims = assert_scalar_logical_or_null(rewrite_dims), + no_check_unused_equations = + assert_scalar_logical_or_null(no_check_unused_equations), + compiler_warnings = assert_scalar_logical_or_null(compiler_warnings)) } stopifnot(all(names(defaults) %in% names(options))) diff --git a/R/utils.R b/R/utils.R index 20468917..092757e9 100644 --- a/R/utils.R +++ b/R/utils.R @@ -246,3 +246,23 @@ flatten1 <- function(x) { na_drop <- function(x) { x[!is.na(x)] } + + +assert_scalar_logical_or_null <- function(x, name = deparse(substitute(x))) { + if (!is.null(x)) { + if (length(x) != 1 || !is.logical(x) || !is.na(x)) { + stop(sprintf("Expected '%s' to be a logical scalar (or NULL)", name)) + } + } + invisible(x) +} + + +assert_scalar_logical_or_null <- function(x, name = deparse(substitute(x))) { + if (!is.null(x)) { + if (length(x) != 1 || !is.logical(x) || is.na(x)) { + stop(sprintf("Expected '%s' to be a logical scalar (or NULL)", name)) + } + } + invisible(x) +} diff --git a/tests/testthat/test-util.R b/tests/testthat/test-util.R index 7d0f2ddf..4d082167 100644 --- a/tests/testthat/test-util.R +++ b/tests/testthat/test-util.R @@ -161,3 +161,20 @@ test_that("Don't set envvar if not needed", { mockery::mock_args(mock_compile_dll)[[1]], list(path, compile_attributes, quiet)) }) + + +test_that("validate inputs", { + expect_silent(assert_scalar_logical_or_null(NULL)) + expect_silent(assert_scalar_logical_or_null(TRUE)) + expect_silent(assert_scalar_logical_or_null(FALSE)) + + thing <- "true" + expect_error( + assert_scalar_logical_or_null(thing), + "Expected 'thing' to be a logical scalar (or NULL)", + fixed = TRUE) + expect_error(assert_scalar_logical_or_null(NA), + "Expected '.+' to be a logical scalar \\(or NULL\\)") + expect_error(assert_scalar_logical_or_null(logical(0)), + "Expected '.+' to be a logical scalar \\(or NULL\\)") +}) From 8092d2ff802439527cc86fd6469488b9a64f5946 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Thu, 11 Mar 2021 11:54:12 +0000 Subject: [PATCH 19/24] Fix ordering with multiple symbols --- R/opt.R | 2 +- tests/testthat/test-opt.R | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/R/opt.R b/R/opt.R index a22d74be..e3f8c9d5 100644 --- a/R/opt.R +++ b/R/opt.R @@ -58,5 +58,5 @@ collect_assoc <- function(args, fn) { order_args <- function(args) { i <- viapply(args, function(x) is.language(x) + is.recursive(x)) - args[order(i, decreasing = TRUE)] + args[order(-i, vcapply(args, deparse_str))] } diff --git a/tests/testthat/test-opt.R b/tests/testthat/test-opt.R index 4087420d..66253682 100644 --- a/tests/testthat/test-opt.R +++ b/tests/testthat/test-opt.R @@ -45,3 +45,16 @@ test_that("More complex examples", { expect_equal(static_eval(quote((1 + 4) * (b + 3))), quote((b + 3) * 5)) }) + + +test_that("sort expressions", { + expect_equal( + static_eval(quote(a + 1 + b + 2)), + quote(a + b + 3)) + expect_equal( + static_eval(quote(1 + b + a + 2)), + quote(a + b + 3)) + expect_equal( + static_eval(quote(1 + b + a + 2 + x * y)), + quote(x * y a + b + 3)) +}) From 2e647331feb153058ec9ddcaf679d944a75951d1 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Thu, 11 Mar 2021 11:59:20 +0000 Subject: [PATCH 20/24] Fix syntax --- tests/testthat/test-opt.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/test-opt.R b/tests/testthat/test-opt.R index 66253682..dce82215 100644 --- a/tests/testthat/test-opt.R +++ b/tests/testthat/test-opt.R @@ -56,5 +56,5 @@ test_that("sort expressions", { quote(a + b + 3)) expect_equal( static_eval(quote(1 + b + a + 2 + x * y)), - quote(x * y a + b + 3)) + quote(x * y + a + b + 3)) }) From 0df2965afe8d4a02c40bd48a798f64681d620805 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Fri, 12 Mar 2021 19:03:33 +0000 Subject: [PATCH 21/24] Treat more corner cases --- R/opt.R | 11 +++++++++++ tests/testthat/test-opt.R | 18 ++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/R/opt.R b/R/opt.R index e3f8c9d5..11f3d86b 100644 --- a/R/opt.R +++ b/R/opt.R @@ -36,6 +36,17 @@ static_eval_assoc <- function(expr) { args <- c(args[!i], eval(r_fold_call(fn, args[i]), baseenv())) } + if (fn == "+") { + args <- args[!vlapply(args, function(x) is.numeric(x) && x == 0)] + } + + if (fn == "*") { + if (any(vlapply(args, function(x) is.numeric(x) && x == 0))) { + args <- list(0) + } + args <- args[!vlapply(args, function(x) is.numeric(x) && x == 1)] + } + if (length(args) == 1L) { return(args[[1L]]) } diff --git a/tests/testthat/test-opt.R b/tests/testthat/test-opt.R index dce82215..c3fb0c0c 100644 --- a/tests/testthat/test-opt.R +++ b/tests/testthat/test-opt.R @@ -58,3 +58,21 @@ test_that("sort expressions", { static_eval(quote(1 + b + a + 2 + x * y)), quote(x * y + a + b + 3)) }) + + +test_that("Addition of zero is a noop", { + expect_equal(static_eval(quote(a + 0)), quote(a)) + expect_equal(static_eval(quote(a + 0 + b)), quote(a + b)) +}) + + +test_that("Multiplication by one is a noop", { + expect_equal(static_eval(quote(a * 1)), quote(a)) + expect_equal(static_eval(quote(a * 1 * b)), quote(a * b)) +}) + + +test_that("Multiplication by one is a noop", { + expect_equal(static_eval(quote(a * 1)), quote(a)) + expect_equal(static_eval(quote(a * 1 * b)), quote(a * b)) +}) From fa1ca538debcc9bb5df1b730ff03a87fcf480ac5 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 15 Mar 2021 13:40:12 +0000 Subject: [PATCH 22/24] Cope with large sums --- R/opt.R | 46 ++++++++++++++++++++++++++++----------- tests/testthat/test-opt.R | 34 +++++++++++++++++++---------- 2 files changed, 56 insertions(+), 24 deletions(-) diff --git a/R/opt.R b/R/opt.R index 11f3d86b..34988edd 100644 --- a/R/opt.R +++ b/R/opt.R @@ -28,8 +28,14 @@ static_eval <- function(expr) { static_eval_assoc <- function(expr) { + expr <- flatten_assoc(expr) + expr[-1] <- lapply(expr[-1], static_eval) + + ## We need a *second* round here of flatten_assoc + expr <- flatten_assoc(expr) + fn <- as.character(expr[[1]]) - args <- collect_assoc(lapply(expr[-1], static_eval), fn) + args <- expr[-1L] i <- vlapply(args, is.numeric) if (any(i)) { @@ -38,6 +44,18 @@ static_eval_assoc <- function(expr) { if (fn == "+") { args <- args[!vlapply(args, function(x) is.numeric(x) && x == 0)] + + ## Collect linear combinations of shared parameters here; this + ## causes issues for simplifying general expressions (e.g., a + 1 + ## * (a + a) will end up as 2 * a + a) but odin doesn't generate + ## things like that (yet). + i <- match(args, args) + if (anyDuplicated(i)) { + for (k in unique(i[duplicated(i)])) { + args[[k]] <- call("*", args[[k]], as.numeric(sum(i == k))) + } + args <- args[!duplicated(i)] + } } if (fn == "*") { @@ -55,19 +73,21 @@ static_eval_assoc <- function(expr) { } -collect_assoc <- function(args, fn) { - args <- as.list(args) - i <- vlapply(args, is_call, fn) - if (any(i)) { - args[i] <- lapply(args[i], function(x) collect_assoc(x[-1], fn)) - flatten1(args) - } else { - args - } -} - - order_args <- function(args) { i <- viapply(args, function(x) is.language(x) + is.recursive(x)) args[order(-i, vcapply(args, deparse_str))] } + + +flatten_assoc <- function(expr) { + fn <- expr[[1L]] + check <- as.list(expr[-1L]) + args <- list() + while (length(check) > 0) { + i <- vlapply(check, is_call, fn) + args <- c(args, check[!i]) + check <- unlist(lapply(check[i], function(x) as.list(x[-1])), FALSE) + } + + c(list(fn), args) +} diff --git a/tests/testthat/test-opt.R b/tests/testthat/test-opt.R index c3fb0c0c..1b4d70ee 100644 --- a/tests/testthat/test-opt.R +++ b/tests/testthat/test-opt.R @@ -23,14 +23,6 @@ test_that("static_eval collects numbers up associatively", { }) -test_that("collect_assoc unfolds expressions", { - expect_equal(collect_assoc(quote(a + b + c), quote(`+`)), - list(quote(`+`), quote(a), quote(b), quote(c))) - expect_equal(collect_assoc(quote(a + 1 + b + 2 + c + 3), quote(`+`)), - list(quote(`+`), quote(a), 1, quote(b), 2, quote(c), 3)) -}) - - test_that("static_eval removes superfluous parens", { expect_equal(static_eval(quote(1 + (a + 2))), quote(a + 3)) expect_equal(static_eval(quote(1 + (a + 2) + 3)), quote(a + 6)) @@ -72,7 +64,27 @@ test_that("Multiplication by one is a noop", { }) -test_that("Multiplication by one is a noop", { - expect_equal(static_eval(quote(a * 1)), quote(a)) - expect_equal(static_eval(quote(a * 1 * b)), quote(a * b)) +test_that("Multiplication by zero is catatrophic", { + expect_equal(static_eval(quote(a * 0)), 0) + expect_equal(static_eval(quote(a * 0 * b)), 0) +}) + + +test_that("Can evaluate very long expressions", { + v <- sprintf("x%d", seq_len(200)) + e <- parse(text = paste(v, collapse = " + "))[[1]] + expect_equal( + static_eval(e), + r_fold_call("+", lapply(sort(v), as.name))) +}) + + +test_that("Can collect linear combinations", { + expect_equal( + static_eval(quote(a + b + a + b + a + 4)), + quote(a * 3 + b * 2 + 4)) + ## This is something to pick up later + expect_equal( + static_eval(quote(a + 1 * (a + a))), + quote(a * 2 + a)) }) From e152547b6acef711a2151f5ef1d309398b1c0d12 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Wed, 17 Mar 2021 16:41:06 +0000 Subject: [PATCH 23/24] Fix corner case with zero addition --- R/opt.R | 3 +++ tests/testthat/test-opt.R | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/R/opt.R b/R/opt.R index 34988edd..ce2c28a3 100644 --- a/R/opt.R +++ b/R/opt.R @@ -44,6 +44,9 @@ static_eval_assoc <- function(expr) { if (fn == "+") { args <- args[!vlapply(args, function(x) is.numeric(x) && x == 0)] + if (length(args) == 0) { + return(0) + } ## Collect linear combinations of shared parameters here; this ## causes issues for simplifying general expressions (e.g., a + 1 diff --git a/tests/testthat/test-opt.R b/tests/testthat/test-opt.R index 1b4d70ee..794c6809 100644 --- a/tests/testthat/test-opt.R +++ b/tests/testthat/test-opt.R @@ -88,3 +88,13 @@ test_that("Can collect linear combinations", { static_eval(quote(a + 1 * (a + a))), quote(a * 2 + a)) }) + + +test_that("cope with adding zeros", { + expect_equal( + static_eval(quote(0 + 0)), + 0) + expect_equal( + static_eval(quote(0 * x + 1 * 0)), + 0) +}) From f2b63a8b390e24057c46a769fa56e2ecb49f08d0 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Wed, 17 Mar 2021 16:48:07 +0000 Subject: [PATCH 24/24] Remove unused utility --- R/utils.R | 5 ----- 1 file changed, 5 deletions(-) diff --git a/R/utils.R b/R/utils.R index 092757e9..41dcee4f 100644 --- a/R/utils.R +++ b/R/utils.R @@ -238,11 +238,6 @@ clean_package_name <- function(name) { } -flatten1 <- function(x) { - unlist(x, FALSE, FALSE) -} - - na_drop <- function(x) { x[!is.na(x)] }