Skip to content

Commit

Permalink
Merge pull request #43 from mrc-ide/mrc-5648
Browse files Browse the repository at this point in the history
Support for sums over arrays
  • Loading branch information
richfitz authored Sep 17, 2024
2 parents 53fda66 + 43b83b1 commit d52ead0
Show file tree
Hide file tree
Showing 10 changed files with 312 additions and 35 deletions.
1 change: 1 addition & 0 deletions R/constants.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ FUNCTIONS <- list(
length = 1,
nrow = 1,
ncol = 1,
sum = 1,
ceiling = ceiling,
sign = sign,
floor = floor,
Expand Down
1 change: 0 additions & 1 deletion R/generate_dust.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ generate_dust_system <- function(dat) {

body <- collector()
body$add("#include <dust2/common.hpp>")
body$add("#include <dust2/array.hpp>")
body$add(generate_dust_system_attributes(dat))
body$add(sprintf("class %s {", dat$class))
body$add("public:")
Expand Down
36 changes: 36 additions & 0 deletions R/generate_dust_sexp.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ generate_dust_sexp <- function(expr, dat, options = list()) {
fn <- as.character(expr[[1]])
}

## There's a group here where we don't want to evaluate the
## arguments, because the interpretation of some values will be
## different to odin's normal rewrite semantics.
if (fn == "[") {
return(generate_dust_array_access(expr, dat, options))
} else if (fn == "OdinDim") {
Expand Down Expand Up @@ -41,8 +44,11 @@ generate_dust_sexp <- function(expr, dat, options = list()) {
return(generate_dust_sexp(
call("OdinDim", as.character(expr[[2]]), if (fn == "nrow") 1 else 2),
dat, options))
} else if (fn == "OdinReduce") {
return(generate_dust_sexp_reduce(expr, dat, options))
}

## Below here is much simpler, really.
args <- vcapply(expr[-1], generate_dust_sexp, dat, options)
n <- length(args)

Expand Down Expand Up @@ -152,3 +158,33 @@ flatten_index <- function(idx, name) {
expr_sum(idx)
}
}


generate_dust_sexp_reduce <- function(expr, dat, options) {
fn <- expr[[2]]
target <- expr[[3]]
target_str <- generate_dust_sexp(expr[[3]], dat, options)
index <- expr$index
dim <- paste0(
if (isFALSE(options$shared_exists)) "dim_" else "shared.dim.",
target)
stopifnot(fn == "sum")
if (is.null(index)) {
sprintf("dust2::array::sum<real_type>(%s, %s)", target_str, dim)
} else {
index_str <- paste(vcapply(index, function(el) {
if (el$type == "single") {
from <- expr_minus(el$at, 1)
to <- from
} else {
from <- expr_minus(el$from, 1)
to <- expr_minus(el$to, 1)
}
sprintf("{%s, %s}",
generate_dust_sexp(from, dat, options),
generate_dust_sexp(to, dat, options))
}), collapse = ", ")
sprintf("dust2::array::sum<real_type>(%s, %s, %s)",
target_str, dim, index_str)
}
}
121 changes: 89 additions & 32 deletions R/parse_expr.R
Original file line number Diff line number Diff line change
Expand Up @@ -363,40 +363,12 @@ parse_expr_usage <- function(expr, src, call) {
fn <- expr[[1]]
fn_str <- as.character(fn)
ignore <- "["
if (fn_str %in% monty::monty_dsl_distributions()$name) {
if (fn_str == "sum") {
expr <- parse_expr_usage_rewrite_reduce(expr, src, call)
} else if (fn_str %in% monty::monty_dsl_distributions()$name) {
expr <- parse_expr_usage_rewrite_stochastic(expr, src, call)
} else if (fn_str %in% names(FUNCTIONS)) {
usage <- FUNCTIONS[[fn_str]]
if (is.function(usage)) {
res <- match_call(expr, usage)
if (!res$success) {
err <- conditionMessage(res$error)
odin_parse_error("Invalid call to '{fn_str}': {err}",
"E1028", src, call)
}
} else {
n_args <- length(expr) - 1
if (!is.null(names(expr))) {
odin_parse_error(
"Calls to '{fn_str}' may not have any named arguments",
"E1029", src, call)
}
if (length(usage) == 1) {
if (n_args != usage) {
odin_parse_error(
paste("Invalid call to '{fn_str}': incorrect number of arguments",
"(expected {usage} but received {n_args})"),
"E1030", src, call)
}
} else if (n_args < usage[[1]] || n_args > usage[[2]]) {
collapse <- if (diff(usage) == 1) " or " else " to "
usage_str <- paste(usage, collapse = collapse)
odin_parse_error(
paste("Invalid call to '{fn_str}': incorrect number of arguments",
"(expected {usage_str} but received {n_args})"),
"E1030", src, call)
}
}
parse_expr_check_call(expr, src, call)
args <- lapply(expr[-1], parse_expr_usage, src, call)
expr <- as.call(c(list(fn), args))
} else if (!(fn_str %in% ignore)) {
Expand All @@ -409,6 +381,43 @@ parse_expr_usage <- function(expr, src, call) {
}


parse_expr_check_call <- function(expr, usage, src, call) {
fn <- as.character(expr[[1]])
usage <- FUNCTIONS[[fn]]
if (is.function(usage)) {
res <- match_call(expr, usage)
if (!res$success) {
err <- conditionMessage(res$error)
odin_parse_error("Invalid call to '{fn}': {err}",
"E1028", src, call)
}
} else {
n_args <- length(expr) - 1
if (!is.null(names(expr))) {
odin_parse_error(
"Calls to '{fn}' may not have any named arguments",
"E1029", src, call)
}
if (length(usage) == 1) {
if (n_args != usage) {
odin_parse_error(
paste("Invalid call to '{fn}': incorrect number of arguments",
"(expected {usage} but received {n_args})"),
"E1030", src, call)
}
} else if (n_args < usage[[1]] || n_args > usage[[2]]) {
collapse <- if (diff(usage) == 1) " or " else " to "
usage_str <- paste(usage, collapse = collapse)
odin_parse_error(
paste("Invalid call to '{fn}': incorrect number of arguments",
"(expected {usage_str} but received {n_args})"),
"E1030", src, call)
}
}
expr
}


parse_expr_usage_rewrite_stochastic <- function(expr, src, call) {
res <- monty::monty_dsl_parse_distribution(expr)
if (!res$success) {
Expand All @@ -431,6 +440,54 @@ parse_expr_usage_rewrite_stochastic <- function(expr, src, call) {
}


parse_expr_usage_rewrite_reduce <- function(expr, src, call) {
parse_expr_check_call(expr, src, call)

fn <- as.character(expr[[1]])
arg <- expr[[2]]
if (rlang::is_symbol(arg)) {
name <- as.character(arg)
return(call("OdinReduce", fn, name, index = NULL))
} else if (!rlang::is_call(arg, "[")) {
odin_parse_error(
c("Expected argument to '{fn}' to be an array",
i = paste("The argument to '{fn}' should be name of an array (as",
"a symbol) to sum over all elements of the array, or",
"an array access (using '[]') to sum over part of",
"an array")),
"E1033", src, call)
}

name <- as.character(arg[[2]])
index <- as.list(arg[-(1:2)])

## Handle special case efficiently:
if (all(vlapply(index, rlang::is_missing))) {
return(call("OdinReduce", fn, name, index = NULL))
}

for (i in seq_along(index)) {
v <- parse_index(name, i, index[[i]])
deps <- v$depends
if (!is.null(deps)) {
if (":" %in% deps$functions) {
odin_parse_error(
c("Invalid use of range operator ':' within '{fn}' call",
paste("If you use ':' as a range operator within an index,",
"then it must be the outermost call, for e.g,",
"{.code (a + 1):(b + 1)}, not {.code 1 + (a:b)}")),
"E1034", src, call)
}
## And see parse_expr_check_lhs_index for more
}
v$depends <- NULL
index[[i]] <- v
}

call("OdinReduce", fn, name, index = index)
}


rewrite_stochastic_to_expectation <- function(expr) {
if (is.recursive(expr)) {
if (rlang::is_call(expr[[1]], "OdinStochasticCall")) {
Expand Down
Binary file modified R/sysdata.rda
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/testthat/helper-odin2.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ test_pkg_setup <- function(path, name = "pkg") {
dir.create(file.path(path, "inst/odin"), FALSE, TRUE)
writeLines(c(
paste("Package:", name),
"LinkingTo: cpp11, dust2, monty",
"LinkingTo: cpp11, dust2, monty, odin2",
"Imports: dust2",
"Version: 0.0.1",
"Authors@R: c(person('A', 'Person', role = c('aut', 'cre'),",
Expand Down
44 changes: 43 additions & 1 deletion tests/testthat/test-generate.R
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,6 @@ test_that("can use length() on the rhs", {
})



test_that("can use nrow() and ncol() on the rhs", {
dat <- odin_parse({
update(x[, ]) <- x[i, j] + nrow(x) / ncol(x)
Expand All @@ -1269,3 +1268,46 @@ test_that("can use nrow() and ncol() on the rhs", {
" }",
"}"))
})


test_that("can generate complete sums over arrays", {
dat <- odin_parse({
update(x) <- sum(y)
initial(x) <- 0
y[] <- Normal(0, 1)
dim(y) <- 3
})
dat <- generate_prepare(dat)
expect_equal(
generate_dust_system_update(dat),
c(method_args$update,
" for (size_t i = 1; i <= shared.dim.y.size; ++i) {",
" internal.y[i - 1] = monty::random::normal<real_type>(rng_state, 0, 1);",
" }",
" state_next[0] = dust2::array::sum<real_type>(internal.y, shared.dim.y);",
"}"))
})


test_that("can generate partial sums over arrays", {
dat <- odin_parse({
update(x[]) <- sum(y[i, ])
initial(x[]) <- 0
y[, ] <- Normal(0, 1)
dim(y) <- c(3, 4)
dim(x) <- 3
})
dat <- generate_prepare(dat)
expect_equal(
generate_dust_system_update(dat),
c(method_args$update,
" for (size_t i = 1; i <= shared.dim.y.dim[0]; ++i) {",
" for (size_t j = 1; j <= shared.dim.y.dim[1]; ++j) {",
" internal.y[i - 1 + (j - 1) * (shared.dim.y.mult[1])] = monty::random::normal<real_type>(rng_state, 0, 1);",
" }",
" }",
" for (size_t i = 1; i <= shared.dim.x.size; ++i) {",
" state_next[i - 1 + 0] = dust2::array::sum<real_type>(internal.y, shared.dim.y, {i - 1, i - 1}, {0, shared.dim.y.dim[1] - 1});",
" }",
"}"))
})
62 changes: 62 additions & 0 deletions tests/testthat/test-parse-expr-array.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,65 @@ test_that("array equations assigning with brackets", {
"Array expressions must always use '[]' on the lhs",
fixed = TRUE)
})


test_that("can parse call to a whole array", {
res <- parse_expr(quote(a <- sum(x)), NULL, NULL)
expect_equal(res$rhs$expr,
quote(OdinReduce("sum", "x", index = NULL)))
expect_equal(res$rhs$depends,
list(functions = "sum", variables = "x"))
})


test_that("all empty index on sum is a complete sum", {
expr <- quote(OdinReduce("sum", "x", index = NULL))
expect_equal(parse_expr(quote(a <- sum(x[])), NULL, NULL)$rhs$expr, expr)
expect_equal(parse_expr(quote(a <- sum(x[, ])), NULL, NULL)$rhs$expr, expr)
expect_equal(parse_expr(quote(a <- sum(x[, , ])), NULL, NULL)$rhs$expr, expr)
})


test_that("can parse call sum over part of array", {
res <- parse_expr(quote(a[] <- sum(x[, i])), NULL, NULL)
expect_equal(
res$rhs$expr,
call("OdinReduce", "sum", "x", index = list(
list(name = "i", type = "range", from = 1, to = quote(OdinDim("x", 1L))),
list(name = "j", type = "single", at = quote(i)))))
expect_equal(res$rhs$depends,
list(functions = c("sum", "["), variables = "x"))
})


test_that("can parse call sum over part of part of array", {
res <- parse_expr(quote(a[] <- sum(x[a:b, i])), NULL, NULL)
expect_equal(
res$rhs$expr,
call("OdinReduce", "sum", "x", index = list(
list(name = "i", type = "range", from = quote(a), to = quote(b)),
list(name = "j", type = "single", at = quote(i)))))
expect_equal(res$rhs$depends,
list(functions = c("sum", "[", ":"),
variables = c("x", "a", "b")))
})


test_that("can check that sum has the right number of arguments", {
expect_error(
parse_expr(quote(a <- sum(x, 1)), NULL, NULL),
"Invalid call to 'sum': incorrect number of arguments")
})


test_that("argument to quote must be symbol or array access", {
expect_error(
parse_expr(quote(a <- sum(x + y)), NULL, NULL),
"Expected argument to 'sum' to be an array")
})

test_that("require that range access is simple", {
expect_error(
parse_expr(quote(y <- sum(x[a:b + c])), NULL, NULL),
"Invalid use of range operator ':' within 'sum' call")
})
Loading

0 comments on commit d52ead0

Please sign in to comment.