Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for sums over arrays #43

Merged
merged 8 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions R/constants.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ FUNCTIONS <- list(
length = 1,
nrow = 1,
ncol = 1,
sum = 1,
ceiling = ceiling,
sign = sign,
floor = floor,
Expand Down
1 change: 0 additions & 1 deletion R/generate_dust.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ generate_dust_system <- function(dat) {

body <- collector()
body$add("#include <dust2/common.hpp>")
body$add("#include <dust2/array.hpp>")
body$add(generate_dust_system_attributes(dat))
body$add(sprintf("class %s {", dat$class))
body$add("public:")
Expand Down
36 changes: 36 additions & 0 deletions R/generate_dust_sexp.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ generate_dust_sexp <- function(expr, dat, options = list()) {
fn <- as.character(expr[[1]])
}

## There's a group here where we don't want to evaluate the
## arguments, because the interpretation of some values will be
## different to odin's normal rewrite semantics.
if (fn == "[") {
return(generate_dust_array_access(expr, dat, options))
} else if (fn == "OdinDim") {
Expand Down Expand Up @@ -41,8 +44,11 @@ generate_dust_sexp <- function(expr, dat, options = list()) {
return(generate_dust_sexp(
call("OdinDim", as.character(expr[[2]]), if (fn == "nrow") 1 else 2),
dat, options))
} else if (fn == "OdinReduce") {
return(generate_dust_sexp_reduce(expr, dat, options))
}

## Below here is much simpler, really.
args <- vcapply(expr[-1], generate_dust_sexp, dat, options)
n <- length(args)

Expand Down Expand Up @@ -152,3 +158,33 @@ flatten_index <- function(idx, name) {
expr_sum(idx)
}
}


generate_dust_sexp_reduce <- function(expr, dat, options) {
fn <- expr[[2]]
target <- expr[[3]]
target_str <- generate_dust_sexp(expr[[3]], dat, options)
index <- expr$index
dim <- paste0(
if (isFALSE(options$shared_exists)) "dim_" else "shared.dim.",
target)
stopifnot(fn == "sum")
if (is.null(index)) {
sprintf("dust2::array::sum<real_type>(%s, %s)", target_str, dim)
} else {
index_str <- paste(vcapply(index, function(el) {
if (el$type == "single") {
from <- expr_minus(el$at, 1)
to <- from
} else {
from <- expr_minus(el$from, 1)
to <- expr_minus(el$to, 1)
}
sprintf("{%s, %s}",
generate_dust_sexp(from, dat, options),
generate_dust_sexp(to, dat, options))
}), collapse = ", ")
sprintf("dust2::array::sum<real_type>(%s, %s, %s)",
target_str, dim, index_str)
}
}
121 changes: 89 additions & 32 deletions R/parse_expr.R
Original file line number Diff line number Diff line change
Expand Up @@ -363,40 +363,12 @@ parse_expr_usage <- function(expr, src, call) {
fn <- expr[[1]]
fn_str <- as.character(fn)
ignore <- "["
if (fn_str %in% monty::monty_dsl_distributions()$name) {
if (fn_str == "sum") {
expr <- parse_expr_usage_rewrite_reduce(expr, src, call)
} else if (fn_str %in% monty::monty_dsl_distributions()$name) {
expr <- parse_expr_usage_rewrite_stochastic(expr, src, call)
} else if (fn_str %in% names(FUNCTIONS)) {
usage <- FUNCTIONS[[fn_str]]
if (is.function(usage)) {
res <- match_call(expr, usage)
if (!res$success) {
err <- conditionMessage(res$error)
odin_parse_error("Invalid call to '{fn_str}': {err}",
"E1028", src, call)
}
} else {
n_args <- length(expr) - 1
if (!is.null(names(expr))) {
odin_parse_error(
"Calls to '{fn_str}' may not have any named arguments",
"E1029", src, call)
}
if (length(usage) == 1) {
if (n_args != usage) {
odin_parse_error(
paste("Invalid call to '{fn_str}': incorrect number of arguments",
"(expected {usage} but received {n_args})"),
"E1030", src, call)
}
} else if (n_args < usage[[1]] || n_args > usage[[2]]) {
collapse <- if (diff(usage) == 1) " or " else " to "
usage_str <- paste(usage, collapse = collapse)
odin_parse_error(
paste("Invalid call to '{fn_str}': incorrect number of arguments",
"(expected {usage_str} but received {n_args})"),
"E1030", src, call)
}
}
parse_expr_check_call(expr, src, call)
args <- lapply(expr[-1], parse_expr_usage, src, call)
expr <- as.call(c(list(fn), args))
} else if (!(fn_str %in% ignore)) {
Expand All @@ -409,6 +381,43 @@ parse_expr_usage <- function(expr, src, call) {
}


parse_expr_check_call <- function(expr, usage, src, call) {
fn <- as.character(expr[[1]])
usage <- FUNCTIONS[[fn]]
if (is.function(usage)) {
res <- match_call(expr, usage)
if (!res$success) {
err <- conditionMessage(res$error)
odin_parse_error("Invalid call to '{fn}': {err}",
"E1028", src, call)
}
} else {
n_args <- length(expr) - 1
if (!is.null(names(expr))) {
odin_parse_error(
"Calls to '{fn}' may not have any named arguments",
"E1029", src, call)
}
if (length(usage) == 1) {
if (n_args != usage) {
odin_parse_error(
paste("Invalid call to '{fn}': incorrect number of arguments",
"(expected {usage} but received {n_args})"),
"E1030", src, call)
}
} else if (n_args < usage[[1]] || n_args > usage[[2]]) {
collapse <- if (diff(usage) == 1) " or " else " to "
usage_str <- paste(usage, collapse = collapse)
odin_parse_error(
paste("Invalid call to '{fn}': incorrect number of arguments",
"(expected {usage_str} but received {n_args})"),
"E1030", src, call)
}
}
expr
}


parse_expr_usage_rewrite_stochastic <- function(expr, src, call) {
res <- monty::monty_dsl_parse_distribution(expr)
if (!res$success) {
Expand All @@ -431,6 +440,54 @@ parse_expr_usage_rewrite_stochastic <- function(expr, src, call) {
}


parse_expr_usage_rewrite_reduce <- function(expr, src, call) {
parse_expr_check_call(expr, src, call)

fn <- as.character(expr[[1]])
arg <- expr[[2]]
if (rlang::is_symbol(arg)) {
name <- as.character(arg)
return(call("OdinReduce", fn, name, index = NULL))
} else if (!rlang::is_call(arg, "[")) {
odin_parse_error(
c("Expected argument to '{fn}' to be an array",
i = paste("The argument to '{fn}' should be name of an array (as",
"a symbol) to sum over all elements of the array, or",
"an array access (using '[]') to sum over part of",
"an array")),
"E1033", src, call)
}

name <- as.character(arg[[2]])
index <- as.list(arg[-(1:2)])

## Handle special case efficiently:
if (all(vlapply(index, rlang::is_missing))) {
return(call("OdinReduce", fn, name, index = NULL))
}

for (i in seq_along(index)) {
v <- parse_index(name, i, index[[i]])
deps <- v$depends
if (!is.null(deps)) {
if (":" %in% deps$functions) {
odin_parse_error(
c("Invalid use of range operator ':' within '{fn}' call",
paste("If you use ':' as a range operator within an index,",
"then it must be the outermost call, for e.g,",
"{.code (a + 1):(b + 1)}, not {.code 1 + (a:b)}")),
"E1034", src, call)
}
## And see parse_expr_check_lhs_index for more
}
v$depends <- NULL
index[[i]] <- v
}

call("OdinReduce", fn, name, index = index)
}


rewrite_stochastic_to_expectation <- function(expr) {
if (is.recursive(expr)) {
if (rlang::is_call(expr[[1]], "OdinStochasticCall")) {
Expand Down
Binary file modified R/sysdata.rda
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/testthat/helper-odin2.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ test_pkg_setup <- function(path, name = "pkg") {
dir.create(file.path(path, "inst/odin"), FALSE, TRUE)
writeLines(c(
paste("Package:", name),
"LinkingTo: cpp11, dust2, monty",
"LinkingTo: cpp11, dust2, monty, odin2",
"Imports: dust2",
"Version: 0.0.1",
"Authors@R: c(person('A', 'Person', role = c('aut', 'cre'),",
Expand Down
44 changes: 43 additions & 1 deletion tests/testthat/test-generate.R
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,6 @@ test_that("can use length() on the rhs", {
})



test_that("can use nrow() and ncol() on the rhs", {
dat <- odin_parse({
update(x[, ]) <- x[i, j] + nrow(x) / ncol(x)
Expand All @@ -1269,3 +1268,46 @@ test_that("can use nrow() and ncol() on the rhs", {
" }",
"}"))
})


test_that("can generate complete sums over arrays", {
dat <- odin_parse({
update(x) <- sum(y)
initial(x) <- 0
y[] <- Normal(0, 1)
dim(y) <- 3
})
dat <- generate_prepare(dat)
expect_equal(
generate_dust_system_update(dat),
c(method_args$update,
" for (size_t i = 1; i <= shared.dim.y.size; ++i) {",
" internal.y[i - 1] = monty::random::normal<real_type>(rng_state, 0, 1);",
" }",
" state_next[0] = dust2::array::sum<real_type>(internal.y, shared.dim.y);",
"}"))
})


test_that("can generate partial sums over arrays", {
dat <- odin_parse({
update(x[]) <- sum(y[i, ])
initial(x[]) <- 0
y[, ] <- Normal(0, 1)
dim(y) <- c(3, 4)
dim(x) <- 3
})
dat <- generate_prepare(dat)
expect_equal(
generate_dust_system_update(dat),
c(method_args$update,
" for (size_t i = 1; i <= shared.dim.y.dim[0]; ++i) {",
" for (size_t j = 1; j <= shared.dim.y.dim[1]; ++j) {",
" internal.y[i - 1 + (j - 1) * (shared.dim.y.mult[1])] = monty::random::normal<real_type>(rng_state, 0, 1);",
" }",
" }",
" for (size_t i = 1; i <= shared.dim.x.size; ++i) {",
" state_next[i - 1 + 0] = dust2::array::sum<real_type>(internal.y, shared.dim.y, {i - 1, i - 1}, {0, shared.dim.y.dim[1] - 1});",
" }",
"}"))
})
62 changes: 62 additions & 0 deletions tests/testthat/test-parse-expr-array.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,65 @@ test_that("array equations assigning with brackets", {
"Array expressions must always use '[]' on the lhs",
fixed = TRUE)
})


test_that("can parse call to a whole array", {
res <- parse_expr(quote(a <- sum(x)), NULL, NULL)
expect_equal(res$rhs$expr,
quote(OdinReduce("sum", "x", index = NULL)))
expect_equal(res$rhs$depends,
list(functions = "sum", variables = "x"))
})


test_that("all empty index on sum is a complete sum", {
expr <- quote(OdinReduce("sum", "x", index = NULL))
expect_equal(parse_expr(quote(a <- sum(x[])), NULL, NULL)$rhs$expr, expr)
expect_equal(parse_expr(quote(a <- sum(x[, ])), NULL, NULL)$rhs$expr, expr)
expect_equal(parse_expr(quote(a <- sum(x[, , ])), NULL, NULL)$rhs$expr, expr)
})


test_that("can parse call sum over part of array", {
res <- parse_expr(quote(a[] <- sum(x[, i])), NULL, NULL)
expect_equal(
res$rhs$expr,
call("OdinReduce", "sum", "x", index = list(
list(name = "i", type = "range", from = 1, to = quote(OdinDim("x", 1L))),
list(name = "j", type = "single", at = quote(i)))))
expect_equal(res$rhs$depends,
list(functions = c("sum", "["), variables = "x"))
})


test_that("can parse call sum over part of part of array", {
res <- parse_expr(quote(a[] <- sum(x[a:b, i])), NULL, NULL)
expect_equal(
res$rhs$expr,
call("OdinReduce", "sum", "x", index = list(
list(name = "i", type = "range", from = quote(a), to = quote(b)),
list(name = "j", type = "single", at = quote(i)))))
expect_equal(res$rhs$depends,
list(functions = c("sum", "[", ":"),
variables = c("x", "a", "b")))
})


test_that("can check that sum has the right number of arguments", {
expect_error(
parse_expr(quote(a <- sum(x, 1)), NULL, NULL),
"Invalid call to 'sum': incorrect number of arguments")
})


test_that("argument to quote must be symbol or array access", {
expect_error(
parse_expr(quote(a <- sum(x + y)), NULL, NULL),
"Expected argument to 'sum' to be an array")
})

test_that("require that range access is simple", {
expect_error(
parse_expr(quote(y <- sum(x[a:b + c])), NULL, NULL),
"Invalid use of range operator ':' within 'sum' call")
})
Loading
Loading