diff --git a/R/constants.R b/R/constants.R index 543f113e..6636ba2e 100644 --- a/R/constants.R +++ b/R/constants.R @@ -58,6 +58,7 @@ FUNCTIONS <- list( length = 1, nrow = 1, ncol = 1, + sum = 1, ceiling = ceiling, sign = sign, floor = floor, diff --git a/R/generate_dust.R b/R/generate_dust.R index a1ebddde..9abf5b53 100644 --- a/R/generate_dust.R +++ b/R/generate_dust.R @@ -3,7 +3,6 @@ generate_dust_system <- function(dat) { body <- collector() body$add("#include ") - body$add("#include ") body$add(generate_dust_system_attributes(dat)) body$add(sprintf("class %s {", dat$class)) body$add("public:") diff --git a/R/generate_dust_sexp.R b/R/generate_dust_sexp.R index ea05bc07..aa552352 100644 --- a/R/generate_dust_sexp.R +++ b/R/generate_dust_sexp.R @@ -7,6 +7,9 @@ generate_dust_sexp <- function(expr, dat, options = list()) { fn <- as.character(expr[[1]]) } + ## There's a group here where we don't want to evaluate the + ## arguments, because the interpretation of some values will be + ## different to odin's normal rewrite semantics. if (fn == "[") { return(generate_dust_array_access(expr, dat, options)) } else if (fn == "OdinDim") { @@ -41,8 +44,11 @@ generate_dust_sexp <- function(expr, dat, options = list()) { return(generate_dust_sexp( call("OdinDim", as.character(expr[[2]]), if (fn == "nrow") 1 else 2), dat, options)) + } else if (fn == "OdinReduce") { + return(generate_dust_sexp_reduce(expr, dat, options)) } + ## Below here is much simpler, really. args <- vcapply(expr[-1], generate_dust_sexp, dat, options) n <- length(args) @@ -152,3 +158,33 @@ flatten_index <- function(idx, name) { expr_sum(idx) } } + + +generate_dust_sexp_reduce <- function(expr, dat, options) { + fn <- expr[[2]] + target <- expr[[3]] + target_str <- generate_dust_sexp(expr[[3]], dat, options) + index <- expr$index + dim <- paste0( + if (isFALSE(options$shared_exists)) "dim_" else "shared.dim.", + target) + stopifnot(fn == "sum") + if (is.null(index)) { + sprintf("dust2::array::sum(%s, %s)", target_str, dim) + } else { + index_str <- paste(vcapply(index, function(el) { + if (el$type == "single") { + from <- expr_minus(el$at, 1) + to <- from + } else { + from <- expr_minus(el$from, 1) + to <- expr_minus(el$to, 1) + } + sprintf("{%s, %s}", + generate_dust_sexp(from, dat, options), + generate_dust_sexp(to, dat, options)) + }), collapse = ", ") + sprintf("dust2::array::sum(%s, %s, %s)", + target_str, dim, index_str) + } +} diff --git a/R/parse_expr.R b/R/parse_expr.R index 48598f11..6faf1d28 100644 --- a/R/parse_expr.R +++ b/R/parse_expr.R @@ -363,40 +363,12 @@ parse_expr_usage <- function(expr, src, call) { fn <- expr[[1]] fn_str <- as.character(fn) ignore <- "[" - if (fn_str %in% monty::monty_dsl_distributions()$name) { + if (fn_str == "sum") { + expr <- parse_expr_usage_rewrite_reduce(expr, src, call) + } else if (fn_str %in% monty::monty_dsl_distributions()$name) { expr <- parse_expr_usage_rewrite_stochastic(expr, src, call) } else if (fn_str %in% names(FUNCTIONS)) { - usage <- FUNCTIONS[[fn_str]] - if (is.function(usage)) { - res <- match_call(expr, usage) - if (!res$success) { - err <- conditionMessage(res$error) - odin_parse_error("Invalid call to '{fn_str}': {err}", - "E1028", src, call) - } - } else { - n_args <- length(expr) - 1 - if (!is.null(names(expr))) { - odin_parse_error( - "Calls to '{fn_str}' may not have any named arguments", - "E1029", src, call) - } - if (length(usage) == 1) { - if (n_args != usage) { - odin_parse_error( - paste("Invalid call to '{fn_str}': incorrect number of arguments", - "(expected {usage} but received {n_args})"), - "E1030", src, call) - } - } else if (n_args < usage[[1]] || n_args > usage[[2]]) { - collapse <- if (diff(usage) == 1) " or " else " to " - usage_str <- paste(usage, collapse = collapse) - odin_parse_error( - paste("Invalid call to '{fn_str}': incorrect number of arguments", - "(expected {usage_str} but received {n_args})"), - "E1030", src, call) - } - } + parse_expr_check_call(expr, src, call) args <- lapply(expr[-1], parse_expr_usage, src, call) expr <- as.call(c(list(fn), args)) } else if (!(fn_str %in% ignore)) { @@ -409,6 +381,43 @@ parse_expr_usage <- function(expr, src, call) { } +parse_expr_check_call <- function(expr, usage, src, call) { + fn <- as.character(expr[[1]]) + usage <- FUNCTIONS[[fn]] + if (is.function(usage)) { + res <- match_call(expr, usage) + if (!res$success) { + err <- conditionMessage(res$error) + odin_parse_error("Invalid call to '{fn}': {err}", + "E1028", src, call) + } + } else { + n_args <- length(expr) - 1 + if (!is.null(names(expr))) { + odin_parse_error( + "Calls to '{fn}' may not have any named arguments", + "E1029", src, call) + } + if (length(usage) == 1) { + if (n_args != usage) { + odin_parse_error( + paste("Invalid call to '{fn}': incorrect number of arguments", + "(expected {usage} but received {n_args})"), + "E1030", src, call) + } + } else if (n_args < usage[[1]] || n_args > usage[[2]]) { + collapse <- if (diff(usage) == 1) " or " else " to " + usage_str <- paste(usage, collapse = collapse) + odin_parse_error( + paste("Invalid call to '{fn}': incorrect number of arguments", + "(expected {usage_str} but received {n_args})"), + "E1030", src, call) + } + } + expr +} + + parse_expr_usage_rewrite_stochastic <- function(expr, src, call) { res <- monty::monty_dsl_parse_distribution(expr) if (!res$success) { @@ -431,6 +440,54 @@ parse_expr_usage_rewrite_stochastic <- function(expr, src, call) { } +parse_expr_usage_rewrite_reduce <- function(expr, src, call) { + parse_expr_check_call(expr, src, call) + + fn <- as.character(expr[[1]]) + arg <- expr[[2]] + if (rlang::is_symbol(arg)) { + name <- as.character(arg) + return(call("OdinReduce", fn, name, index = NULL)) + } else if (!rlang::is_call(arg, "[")) { + odin_parse_error( + c("Expected argument to '{fn}' to be an array", + i = paste("The argument to '{fn}' should be name of an array (as", + "a symbol) to sum over all elements of the array, or", + "an array access (using '[]') to sum over part of", + "an array")), + "E1033", src, call) + } + + name <- as.character(arg[[2]]) + index <- as.list(arg[-(1:2)]) + + ## Handle special case efficiently: + if (all(vlapply(index, rlang::is_missing))) { + return(call("OdinReduce", fn, name, index = NULL)) + } + + for (i in seq_along(index)) { + v <- parse_index(name, i, index[[i]]) + deps <- v$depends + if (!is.null(deps)) { + if (":" %in% deps$functions) { + odin_parse_error( + c("Invalid use of range operator ':' within '{fn}' call", + paste("If you use ':' as a range operator within an index,", + "then it must be the outermost call, for e.g,", + "{.code (a + 1):(b + 1)}, not {.code 1 + (a:b)}")), + "E1034", src, call) + } + ## And see parse_expr_check_lhs_index for more + } + v$depends <- NULL + index[[i]] <- v + } + + call("OdinReduce", fn, name, index = index) +} + + rewrite_stochastic_to_expectation <- function(expr) { if (is.recursive(expr)) { if (rlang::is_call(expr[[1]], "OdinStochasticCall")) { diff --git a/R/sysdata.rda b/R/sysdata.rda index c1ac0f15..5ef4bc12 100644 Binary files a/R/sysdata.rda and b/R/sysdata.rda differ diff --git a/tests/testthat/helper-odin2.R b/tests/testthat/helper-odin2.R index 38c2b515..a874b7f1 100644 --- a/tests/testthat/helper-odin2.R +++ b/tests/testthat/helper-odin2.R @@ -31,7 +31,7 @@ test_pkg_setup <- function(path, name = "pkg") { dir.create(file.path(path, "inst/odin"), FALSE, TRUE) writeLines(c( paste("Package:", name), - "LinkingTo: cpp11, dust2, monty", + "LinkingTo: cpp11, dust2, monty, odin2", "Imports: dust2", "Version: 0.0.1", "Authors@R: c(person('A', 'Person', role = c('aut', 'cre'),", diff --git a/tests/testthat/test-generate.R b/tests/testthat/test-generate.R index 8dc7258b..8aa34d5e 100644 --- a/tests/testthat/test-generate.R +++ b/tests/testthat/test-generate.R @@ -1250,7 +1250,6 @@ test_that("can use length() on the rhs", { }) - test_that("can use nrow() and ncol() on the rhs", { dat <- odin_parse({ update(x[, ]) <- x[i, j] + nrow(x) / ncol(x) @@ -1269,3 +1268,46 @@ test_that("can use nrow() and ncol() on the rhs", { " }", "}")) }) + + +test_that("can generate complete sums over arrays", { + dat <- odin_parse({ + update(x) <- sum(y) + initial(x) <- 0 + y[] <- Normal(0, 1) + dim(y) <- 3 + }) + dat <- generate_prepare(dat) + expect_equal( + generate_dust_system_update(dat), + c(method_args$update, + " for (size_t i = 1; i <= shared.dim.y.size; ++i) {", + " internal.y[i - 1] = monty::random::normal(rng_state, 0, 1);", + " }", + " state_next[0] = dust2::array::sum(internal.y, shared.dim.y);", + "}")) +}) + + +test_that("can generate partial sums over arrays", { + dat <- odin_parse({ + update(x[]) <- sum(y[i, ]) + initial(x[]) <- 0 + y[, ] <- Normal(0, 1) + dim(y) <- c(3, 4) + dim(x) <- 3 + }) + dat <- generate_prepare(dat) + expect_equal( + generate_dust_system_update(dat), + c(method_args$update, + " for (size_t i = 1; i <= shared.dim.y.dim[0]; ++i) {", + " for (size_t j = 1; j <= shared.dim.y.dim[1]; ++j) {", + " internal.y[i - 1 + (j - 1) * (shared.dim.y.mult[1])] = monty::random::normal(rng_state, 0, 1);", + " }", + " }", + " for (size_t i = 1; i <= shared.dim.x.size; ++i) {", + " state_next[i - 1 + 0] = dust2::array::sum(internal.y, shared.dim.y, {i - 1, i - 1}, {0, shared.dim.y.dim[1] - 1});", + " }", + "}")) +}) diff --git a/tests/testthat/test-parse-expr-array.R b/tests/testthat/test-parse-expr-array.R index 462b6e8e..40a0e581 100644 --- a/tests/testthat/test-parse-expr-array.R +++ b/tests/testthat/test-parse-expr-array.R @@ -85,3 +85,65 @@ test_that("array equations assigning with brackets", { "Array expressions must always use '[]' on the lhs", fixed = TRUE) }) + + +test_that("can parse call to a whole array", { + res <- parse_expr(quote(a <- sum(x)), NULL, NULL) + expect_equal(res$rhs$expr, + quote(OdinReduce("sum", "x", index = NULL))) + expect_equal(res$rhs$depends, + list(functions = "sum", variables = "x")) +}) + + +test_that("all empty index on sum is a complete sum", { + expr <- quote(OdinReduce("sum", "x", index = NULL)) + expect_equal(parse_expr(quote(a <- sum(x[])), NULL, NULL)$rhs$expr, expr) + expect_equal(parse_expr(quote(a <- sum(x[, ])), NULL, NULL)$rhs$expr, expr) + expect_equal(parse_expr(quote(a <- sum(x[, , ])), NULL, NULL)$rhs$expr, expr) +}) + + +test_that("can parse call sum over part of array", { + res <- parse_expr(quote(a[] <- sum(x[, i])), NULL, NULL) + expect_equal( + res$rhs$expr, + call("OdinReduce", "sum", "x", index = list( + list(name = "i", type = "range", from = 1, to = quote(OdinDim("x", 1L))), + list(name = "j", type = "single", at = quote(i))))) + expect_equal(res$rhs$depends, + list(functions = c("sum", "["), variables = "x")) +}) + + +test_that("can parse call sum over part of part of array", { + res <- parse_expr(quote(a[] <- sum(x[a:b, i])), NULL, NULL) + expect_equal( + res$rhs$expr, + call("OdinReduce", "sum", "x", index = list( + list(name = "i", type = "range", from = quote(a), to = quote(b)), + list(name = "j", type = "single", at = quote(i))))) + expect_equal(res$rhs$depends, + list(functions = c("sum", "[", ":"), + variables = c("x", "a", "b"))) +}) + + +test_that("can check that sum has the right number of arguments", { + expect_error( + parse_expr(quote(a <- sum(x, 1)), NULL, NULL), + "Invalid call to 'sum': incorrect number of arguments") +}) + + +test_that("argument to quote must be symbol or array access", { + expect_error( + parse_expr(quote(a <- sum(x + y)), NULL, NULL), + "Expected argument to 'sum' to be an array") +}) + +test_that("require that range access is simple", { + expect_error( + parse_expr(quote(y <- sum(x[a:b + c])), NULL, NULL), + "Invalid use of range operator ':' within 'sum' call") +}) diff --git a/vignettes/errors.Rmd b/vignettes/errors.Rmd index d121ccd9..c7da0fea 100644 --- a/vignettes/errors.Rmd +++ b/vignettes/errors.Rmd @@ -399,6 +399,60 @@ a <- parameter(type = "integer", differentiate = TRUE) Here, you must decide if `a` should be differentiable (in which case remove the `type` argument) or an integer (in which case remove the `differentiate` argument). +# `E1033` + +The argument to sum must be an array. This can either be a complete array (in which case the argument will be a symbol), or an indexed array. So these are both fine: + +```r +a <- sum(x) +b[] <- sum(x[, i]) +``` + +with the first summing over the whole array and the second summing over rows (each element of `b` will contain a sum over the corresponding row of `x`. + +But these are errors: + +```r +a <- sum(a + y) +b[] <- sum([, i] + 1) +``` + +Because summation is associative (or commutative) in this case we could write: + +```r +a <- sum(a) + sum(y) +b[] <- sum([, i]) + 1 +``` + +but in more complicated cases you may have to jump through more hoops to get the expression you want, and this may involve saving out an intermediate variable. For example, rather than writing: + +```r +a <- sum(x^2) +``` + +You might write: + +```r +xx[] <- x[i]^2 +a <- sum(xx) +``` + +# `E1034` + +Invalid use of `:` within a partial sum. If you use `:` it must be the *outermost* operator within an index, so this is fine: + +```r +sum(x[a:b]) +``` + +but this is not + +```r +sum(x[a:b + 1]) +``` + +See [E1022](#e1022) for more information in the case where this same class of error is applied to indexing the left hand side of an assignment. + # `E2001` Your system of equations does not include any expressions with `initial()` on the lhs. This is what we derive the set of variables from, so at least one must be present. diff --git a/vignettes/functions.Rmd b/vignettes/functions.Rmd index f5ad6539..0753c39e 100644 --- a/vignettes/functions.Rmd +++ b/vignettes/functions.Rmd @@ -164,6 +164,32 @@ We provide several functions for retrieving dimensions from an array We do not currently offer a function for accessing the size of higher dimensions, please let us know if this is an issue (see `vignette("migrating")`) +Frequently, you will want to take a sum over an array, or part of an array, using `sum`. To sum over all elements of an array, use `sum()` with the name of the array you would like to sum over: + +```r +dim(x) <- 10 +x_tot <- sum(x) +``` + +If `m` is a matrix you can compute the sums over the second column by writing: + +```r +m1_tot <- sum(m[, 2]) +``` + +This partial sum approach is frequently used within implicit loops: + +``` +m_col_totals[] <- sum([, i]) +``` + +You can use this approach to compute a matrix-vector product $\mathbf(Ax)$: + +``` +ax_tmp[, ] <- a[i, j] * x[j] +ax[] <- sum(a[, i]) +``` + # Distribution functions We support distribution functions in two places: