From e4e3d9f7c28f92bea3fc02fe7c0aec8055f68a5b Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 8 Oct 2024 16:48:35 +0100 Subject: [PATCH 1/3] Parse min/max/prod properly --- R/constants.R | 2 ++ R/parse_expr.R | 4 ++- tests/testthat/test-parse-expr-array.R | 9 ++++++ tests/testthat/test-parse-expr.R | 38 ++++++++++++++++++++++++++ 4 files changed, 52 insertions(+), 1 deletion(-) diff --git a/R/constants.R b/R/constants.R index 082fa317..7c62bde2 100644 --- a/R/constants.R +++ b/R/constants.R @@ -59,6 +59,8 @@ FUNCTIONS <- list( nrow = 1, ncol = 1, sum = 1, + min = 1:2, + max = 1:2, as.logical = 1, as.integer = 1, as.numeric = 1, diff --git a/R/parse_expr.R b/R/parse_expr.R index deb4c205..34908c7c 100644 --- a/R/parse_expr.R +++ b/R/parse_expr.R @@ -635,7 +635,9 @@ parse_expr_usage <- function(expr, src, call) { } fn_str <- as.character(fn) ignore <- "[" - if (fn_str == "sum") { + is_reduction <- fn_str %in% c("sum", "prod") || + (fn_str %in% c("min", "max") && length(expr) == 2) + if (is_reduction) { expr <- parse_expr_usage_rewrite_reduce(expr, src, call) } else if (fn_str %in% monty::monty_dsl_distributions()$name) { expr <- parse_expr_usage_rewrite_stochastic(expr, src, call) diff --git a/tests/testthat/test-parse-expr-array.R b/tests/testthat/test-parse-expr-array.R index be588273..c4a34a71 100644 --- a/tests/testthat/test-parse-expr-array.R +++ b/tests/testthat/test-parse-expr-array.R @@ -132,6 +132,15 @@ test_that("can parse call sum over part of part of array", { }) +test_that("parse prod as a reduction", { + res <- parse_expr(quote(a <- prod(x)), NULL, NULL) + expect_equal(res$rhs$expr, + quote(OdinReduce("prod", "x", index = NULL))) + expect_equal(res$rhs$depends, + list(functions = "prod", variables = "x")) +}) + + test_that("can check that sum has the right number of arguments", { expect_error( parse_expr(quote(a <- sum(x, 1)), NULL, NULL), diff --git a/tests/testthat/test-parse-expr.R b/tests/testthat/test-parse-expr.R index bd9743ca..ef088593 100644 --- a/tests/testthat/test-parse-expr.R +++ b/tests/testthat/test-parse-expr.R @@ -490,3 +490,41 @@ test_that("prevent nested special calls", { "Invalid nested special lhs function 'update' within 'initial'", fixed = TRUE) }) + + +test_that("parse min as a reduction", { + res <- parse_expr(quote(a <- min(x)), NULL, NULL) + expect_equal(res$rhs$expr, + quote(OdinReduce("min", "x", index = NULL))) + expect_equal(res$rhs$depends, + list(functions = "min", variables = "x")) +}) + + +test_that("parse min as a 2-arg function", { + res <- parse_expr(quote(a <- min(x, y)), NULL, NULL) + expect_equal(res$rhs$expr, quote(min(x, y))) + expect_equal(res$rhs$depends, + list(functions = "min", variables = c("x", "y"))) + + expect_error( + parse_expr(quote(a <- min(x, y, z)), NULL, NULL), + "Invalid call to 'min': incorrect number of arguments") +}) + + +test_that("parse max as a reduction", { + res <- parse_expr(quote(a <- max(x)), NULL, NULL) + expect_equal(res$rhs$expr, + quote(OdinReduce("max", "x", index = NULL))) + expect_equal(res$rhs$depends, + list(functions = "max", variables = "x")) +}) + + +test_that("parse max as a 2-arg function", { + res <- parse_expr(quote(a <- max(x, y)), NULL, NULL) + expect_equal(res$rhs$expr, quote(max(x, y))) + expect_equal(res$rhs$depends, + list(functions = "max", variables = c("x", "y"))) +}) From ba0a37eeb0c732f0d362f67ff046a9056b79b0e9 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 8 Oct 2024 16:51:43 +0100 Subject: [PATCH 2/3] Support min/max --- R/generate_dust_sexp.R | 8 ++++---- tests/testthat/test-generate.R | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/R/generate_dust_sexp.R b/R/generate_dust_sexp.R index ec59881b..681a8f3e 100644 --- a/R/generate_dust_sexp.R +++ b/R/generate_dust_sexp.R @@ -201,9 +201,9 @@ generate_dust_sexp_reduce <- function(expr, dat, options) { dim <- paste0( if (isFALSE(options$shared_exists)) "dim_" else "shared.dim.", target) - stopifnot(fn == "sum") + stopifnot(fn %in% c("sum", "prod", "min", "max")) if (is.null(index)) { - sprintf("dust2::array::sum(%s, %s)", target_str, dim) + sprintf("dust2::array::%s(%s, %s)", fn, target_str, dim) } else { index_str <- paste(vcapply(index, function(el) { if (el$type == "single") { @@ -217,7 +217,7 @@ generate_dust_sexp_reduce <- function(expr, dat, options) { generate_dust_sexp(from, dat, options), generate_dust_sexp(to, dat, options)) }), collapse = ", ") - sprintf("dust2::array::sum(%s, %s, %s)", - target_str, dim, index_str) + sprintf("dust2::array::%s(%s, %s, %s)", + fn, target_str, dim, index_str) } } diff --git a/tests/testthat/test-generate.R b/tests/testthat/test-generate.R index a971e970..891212b7 100644 --- a/tests/testthat/test-generate.R +++ b/tests/testthat/test-generate.R @@ -1660,3 +1660,21 @@ test_that("Generate conditional debug", { " }", "}")) }) + + +test_that("support min/max", { + dat <- odin_parse({ + update(x) <- min(a) + max(b, c) + initial(x) <- 0 + a[] <- i + dim(a) <- 10 + b <- 20 + c <- 30 + }) + dat <- generate_prepare(dat) + expect_equal( + generate_dust_system_update(dat), + c(method_args$update, + " state_next[0] = dust2::array::min(shared.a, shared.dim.a) + std::max(shared.b, shared.c);", + "}")) +}) From c100f675ca73d4df722d5fe086b3e101949cdb4b Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Wed, 9 Oct 2024 08:54:51 +0100 Subject: [PATCH 3/3] Fix handling of prod For https://github.com/mrc-ide/odin/issues/316 --- R/constants.R | 1 + 1 file changed, 1 insertion(+) diff --git a/R/constants.R b/R/constants.R index 7c62bde2..28070fac 100644 --- a/R/constants.R +++ b/R/constants.R @@ -59,6 +59,7 @@ FUNCTIONS <- list( nrow = 1, ncol = 1, sum = 1, + prod = 1, min = 1:2, max = 1:2, as.logical = 1,