diff --git a/DESCRIPTION b/DESCRIPTION index e8c8f6a1..96faa3de 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: odin Title: ODE Generation and Integration -Version: 1.1.11 +Version: 1.1.12 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Thibaut", "Jombart", role = "ctb"), diff --git a/NEWS.md b/NEWS.md index f394a5fe..adb3839c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,7 @@ +# odin 1.1.12 + +* New option `rewrite_constants` (via `odin::odin_options`) which attempts to rewrite all constants in the model code before generation. This can considerably reduce the number of variable lookups (mrc-2252) + # odin 1.1.11 * New option `substitutions` (via `odin::odin_options`) which can substitute in scalar `user` values at compile time (#220) diff --git a/R/generate_c_compiled.R b/R/generate_c_compiled.R index 7f1e1594..c4781337 100644 --- a/R/generate_c_compiled.R +++ b/R/generate_c_compiled.R @@ -319,7 +319,7 @@ generate_c_compiled_rhs_r <- function(dat, rewrite) { sprintf_safe("const %s %s = %s;", time_type, dat$meta$initial_time, initial_time), c_expr_if(sprintf_safe("ISNA(%s)", dat$meta$initial_time), - sprintf_safe('%s = %s(%s, "dat$meta$time");', + sprintf_safe('%s = %s(%s, "%s");', initial_time, time_access, dat$meta$time, dat$meta$time))) reset_initial_time <- diff --git a/R/ir_parse.R b/R/ir_parse.R index 29b15925..28e5fb3d 100644 --- a/R/ir_parse.R +++ b/R/ir_parse.R @@ -16,14 +16,15 @@ ir_parse <- function(x, options, type = NULL) { variables <- ir_parse_find_variables(eqs, features$discrete, source) eqs <- lapply(eqs, ir_parse_rewrite_initial, variables) - eqs <- ir_parse_arrays(eqs, variables, config$include$names, source) ## This performs a round of optimisation where we try to simplify ## away expressions for the dimensions, which reduces the number of ## required variables. eqs <- ir_parse_substitute(eqs, options$substitutions) - if (options$rewrite_dims && features$has_array) { + if (options$rewrite_constants) { + eqs <- ir_parse_rewrite_constants(eqs) + } else if (options$rewrite_dims && features$has_array) { eqs <- ir_parse_rewrite_dims(eqs) } @@ -1553,59 +1554,100 @@ ir_parse_substitute <- function(eqs, subs) { } -## This approach could probably be applied over the whole tree really, -## as we might be able to eliminate some other compile time -## things. However, doing that will make the models less debuggable. ir_parse_rewrite_dims <- function(eqs) { - compute <- function(x) { - if (is.numeric(x)) { - x - } else if (is.symbol(x)) { + nms <- names_if(vlapply(eqs, function(x) isTRUE(x$lhs$dim))) + ir_parse_rewrite(nms, eqs) +} + + +ir_parse_rewrite_constants <- function(eqs) { + nms <- names_if(vlapply(eqs, function(x) x$type == "expression_scalar")) + ir_parse_rewrite(nms, eqs) +} + + +ir_parse_rewrite_compute_eqs <- function(nms, eqs) { + cache <- new_empty_env() + lapply(eqs[nms], function(eq) + static_eval(ir_parse_rewrite_compute(eq$rhs$value, eqs, cache))) +} + + +ir_parse_rewrite_compute <- function(x, eqs, cache) { + key <- deparse_str(x) + if (key %in% names(cache)) { + return(cache[[key]]) + } + + if (!is.numeric(x)) { + if (is.symbol(x)) { x_eq <- eqs[[deparse_str(x)]] ## use identical() here to cope with x_eq being NULL when 't' is ## passed through (that will be an error elsewhere). if (identical(x_eq$type, "expression_scalar")) { - compute(x_eq$rhs$value) - } else { - x + x <- ir_parse_rewrite_compute(x_eq$rhs$value, eqs, cache) } } else if (is_call(x, "length")) { ## NOTE: use array_dim_name because we might hit things like ## length(y) where 'y' is one of the variables; we can't look up ## eqs[[name]]$array$length without checking that. - compute(as.name(array_dim_name(as.character(x[[2]])))) + length_name <- as.name(array_dim_name(as.character(x[[2]]))) + x <- ir_parse_rewrite_compute(length_name, eqs, cache) } else if (is_call(x, "dim")) { - compute(as.name(array_dim_name(as.character(x[[2]]), x[[3]]))) + dim_name <- as.name(array_dim_name(as.character(x[[2]]), x[[3]])) + x <- ir_parse_rewrite_compute(dim_name, eqs, cache) } else if (is.recursive(x)) { - x[-1] <- lapply(x[-1], compute) - x - } else { # NULL - x + x[-1] <- lapply(x[-1], ir_parse_rewrite_compute, eqs, cache) } } - ## alternatively look in all $array$dimnames elements - nms <- grep("dim_", names(eqs), value = TRUE) + cache[[key]] <- x + x +} + - val <- lapply(eqs[nms], function(eq) - static_eval(compute(eq$rhs$value))) +ir_parse_rewrite <- function(nms, eqs) { + val <- tryCatch( + ir_parse_rewrite_compute_eqs(nms, eqs), + error = function(e) { + message("Rewrite failure: ", e$message) + list() + }) rewrite <- vlapply(val, function(x) is.symbol(x) || is.numeric(x)) subs <- val[rewrite] + ## One small wrinkle here: don't rewrite things that are the target + ## of a copy as the rewrite is complicated. This affects almost + ## nothing in reality outside the tests? + copy_self <- unlist(lapply(eqs, function(x) + if (x$type == "copy") x$lhs$name_data), FALSE) + subs <- subs[setdiff(names(subs), copy_self)] + + is_dim <- vlapply(eqs, function(x) isTRUE(x$lhs$dim)) + ## Try and deduplicate the rest. However, it's not totally obvious ## that we can do this without creating some weird dependency - ## issues. - leave <- val[!rewrite] - ## Do not deduplicate NULL dimensions; these are set by user() later. - dup <- duplicated(leave) & !vlapply(leave, is.null) - if (any(dup)) { - i <- match(leave[dup], leave) - subs <- c(subs, - set_names(lapply(names(leave)[i], as.name), names(leave)[dup])) + ## issues. Also we need to only treat dimensions (and possibly + ## offsets); we could do any compile-time thing really but we don't + ## know it yet. Propagating other expressions through though can + ## create problems. + check <- val[intersect(names_if(!rewrite), names_if(is_dim))] + + if (length(check) > 0) { + dup <- duplicated(check) & !vlapply(check, is.null) + if (any(dup)) { + i <- match(check[dup], check) + subs <- c(subs, + set_names(lapply(names(check)[i], as.name), names(check)[dup])) + } } + subs_env <- list2env(subs, parent = emptyenv()) + subs_dep <- vcapply(subs, function(x) + if (is.numeric(x)) NA_character_ else deparse_str(x)) + replace <- function(x, y) { i <- match(vcapply(x, function(x) x %||% ""), names(y)) j <- which(!is.na(i)) @@ -1613,12 +1655,22 @@ ir_parse_rewrite_dims <- function(eqs) { na_drop(x) } - subs_env <- list2env(subs, parent = emptyenv()) - subs_dep <- vcapply(subs, function(x) - if (is.numeric(x)) NA_character_ else deparse_str(x)) + rewrite_eq_array_part <- function(el) { + el$value <- substitute_(el$value, subs_env) + for (i in seq_along(el$index)) { + el$index[[i]]$value <- substitute_(el$index[[i]]$value, subs_env) + } + el + } rewrite_eq <- function(eq) { - eq$rhs$value <- substitute_(eq$rhs$value, subs_env) + if (eq$type == "expression_array") { + eq$rhs <- lapply(eq$rhs, rewrite_eq_array_part) + } else if (eq$name %in% names(subs)) { + eq$rhs$value <- subs[[eq$name]] + } else { + eq$rhs$value <- substitute_(eq$rhs$value, subs_env) + } eq$depends$variables <- replace(eq$depends$variables, subs_dep) eq$lhs$depends$variables <- replace(eq$lhs$depends$variables, subs_dep) @@ -1630,6 +1682,7 @@ ir_parse_rewrite_dims <- function(eqs) { } if (!is.null(eq$delay)) { + eq$delay$time <- substitute_(eq$delay$time, subs_env) eq$delay$depends$variables <- replace(eq$delay$depends$variables, subs_dep) } @@ -1637,5 +1690,9 @@ ir_parse_rewrite_dims <- function(eqs) { eq } - lapply(eqs[setdiff(names(eqs), names(subs))], rewrite_eq) + ## Can't drop initial(), deriv(), or update() calls even if they are + ## constants. + keep <- names_if(!vlapply(eqs, function(x) is.null(x$lhs$special))) + i <- setdiff(names(eqs), setdiff(names(subs), keep)) + lapply(eqs[i], rewrite_eq) } diff --git a/R/ir_parse_arrays.R b/R/ir_parse_arrays.R index 42e3f1e9..109cedf8 100644 --- a/R/ir_parse_arrays.R +++ b/R/ir_parse_arrays.R @@ -444,7 +444,8 @@ ir_parse_arrays_dims <- function(eq, eqs, rank, variables, output) { lhs = list(name_lhs = eq$name, name_data = eq$name, name_equation = eq$name, - storage_type = "int"), + storage_type = "int", + dim = TRUE), rhs = eq$rhs, depends = depends_dim, source = eq$source) @@ -458,7 +459,7 @@ ir_parse_arrays_dims <- function(eq, eqs, rank, variables, output) { name = d, type = type, lhs = list(name_lhs = d, name_data = d, name_equation = d, - storage_type = "int"), + storage_type = "int", dim = TRUE), rhs = list(value = eq$rhs$value[[i]]), implicit = TRUE, source = eq$source, @@ -483,7 +484,7 @@ ir_parse_arrays_dims <- function(eq, eqs, rank, variables, output) { name = d, type = "expression_scalar", lhs = list(name_lhs = d, name_data = d, name_equation = d, - storage_type = "int"), + storage_type = "int", dim = TRUE), rhs = list(value = r_fold_call("*", dims[j])), implicit = TRUE, source = eq$source, diff --git a/R/odin_options.R b/R/odin_options.R index d495ba95..dd6dc3f0 100644 --- a/R/odin_options.R +++ b/R/odin_options.R @@ -14,11 +14,21 @@ ##' messages with this option set to `TRUE` because parts of the ##' model have been effectively evaluated during processing. ##' -##' @param substitutions Optionally, a list of values to substitute into -##' model specification as constants, even though they are declared -##' as `user()`. This will be most useful in conjunction with -##' `rewrite_dims` to create a copy of your model with dimensions -##' known at compile time and all loops using literal integers. +##' @param rewrite_constants Logical, indicating if odin should try +##' and rewrite *all* constant scalars. This is a superset of +##' `rewrite_dims` and may be slow for large models. Doing this will +##' make your model less debuggable; error messages will reference +##' expressions that have been extensively rewritten, some variables +##' will have been removed entirely or merged with other identical +##' expressions, and the generated code may not be obviously +##' connected to the original code. +##' +##' @param substitutions Optionally, a list of values to substitute +##' into model specification as constants, even though they are +##' declared as `user()`. This will be most useful in conjunction +##' with `rewrite_dims` to create a copy of your model with +##' dimensions known at compile time and all loops using literal +##' integers. ##' ##' @return A list of parameters, of class `odin_options` ##' @@ -29,8 +39,8 @@ odin_options <- function(verbose = NULL, target = NULL, workdir = NULL, validate = NULL, pretty = NULL, skip_cache = NULL, compiler_warnings = NULL, no_check_unused_equations = NULL, - rewrite_dims = NULL, substitutions = NULL, - options = NULL) { + rewrite_dims = NULL, rewrite_constants = NULL, + substitutions = NULL, options = NULL) { default_target <- if (is.null(target) && !can_compile(verbose = FALSE)) "r" else "c" defaults <- list( @@ -41,6 +51,7 @@ odin_options <- function(verbose = NULL, target = NULL, workdir = NULL, pretty = FALSE, skip_cache = FALSE, rewrite_dims = FALSE, + rewrite_constants = FALSE, substitutions = NULL, no_check_unused_equations = FALSE, compiler_warnings = FALSE) @@ -53,6 +64,7 @@ odin_options <- function(verbose = NULL, target = NULL, workdir = NULL, workdir = workdir, skip_cache = assert_scalar_logical_or_null(skip_cache), rewrite_dims = assert_scalar_logical_or_null(rewrite_dims), + rewrite_constants = assert_scalar_logical_or_null(rewrite_constants), substitutions = check_substitutions(substitutions), no_check_unused_equations = assert_scalar_logical_or_null(no_check_unused_equations), diff --git a/man/odin_options.Rd b/man/odin_options.Rd index 78c945cc..db018bfb 100644 --- a/man/odin_options.Rd +++ b/man/odin_options.Rd @@ -7,7 +7,8 @@ odin_options(verbose = NULL, target = NULL, workdir = NULL, validate = NULL, pretty = NULL, skip_cache = NULL, compiler_warnings = NULL, no_check_unused_equations = NULL, - rewrite_dims = NULL, substitutions = NULL, options = NULL) + rewrite_dims = NULL, rewrite_constants = NULL, substitutions = NULL, + options = NULL) } \arguments{ \item{verbose}{Logical scalar indicating if the compilation should @@ -53,11 +54,21 @@ shared expressions. You may get less-comprehensible error messages with this option set to \code{TRUE} because parts of the model have been effectively evaluated during processing.} -\item{substitutions}{Optionally, a list of values to substitute into -model specification as constants, even though they are declared -as \code{user()}. This will be most useful in conjunction with -\code{rewrite_dims} to create a copy of your model with dimensions -known at compile time and all loops using literal integers.} +\item{rewrite_constants}{Logical, indicating if odin should try +and rewrite \emph{all} constant scalars. This is a superset of +\code{rewrite_dims} and may be slow for large models. Doing this will +make your model less debuggable; error messages will reference +expressions that have been extensively rewritten, some variables +will have been removed entirely or merged with other identical +expressions, and the generated code may not be obviously +connected to the original code.} + +\item{substitutions}{Optionally, a list of values to substitute +into model specification as constants, even though they are +declared as \code{user()}. This will be most useful in conjunction +with \code{rewrite_dims} to create a copy of your model with +dimensions known at compile time and all loops using literal +integers.} \item{options}{Named list of options. If provided, then all other options are ignored.} diff --git a/tests/testthat/helper-odin.R b/tests/testthat/helper-odin.R index e868b7ed..758a990b 100644 --- a/tests/testthat/helper-odin.R +++ b/tests/testthat/helper-odin.R @@ -129,7 +129,7 @@ test_that_odin <- function(desc, code) { code_enq <- rlang::enquo(code) for (target in targets) { opts <- list(odin.target = target, - odin.rewrite_dims = target == "c") + odin.rewrite_constants = target == "c") testthat::test_that(sprintf("%s (%s)", desc, target), withr::with_options(opts, rlang::eval_tidy(code_enq))) } diff --git a/tests/testthat/test-ir.R b/tests/testthat/test-ir.R index 294635f1..183b5693 100644 --- a/tests/testthat/test-ir.R +++ b/tests/testthat/test-ir.R @@ -12,7 +12,8 @@ test_that("deserialise", { test_that("Stage information included in IR", { - ir <- odin_parse_("examples/array_odin.R") + ir <- odin_parse_("examples/array_odin.R", + options = odin_options(rewrite_constants = FALSE)) dat <- odin_ir_deserialise(ir) expect_equal(dat$data$elements$N_age$stage, "constant") expect_equal(dat$data$elements$I_tot$stage, "time") diff --git a/tests/testthat/test-odin-validate.R b/tests/testthat/test-odin-validate.R index efddaa04..0b2e0925 100644 --- a/tests/testthat/test-odin-validate.R +++ b/tests/testthat/test-odin-validate.R @@ -23,7 +23,8 @@ test_that("invalid model", { test_that("unused variables can be detected", { code <- c("initial(x) <- 1", "deriv(x) <- 1", "a <- 1") - res <- odin_validate(code, "text") + res <- odin_validate(code, "text", + odin_options(rewrite_constants = FALSE)) expect_equal(length(res$messages), 1L) expect_match(res$messages[[1]]$msg, "Unused equation: a") expect_equivalent(res$messages[[1]]$line, 3) diff --git a/tests/testthat/test-parse2-general.R b/tests/testthat/test-parse2-general.R index 8389b3f2..4b292913 100644 --- a/tests/testthat/test-parse2-general.R +++ b/tests/testthat/test-parse2-general.R @@ -370,13 +370,18 @@ test_that("recursive variables", { }) test_that("array extent and time", { - for (rewrite_dims in c(FALSE, TRUE)) { + opts <- list( + odin_options(rewrite_dims = FALSE, rewrite_constants = FALSE), + odin_options(rewrite_dims = TRUE, rewrite_constants = FALSE), + odin_options(rewrite_dims = FALSE, rewrite_constants = TRUE)) + + for (o in opts) { expect_error( odin_parse_(quote({ deriv(y[]) <- 1 initial(y[]) <- 0 dim(y) <- t - }), options = odin_options(rewrite_dims = rewrite_dims)), + }), options = o), "Array extent is determined by time", class = "odin_error") expect_error( @@ -385,7 +390,7 @@ test_that("array extent and time", { initial(y[]) <- 0 a <- t dim(y) <- a - }), options = odin_options(rewrite_dims = rewrite_dims)), + }), options = o), "Array extent is determined by time", class = "odin_error") expect_error( @@ -395,7 +400,7 @@ test_that("array extent and time", { deriv(z) <- 1 initial(z) <- 0 dim(y) <- z - }), options = odin_options(rewrite_dims = rewrite_dims)), + }), options = o), "Array extent is determined by time", class = "odin_error") } }) @@ -616,12 +621,17 @@ test_that("check array rhs", { ## Probably more needed here as there are some special cases... test_that("cyclic dependency", { - expect_error( - odin_parse_(ex("x <- y; y <- x")), - "A cyclic dependency detected") - expect_error( - odin_parse_(ex("x <- y; y <- z; z <- x")), - "A cyclic dependency detected") + opts <- list( + odin_options(rewrite_constants = FALSE), + odin_options(rewrite_constants = TRUE)) + for (o in opts) { + expect_error( + odin_parse_(ex("x <- y; y <- x"), options = o), + "A cyclic dependency detected") + expect_error( + odin_parse_(ex("x <- y; y <- z; z <- x"), options = o), + "A cyclic dependency detected") + } }) test_that("range operator on RHS", { @@ -747,7 +757,7 @@ test_that("detect integers", { initial(I) <- 0 S0[, ] <- user() dim(S0) <- c(n, m) - }) + }, options = odin_options(rewrite_constants = FALSE)) dat <- ir_deserialise(ir) type <- vcapply(dat$data$elements, "[[", "storage_type") int <- names_if(type == "int") diff --git a/tests/testthat/test-parse2-rewrite.R b/tests/testthat/test-parse2-rewrite.R index 79b9524a..f70f0fcb 100644 --- a/tests/testthat/test-parse2-rewrite.R +++ b/tests/testthat/test-parse2-rewrite.R @@ -26,7 +26,8 @@ test_that("rewrite arrays drops references to dim_ variables", { initial(I) <- 0 S0[, ] <- user() dim(S0) <- c(n, m) - }, options = odin_options(rewrite_dims = TRUE)) + }, options = odin_options(rewrite_dims = TRUE, + rewrite_constants = FALSE)) expect_false(grepl("dim_S_1", ir)) expect_false(grepl("dim_S", ir)) }) @@ -44,6 +45,7 @@ test_that("Can create compile-time constants", { S0[, ] <- user() dim(S0) <- c(n, m) }, options = odin_options(rewrite_dims = TRUE, + rewrite_constants = FALSE, substitutions = list(n = 2, m = 3))) dat <- ir_deserialise(ir) expect_equal(dat$equations$n$type, "expression_scalar") @@ -102,3 +104,50 @@ test_that("Can validate substitutions", { odin_parse(code, options = odin_options(substitutions = c(n = 1))), "'substitutions' must be a list") }) + + +test_that("Rewrite all constants", { + ir <- odin_parse({ + a <- 10 + b <- 20 + c <- 30 + initial(x) <- 0 + deriv(x) <- a + b * c + }, options = odin_options(rewrite_constants = TRUE)) + dat <- ir_deserialise(ir) + expect_length(dat$equations, 2) + expect_setequal(names(dat$equations), c("initial_x", "deriv_x")) + expect_equal(dat$equations$deriv_x$rhs$value, 610) # i.e., 10 + 20 * 30 +}) + + +test_that("leave time-varying expressions alone", { + ir <- odin_parse({ + a <- 2 * t + deriv(x) <- a * 3 + deriv(y) <- a * 4 + initial(x) <- 0 + initial(y) <- 0 + }, options = odin_options(rewrite_constants = TRUE)) + dat <- ir_deserialise(ir) + expect_equal( + dat$equations$deriv_x$rhs$value, + list("*", "a", 3)) + expect_equal( + dat$equations$deriv_y$rhs$value, + list("*", "a", 4)) +}) + +test_that("collapse complex constants into expressions", { + ir <- odin_parse({ + a <- 2 * t + b <- 2 * n + n <- 4 + deriv(x) <- a + b + initial(x) <- 0 + }, options = odin_options(rewrite_constants = TRUE)) + dat <- ir_deserialise(ir) + expect_equal( + dat$equations$deriv_x$rhs$value, + list("+", "a", 8)) +}) diff --git a/tests/testthat/test-parse2-unused.R b/tests/testthat/test-parse2-unused.R index ec398c5f..da52b49f 100644 --- a/tests/testthat/test-parse2-unused.R +++ b/tests/testthat/test-parse2-unused.R @@ -12,13 +12,15 @@ test_that("one unused variable", { deriv(y) <- 1 initial(y) <- 0 a <- 1 - }), "Unused equation: a") + }, options = odin_options(rewrite_constants = FALSE)), + "Unused equation: a") expect_silent(odin_parse({ deriv(y) <- 1 initial(y) <- 0 a <- 1 - }, options = odin_options(no_check_unused_equations = TRUE))) + }, options = odin_options(rewrite_constants = FALSE, + no_check_unused_equations = TRUE))) }) test_that("more than one unused variable", { @@ -27,7 +29,7 @@ test_that("more than one unused variable", { initial(y) <- 0 a <- 1 b <- 2 - }), + }, options = odin_options(rewrite_constants = FALSE)), "Unused equations: a, b") }) @@ -37,7 +39,7 @@ test_that("dependent unused variables", { initial(y) <- 0 a <- 1 b <- a * 2 - }), + }, options = odin_options(rewrite_constants = FALSE)), "Unused equations: a, b") }) @@ -47,7 +49,7 @@ test_that("dependent non-unused variables", { initial(y) <- 0 a <- 1 b <- a * 2 - })) + }, options = odin_options(rewrite_constants = FALSE))) }) test_that("delayed non-unused variables", { diff --git a/tests/testthat/test-run-basic.R b/tests/testthat/test-run-basic.R index d2abcff7..03f7f9e7 100644 --- a/tests/testthat/test-run-basic.R +++ b/tests/testthat/test-run-basic.R @@ -5,7 +5,7 @@ test_that_odin("trivial model", { deriv(y) <- r initial(y) <- 1 r <- 2 - }) + }, options = odin_options(rewrite_constants = FALSE)) mod <- gen() expect_is(mod, "odin_model") @@ -70,8 +70,7 @@ test_that_odin("Time dependent initial conditions", { expect_equal(mod$deriv(0, 1), f(0)) expect_equal(mod$deriv(1, 1), f(1)) - expect_equal(sort_list(mod$contents()), - sort_list(list(initial_y3 = f(1), r = 1))) + expect_equal(mod$contents()$initial_y3, f(1)) }) @@ -130,10 +129,10 @@ test_that_odin("user variables", { expect_error(gen(r = numeric(0)), "Expected a scalar numeric for 'r'") - expect_equal(sort_list(gen(r = pi)$contents()), - sort_list(list(K = 100, N0 = 1, initial_N = 1, r = pi))) - expect_equal(sort_list(gen(r = pi, N0 = 10)$contents()), - sort_list(list(K = 100, N0 = 10, initial_N = 10, r = pi))) + expect_equal(gen(r = pi)$contents()[c("N0", "r")], + list(N0 = 1, r = pi)) + expect_equal(gen(r = pi, N0 = 10)$contents()[c("N0", "r")], + list(N0 = 10, r = pi)) expect_equal(gen(r = pi, N0 = 10)$initial(0), 10) expect_equal(gen(r = pi, N0 = 10)$deriv(0, 10), pi * 10 * (1 - 10 / 100)) @@ -266,7 +265,7 @@ test_that_odin("array support", { n <- 3 dim(r) <- n dim(x) <- n - }, options = odin_options(rewrite_dims = FALSE)) + }, options = odin_options(rewrite_constants = FALSE, rewrite_dims = FALSE)) mod <- gen() @@ -310,7 +309,7 @@ test_that_odin("3d array", { initial(y[, , ]) <- 1 deriv(y[, , ]) <- y[i, j, k] * 0.1 dim(y) <- c(2, 3, 4) - }, options = odin_options(rewrite_dims = FALSE)) + }, options = odin_options(rewrite_constants = FALSE, rewrite_dims = FALSE)) mod <- gen() d <- mod$contents() @@ -388,7 +387,7 @@ test_that_odin("user array - indirect", { dim(r) <- n dim(x) <- n n <- user() - }, options = odin_options(rewrite_dims = FALSE)) + }, options = odin_options(rewrite_constants = FALSE, rewrite_dims = FALSE)) mod <- gen(n = 3, r = 1:3) expect_equal(sort_list(mod$contents()), @@ -411,7 +410,7 @@ test_that_odin("user array - direct", { r[] <- user() dim(r) <- user() dim(x) <- length(r) - }, options = odin_options(rewrite_dims = FALSE)) + }, options = odin_options(rewrite_constants = FALSE, rewrite_dims = FALSE)) mod <- gen(r = 1:3) expect_equal( @@ -432,7 +431,7 @@ test_that_odin("user array - direct 3d", { deriv(y) <- 1 r[, , ] <- user() dim(r) <- user() - }, options = odin_options(rewrite_dims = FALSE)) + }, options = odin_options(rewrite_constants = FALSE, rewrite_dims = FALSE)) m <- array(runif(24), 2:4) mod <- gen(r = m) @@ -459,7 +458,7 @@ test_that_odin("interpolation", { dim(tp) <- user() dim(zp) <- user() output(p) <- pulse - }, options = odin_options(rewrite_dims = FALSE)) + }, options = odin_options(rewrite_constants = FALSE, rewrite_dims = FALSE)) tt <- seq(0, 3, length.out = 301) tp <- c(0, 1, 2) @@ -578,7 +577,7 @@ test_that_odin("3d array time dependent and variable", { dim(y) <- c(2, 3, 4) dim(r) <- c(2, 3, 4) r[, , ] <- t * 0.1 - }, options = odin_options(rewrite_dims = FALSE)) + }, options = odin_options(rewrite_constants = FALSE, rewrite_dims = FALSE)) mod <- gen() d <- mod$contents() @@ -653,7 +652,7 @@ test_that_odin("discrete delays: matrix", { dim(y) <- c(2, 3) dim(z) <- c(2, 3) dim(a) <- c(2, 3) - }, options = odin_options(rewrite_dims = FALSE)) + }, options = odin_options(rewrite_constants = FALSE, rewrite_dims = FALSE)) mod <- gen() tt <- 0:10 @@ -753,6 +752,7 @@ test_that_odin("Can substitute user variables", { S0[, ] <- user() dim(S0) <- c(n, m) }, options = odin_options(rewrite_dims = TRUE, + rewrite_constants = FALSE, substitutions = list(n = 2, m = 3))) expect_equal(nrow(coef(gen)), 1) # only S0 now S0 <- matrix(rpois(6, 10), 2, 3) @@ -762,3 +762,25 @@ test_that_odin("Can substitute user variables", { expect_equal(dat$m, 3) expect_equal(dat$initial_S, S0) }) + + +test_that("Can rewrite common dimensions", { + gen <- odin({ + n <- user(integer = TRUE) + m <- user() + deriv(S[, ]) <- 0 + deriv(I) <- S[n, m] + dim(S) <- c(n, m) + initial(S[, ]) <- S0[i, j] + initial(I) <- 0 + S0[, ] <- user() + dim(S0) <- c(n, m) + }, options = odin_options(rewrite_constants = TRUE)) + + S0 <- matrix(rpois(6, 10), 2, 3) + mod <- gen(S0 = S0, n = 2, m = 3) + dat <- mod$contents() + + expect_equal(sum(c("dim_S0", "dim_S") %in% names(dat)), 1) + expect_equal(dat$initial_S, S0) +}) diff --git a/tests/testthat/test-run-delay-discrete.R b/tests/testthat/test-run-delay-discrete.R index b89db4c2..ddc6b756 100644 --- a/tests/testthat/test-run-delay-discrete.R +++ b/tests/testthat/test-run-delay-discrete.R @@ -31,7 +31,8 @@ test_that_odin("delays: scalar variable", { ## Check that the underlying data are correct: dat <- mod$contents() - cmp <- logistic_map(dat$r, dat$initial_y, diff(range(tt))) + cmp <- logistic_map(3.6, dat$initial_y, diff(range(tt))) + expect_equal(yy$y, drop(cmp)) ## Then check the delayed expression: i <- seq_len(length(tt) - 2) @@ -56,7 +57,8 @@ test_that_odin("delays: scalar expression", { ## Check that the underlying data are correct: dat <- mod$contents() - cmp <- logistic_map(dat$r, dat$initial_y, diff(range(tt))) + cmp <- logistic_map(3.6, dat$initial_y, diff(range(tt))) + expect_equal(yy$y, cmp) ## Then check the delayed expression: i <- seq_len(length(tt) - 2) @@ -81,7 +83,8 @@ test_that_odin("delays: vector variable", { ## Check that the underlying data are correct: dat <- mod$contents() - cmp <- logistic_map(dat$r, dat$initial_y, diff(range(tt))) + cmp <- logistic_map(3.6, dat$initial_y, diff(range(tt))) + expect_equal(yy$y, cmp) ## Then check the delayed expression: i <- seq_len(length(tt) - 2) @@ -107,7 +110,7 @@ test_that_odin("delays: vector expression", { ## Check that the underlying data are correct: dat <- mod$contents() - cmp <- logistic_map(dat$r, dat$initial_y, diff(range(tt))) + cmp <- logistic_map(3.6, dat$initial_y, diff(range(tt))) expect_equal(yy$y, cmp) ## Then check the delayed expression: @@ -165,7 +168,8 @@ test_that_odin("default (scalar)", { ## Check that the underlying data are correct: dat <- mod$contents() - cmp <- logistic_map(dat$r, dat$initial_y, diff(range(tt))) + cmp <- logistic_map(3.6, dat$initial_y, diff(range(tt))) + expect_equal(yy$y, drop(cmp)) ## Then check the delayed expression: i <- seq_len(length(tt) - 2) @@ -196,7 +200,8 @@ test_that_odin("default (vector)", { ## Check that the underlying data are correct: dat <- mod$contents() - cmp <- logistic_map(dat$r, dat$initial_y, diff(range(tt))) + cmp <- logistic_map(3.6, dat$initial_y, diff(range(tt))) + expect_equal(yy$y, cmp) ## Then check the delayed expression: i <- seq_len(length(tt) - 2) diff --git a/tests/testthat/test-run-examples.R b/tests/testthat/test-run-examples.R index 7baa9a21..ee24f226 100644 --- a/tests/testthat/test-run-examples.R +++ b/tests/testthat/test-run-examples.R @@ -120,7 +120,6 @@ test_that_odin("user arrays", { dat3 <- mod3$contents() dat1 <- mod1$contents() expect_true(all(names(dat1) %in% names(dat3))) - expect_true(all(grepl("^offset_", setdiff(names(dat3), names(dat1))))) expect_equal(dat3[names(dat1)], dat1) ## Now, let's set some different parameters here and check enforcement: diff --git a/tests/testthat/test-run-general.R b/tests/testthat/test-run-general.R index 814f64bf..0dcc75d1 100644 --- a/tests/testthat/test-run-general.R +++ b/tests/testthat/test-run-general.R @@ -1349,3 +1349,16 @@ test_that("user c functions can be passed arrays and indexes", { y <- mod$run(c(0, 1)) expect_equal(mod$transform_variables(y[2, ])$y, cumsum(x)) }) + + +test_that_odin("self output for scalar: rewrite corner case", { + gen <- odin({ + initial(a) <- 1 + deriv(a) <- 0 + x <- 2 + 5 + output(x) <- TRUE + }) + + tt <- seq(0, 10, length.out = 11) + expect_equal(gen()$run(tt)[, "x"], rep(7, 11)) +}) diff --git a/tests/testthat/test-run-library.R b/tests/testthat/test-run-library.R index 4eb2a939..70b37b7b 100644 --- a/tests/testthat/test-run-library.R +++ b/tests/testthat/test-run-library.R @@ -59,8 +59,8 @@ test_that_odin("%%", { tt <- seq(-5, 5, length.out = 101) mod <- gen() res <- mod$run(tt) - s <- mod$contents()[["s"]] - q <- mod$contents()[["q"]] + s <- sin(1) + q <- 1.0 expect_equal(res[, "s1"], tt %% s) expect_equal(res[, "s2"], -tt %% s) @@ -93,8 +93,8 @@ test_that_odin("%/%", { tt <- seq(-5, 5, length.out = 101) mod <- gen() res <- mod$run(tt) - s <- mod$contents()[["s"]] - q <- mod$contents()[["q"]] + s <- sin(1) + q <- 1.0 expect_equal(res[, "s1"], tt %/% s) expect_equal(res[, "s2"], -tt %/% s)