Skip to content

Commit

Permalink
Merge pull request #216 from mrc-ide/mrc-2117
Browse files Browse the repository at this point in the history
mrc-2117: Allow extendable configuration
  • Loading branch information
weshinsley authored Jan 19, 2021
2 parents 05e8630 + b0ca210 commit 42b3add
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 28 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: odin
Title: ODE Generation and Integration
Version: 1.1.6
Version: 1.1.7
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Thibaut", "Jombart", role = "ctb"),
Expand Down
3 changes: 2 additions & 1 deletion R/ir_parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ ir_parse <- function(x, options, type = NULL) {
source <- dat$source

## Data elements:
config <- ir_parse_config(eqs, base, root, source, options$read_include)
config <- ir_parse_config(eqs, base, root, source, options$read_include,
options$config_custom)
features <- ir_parse_features(eqs, config, source)

variables <- ir_parse_find_variables(eqs, features$discrete, source)
Expand Down
3 changes: 2 additions & 1 deletion R/ir_parse_arrays.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ ir_parse_arrays_check_usage <- function(eqs, source) {
is_user <- vlapply(eqs, function(x) x$type == "user")
is_copy <- vlapply(eqs, function(x) x$type == "copy")
is_delay <- vlapply(eqs, function(x) x$type == "delay")
is_config <- vlapply(eqs, function(x) x$type == "config")
is_delay_array <- vlapply(eqs, function(x)
x$type == "delay" && !is.null(x$lhs$index))
name_data <- vcapply(eqs, function(x) x$lhs$name_data)
Expand Down Expand Up @@ -126,7 +127,7 @@ ir_parse_arrays_check_usage <- function(eqs, source) {
}

## Then, start checking for duplicates:
err <- is_duplicated(names(eqs)) & !is_array & !is_inplace
err <- is_duplicated(names(eqs)) & !is_array & !is_inplace & !is_config
if (any(err)) {
ir_parse_error(
sprintf("Duplicate entries must all be array assignments (%s)",
Expand Down
54 changes: 38 additions & 16 deletions R/ir_parse_config.R
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
ir_parse_config <- function(eqs, base_default, root, source, read_include) {
ir_parse_config <- function(eqs, base_default, root, source,
read_include, custom) {
i <- vcapply(eqs, "[[", "type") == "config"

config <- lapply(unname(eqs[i]), ir_parse_config1, source)
config <- lapply(unname(eqs[i]), ir_parse_config1, source, custom)

nms <- vcapply(config, function(x) x$lhs$name_data)

base <- ir_parse_config_base(config[nms == "base"], base_default, source)
include <- ir_parse_config_include(config[nms == "include"], root, source,
read_include)
custom <- ir_parse_config_custom(config[nms %in% custom], source)

list(base = base, include = include)
ret <- list(base = base, include = include)
if (length(custom) > 0) {
ret$custom <- custom
}
ret
}


Expand Down Expand Up @@ -71,27 +77,43 @@ ir_parse_config_include <- function(include, root, source, read_include) {
}


ir_parse_config1 <- function(eq, source) {
ir_parse_config_custom <- function(x, source) {
if (length(x) == 0) {
return(NULL)
}

## Is there any other validation that can really be done? We could
## require that custom cases conform to particular types or are
## unique? For now we'll be really leniant since we don't document
## this as a public interface yet.
name <- vcapply(x, function(el) el$lhs$name_lhs)
value <- lapply(x, function(el) el$rhs$value)
unname(Map(list, name = name, value = value))
}


ir_parse_config1 <- function(eq, source, custom) {
target <- eq$lhs$name_data
value <- eq$rhs$value

expected_type <- switch(
target,
base = "character",
include = "character",
ir_parse_error(sprintf("Unknown configuration option: %s", target),
eq$source, source))
NULL)

if (!is.atomic(value)) {
ir_parse_error("config() rhs must be atomic (not an expression or symbol)",
eq$source, source)
}

if (storage.mode(value) != expected_type) {
ir_parse_error(sprintf(
"Expected a %s for config(%s) but recieved a %s",
expected_type, target, storage.mode(value)),
eq$source, source)
if (is.null(expected_type)) {
if (!(target %in% custom)) {
ir_parse_error(sprintf("Unknown configuration option: %s", target),
eq$source, source)
}
} else {
if (storage.mode(value) != expected_type) {
ir_parse_error(sprintf(
"Expected a %s for config(%s) but recieved a %s",
expected_type, target, storage.mode(value)),
eq$source, source)
}
}

eq
Expand Down
11 changes: 10 additions & 1 deletion R/ir_serialise.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,16 @@ ir_serialise_version <- function(version) {


ir_serialise_config <- function(config) {
list(base = scalar(config$base), include = config$include)
custom <- config$custom
if (!is.null(config$custom)) {
for (i in seq_along(custom)) {
custom[[i]]$name <- scalar(custom[[i]]$name)
custom[[i]]$value <- scalar(custom[[i]]$value)
}
}
list(base = scalar(config$base),
include = config$include,
custom = custom)
}


Expand Down
4 changes: 3 additions & 1 deletion inst/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,11 @@
{"$ref": "#/definitions/include"},
{"type": "null"}
]
},
"custom": {
}
},
"required": ["base", "include"],
"required": ["base", "include", "custom"],
"additionalProperties": false
},

Expand Down
3 changes: 3 additions & 0 deletions tests/testthat/identity.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
double identity(double x) {
return x;
}
56 changes: 49 additions & 7 deletions tests/testthat/test-parse2-config.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,12 @@ test_that("config() takes a symbol", {
class = "odin_error")
})

test_that("config() rhs is atomic", {
expect_error(odin_parse("config(base) <- 1 + 1;"),
"config() rhs must be atomic (not an expression",
fixed = TRUE, class = "odin_error")
})

test_that("config(base)", {
expect_error(odin_parse("config(base) <- 'foo'; config(base) <- 'foo'"),
"Expected a single config(base) option",
fixed = TRUE, class = "odin_error")
expect_error(odin_parse("config(base) <- foo;"),
"config() rhs must be atomic",
"Expected a character for config(base) but recieved a symbol",
fixed = TRUE, class = "odin_error")
expect_error(
odin_parse("config(base) <- 1;"),
Expand Down Expand Up @@ -61,3 +55,51 @@ test_that("config(include)", {
"Duplicated function 'squarepulse' while reading includes",
class = "odin_error")
})


test_that("Can include multiple files", {
ir <- odin_parse({
config(include) <- "user_fns.c"
config(include) <- "identity.c"
initial(x) <- 1
deriv(x) <- 1
})
dat <- ir_deserialise(ir)
expect_length(dat$config$include, 2)
expect_equal(
vcapply(dat$config$include$data, function(x) basename(x$filename[[1]])),
c("user_fns.c", "identity.c"))
})


test_that("extend config", {
options <- odin_options(target = "c")
options$config_custom <- "a"

ir <- odin_parse({
config(a) <- 1
initial(x) <- 1
deriv(x) <- 1
}, options = options)
expect_equal(ir_deserialise(ir)$config$custom,
list(list(name = "a", value = 1)))

ir <- odin_parse({
config(a) <- 1
config(a) <- 2
initial(x) <- 1
deriv(x) <- 1
}, options = options)
expect_equal(ir_deserialise(ir)$config$custom,
list(list(name = "a", value = 1),
list(name = "a", value = 2)))

expect_error(
odin_parse({
config(a) <- 1
config(b) <- 2
initial(x) <- 1
deriv(x) <- 1
}, options = options),
"Unknown configuration option: b")
})

0 comments on commit 42b3add

Please sign in to comment.