diff --git a/DESCRIPTION b/DESCRIPTION index eab6e717..96038afc 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: odin Title: ODE Generation and Integration -Version: 1.1.6 +Version: 1.1.7 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Thibaut", "Jombart", role = "ctb"), diff --git a/R/ir_parse.R b/R/ir_parse.R index 8009c68a..1cb09e97 100644 --- a/R/ir_parse.R +++ b/R/ir_parse.R @@ -9,7 +9,8 @@ ir_parse <- function(x, options, type = NULL) { source <- dat$source ## Data elements: - config <- ir_parse_config(eqs, base, root, source, options$read_include) + config <- ir_parse_config(eqs, base, root, source, options$read_include, + options$config_custom) features <- ir_parse_features(eqs, config, source) variables <- ir_parse_find_variables(eqs, features$discrete, source) diff --git a/R/ir_parse_arrays.R b/R/ir_parse_arrays.R index 3b66aff5..7d78dbe6 100644 --- a/R/ir_parse_arrays.R +++ b/R/ir_parse_arrays.R @@ -86,6 +86,7 @@ ir_parse_arrays_check_usage <- function(eqs, source) { is_user <- vlapply(eqs, function(x) x$type == "user") is_copy <- vlapply(eqs, function(x) x$type == "copy") is_delay <- vlapply(eqs, function(x) x$type == "delay") + is_config <- vlapply(eqs, function(x) x$type == "config") is_delay_array <- vlapply(eqs, function(x) x$type == "delay" && !is.null(x$lhs$index)) name_data <- vcapply(eqs, function(x) x$lhs$name_data) @@ -126,7 +127,7 @@ ir_parse_arrays_check_usage <- function(eqs, source) { } ## Then, start checking for duplicates: - err <- is_duplicated(names(eqs)) & !is_array & !is_inplace + err <- is_duplicated(names(eqs)) & !is_array & !is_inplace & !is_config if (any(err)) { ir_parse_error( sprintf("Duplicate entries must all be array assignments (%s)", diff --git a/R/ir_parse_config.R b/R/ir_parse_config.R index b482793d..c6b3d25e 100644 --- a/R/ir_parse_config.R +++ b/R/ir_parse_config.R @@ -1,15 +1,21 @@ -ir_parse_config <- function(eqs, base_default, root, source, read_include) { +ir_parse_config <- function(eqs, base_default, root, source, + read_include, custom) { i <- vcapply(eqs, "[[", "type") == "config" - config <- lapply(unname(eqs[i]), ir_parse_config1, source) + config <- lapply(unname(eqs[i]), ir_parse_config1, source, custom) nms <- vcapply(config, function(x) x$lhs$name_data) base <- ir_parse_config_base(config[nms == "base"], base_default, source) include <- ir_parse_config_include(config[nms == "include"], root, source, read_include) + custom <- ir_parse_config_custom(config[nms %in% custom], source) - list(base = base, include = include) + ret <- list(base = base, include = include) + if (length(custom) > 0) { + ret$custom <- custom + } + ret } @@ -71,7 +77,22 @@ ir_parse_config_include <- function(include, root, source, read_include) { } -ir_parse_config1 <- function(eq, source) { +ir_parse_config_custom <- function(x, source) { + if (length(x) == 0) { + return(NULL) + } + + ## Is there any other validation that can really be done? We could + ## require that custom cases conform to particular types or are + ## unique? For now we'll be really leniant since we don't document + ## this as a public interface yet. + name <- vcapply(x, function(el) el$lhs$name_lhs) + value <- lapply(x, function(el) el$rhs$value) + unname(Map(list, name = name, value = value)) +} + + +ir_parse_config1 <- function(eq, source, custom) { target <- eq$lhs$name_data value <- eq$rhs$value @@ -79,19 +100,20 @@ ir_parse_config1 <- function(eq, source) { target, base = "character", include = "character", - ir_parse_error(sprintf("Unknown configuration option: %s", target), - eq$source, source)) + NULL) - if (!is.atomic(value)) { - ir_parse_error("config() rhs must be atomic (not an expression or symbol)", - eq$source, source) - } - - if (storage.mode(value) != expected_type) { - ir_parse_error(sprintf( - "Expected a %s for config(%s) but recieved a %s", - expected_type, target, storage.mode(value)), - eq$source, source) + if (is.null(expected_type)) { + if (!(target %in% custom)) { + ir_parse_error(sprintf("Unknown configuration option: %s", target), + eq$source, source) + } + } else { + if (storage.mode(value) != expected_type) { + ir_parse_error(sprintf( + "Expected a %s for config(%s) but recieved a %s", + expected_type, target, storage.mode(value)), + eq$source, source) + } } eq diff --git a/R/ir_serialise.R b/R/ir_serialise.R index 37b3adbb..67bf75b7 100644 --- a/R/ir_serialise.R +++ b/R/ir_serialise.R @@ -19,7 +19,16 @@ ir_serialise_version <- function(version) { ir_serialise_config <- function(config) { - list(base = scalar(config$base), include = config$include) + custom <- config$custom + if (!is.null(config$custom)) { + for (i in seq_along(custom)) { + custom[[i]]$name <- scalar(custom[[i]]$name) + custom[[i]]$value <- scalar(custom[[i]]$value) + } + } + list(base = scalar(config$base), + include = config$include, + custom = custom) } diff --git a/inst/schema.json b/inst/schema.json index 454ca369..1debc534 100644 --- a/inst/schema.json +++ b/inst/schema.json @@ -151,9 +151,11 @@ {"$ref": "#/definitions/include"}, {"type": "null"} ] + }, + "custom": { } }, - "required": ["base", "include"], + "required": ["base", "include", "custom"], "additionalProperties": false }, diff --git a/tests/testthat/test-parse2-config.R b/tests/testthat/test-parse2-config.R index 0c1f8267..ef0966ce 100644 --- a/tests/testthat/test-parse2-config.R +++ b/tests/testthat/test-parse2-config.R @@ -12,18 +12,12 @@ test_that("config() takes a symbol", { class = "odin_error") }) -test_that("config() rhs is atomic", { - expect_error(odin_parse("config(base) <- 1 + 1;"), - "config() rhs must be atomic (not an expression", - fixed = TRUE, class = "odin_error") -}) - test_that("config(base)", { expect_error(odin_parse("config(base) <- 'foo'; config(base) <- 'foo'"), "Expected a single config(base) option", fixed = TRUE, class = "odin_error") expect_error(odin_parse("config(base) <- foo;"), - "config() rhs must be atomic", + "Expected a character for config(base) but recieved a symbol", fixed = TRUE, class = "odin_error") expect_error( odin_parse("config(base) <- 1;"), @@ -61,3 +55,51 @@ test_that("config(include)", { "Duplicated function 'squarepulse' while reading includes", class = "odin_error") }) + + +test_that("Can include multiple files", { + ir <- odin_parse({ + config(include) <- "user_fns.c" + config(include) <- "identity.c" + initial(x) <- 1 + deriv(x) <- 1 + }) + dat <- ir_deserialise(ir) + expect_length(dat$config$include, 2) + expect_equal( + vcapply(dat$config$include$data, function(x) basename(x$filename[[1]])), + c("user_fns.c", "identity.c")) +}) + + +test_that("extend config", { + options <- odin_options(target = "c") + options$config_custom <- "a" + + ir <- odin_parse({ + config(a) <- 1 + initial(x) <- 1 + deriv(x) <- 1 + }, options = options) + expect_equal(ir_deserialise(ir)$config$custom, + list(list(name = "a", value = 1))) + + ir <- odin_parse({ + config(a) <- 1 + config(a) <- 2 + initial(x) <- 1 + deriv(x) <- 1 + }, options = options) + expect_equal(ir_deserialise(ir)$config$custom, + list(list(name = "a", value = 1), + list(name = "a", value = 2))) + + expect_error( + odin_parse({ + config(a) <- 1 + config(b) <- 2 + initial(x) <- 1 + deriv(x) <- 1 + }, options = options), + "Unknown configuration option: b") +})