diff --git a/R/parse_compat.R b/R/parse_compat.R index 13a8665..96dd257 100644 --- a/R/parse_compat.R +++ b/R/parse_compat.R @@ -250,7 +250,7 @@ parse_compat_fix_output_self <- function(expr, call) { rlang::is_call(expr$value[[2]], "output") && !isTRUE(expr$value[[3]]) if (is_output_expr) { - lhs <- expr$value[[2]] + lhs <- expr$value[[2]][[2]] rhs <- expr$value[[3]] rewrite <- (is.symbol(lhs) && @@ -261,10 +261,10 @@ parse_compat_fix_output_self <- function(expr, call) { identical(lhs[[2]], rhs[[2]])) if (rewrite) { original <- expr$value - if (is_call(lhs, "[[")) { - expr$value[[2]] <- expr$value[[2]][[2]] - } expr$value[[3]] <- TRUE + if (rlang::is_call(lhs, "[")) { + expr$value[[2]][[2]] <- lhs[[2]] + } expr <- parse_add_compat(expr, "output_self", original) } } diff --git a/R/parse_system.R b/R/parse_system.R index 42def68..9a007bb 100644 --- a/R/parse_system.R +++ b/R/parse_system.R @@ -141,6 +141,7 @@ parse_system_overall <- function(exprs, call) { is_output_flag <- vlapply(exprs[is_output], function(x) isTRUE(x$rhs$expr)) is_output_expr <- !is_output_flag + nms <- lapply(exprs[is_output], function(x) x$lhs$name) ## Rewrite expressions in output(x) <- expr style to just drop the ## special output bit now, and treat them as normal expressions. @@ -644,7 +645,8 @@ parse_system_arrays <- function(exprs, call) { err <- vlapply(exprs[i], function(x) { is.null(x$lhs$array) && !identical(x$special, "parameter") && - !identical(x$special, "delay") + !identical(x$special, "delay") && + !(identical(x$special, "output") && isTRUE(x$rhs$expr)) }) if (any(err)) { src <- lapply(exprs[i][err], "[[", "src") diff --git a/tests/testthat/test-generate.R b/tests/testthat/test-generate.R index dbae64a..4c935af 100644 --- a/tests/testthat/test-generate.R +++ b/tests/testthat/test-generate.R @@ -2805,3 +2805,23 @@ test_that("output <- TRUE version generates same code", { expect_equal(generate_dust_system(dat2), generate_dust_system(dat1)) }) + + +test_that("cope with array output", { + dat1 <- odin_parse({ + initial(x[]) <- 0 + deriv(x[]) <- x[i] + a[] <- x[i] + 1 + output(a) <- TRUE + dim(x, a) <- 5 + }) + + dat2 <- odin_parse({ + initial(x[]) <- 0 + deriv(x[]) <- x[i] + output(a[]) <- x[i] + 1 + dim(x, a) <- 5 + }) + + expect_equal(generate_dust_system(dat2), generate_dust_system(dat1)) +}) diff --git a/tests/testthat/test-parse-compat.R b/tests/testthat/test-parse-compat.R index 6c91c2a..573af51 100644 --- a/tests/testthat/test-parse-compat.R +++ b/tests/testthat/test-parse-compat.R @@ -294,3 +294,30 @@ test_that("disallow parsing interpolation to slice", { "Drop arrays from lhs of assignments from 'interpolate()'", fixed = TRUE) }) + + +test_that("warn about old-style output assignments", { + w <- expect_warning( + odin_parse({ + initial(x) <- 0 + deriv(x) <- 1 + a <- x + 1 + output(a) <- a + }), + "Use `TRUE` on rhs for 'output(x) <- x' expressions", + fixed = TRUE) +}) + + +test_that("warn about old-style output assignments in arrays", { + w <- expect_warning( + odin_parse({ + initial(x) <- 0 + deriv(x) <- 1 + a[] <- x + 1 + output(a[]) <- a[i] + dim(a) <- 1 + }), + "Use `TRUE` on rhs for 'output(x) <- x' expressions", + fixed = TRUE) +})