diff --git a/DESCRIPTION b/DESCRIPTION index 00d53f0..3d81106 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: odin2 Title: Next generation odin -Version: 0.3.15 +Version: 0.3.16 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Wes", "Hinsley", role = "aut"), @@ -17,12 +17,13 @@ URL: https://mrc-ide.github.io/odin2, https://github.com/mrc-ide/odin2 BugReports: https://github.com/mrc-ide/odin2/issues Imports: cli, - dust2 (>= 0.3.7), + dust2 (>= 0.3.9), glue, monty (>= 0.3.11), rlang Suggests: decor, + deSolve, fs, knitr, rmarkdown, diff --git a/R/generate_dust.R b/R/generate_dust.R index cc518fa..bad4ee2 100644 --- a/R/generate_dust.R +++ b/R/generate_dust.R @@ -435,12 +435,11 @@ generate_dust_system_delays <- function(dat) { dat$sexp_data) if (target %in% arrays$name) { size <- generate_dust_sexp(call("OdinLength", target), dat$sexp_data) - index <- sprintf("dust2::tools::integer_sequence(%s, %s)", - size, offset) } else { - index <- sprintf("{%s}", offset) + size <- "1" } - body$add(sprintf("const dust2::ode::delay %s{%s, %s};", + index <- sprintf("{%s, %s}", offset, size) + body$add(sprintf("const dust2::ode::delay %s(%s, {%s});", nm, by, index)) } @@ -973,10 +972,10 @@ generate_dust_system_delay_equation <- function(nm, dat) { if (delay_type == "variable") { if (is_array) { - ret <- sprintf("const auto& %s = delays[%d];", + ret <- sprintf("const auto& %s = delays[%d].data;", nm, i - 1) } else { - ret <- sprintf("const auto %s = delays[%d][0];", + ret <- sprintf("const auto %s = delays[%d].data[0];", nm, i - 1) } } else { diff --git a/tests/testthat/test-generate-delay.R b/tests/testthat/test-generate-delay.R index fd9a62e..5f19c68 100644 --- a/tests/testthat/test-generate-delay.R +++ b/tests/testthat/test-generate-delay.R @@ -9,14 +9,14 @@ test_that("can generate a very simple delay", { expect_equal( generate_dust_system_delays(dat), c(method_args$delays, - " const dust2::ode::delay a{1, {0}};", + " const dust2::ode::delay a(1, {{0, 1}});", " return dust2::ode::delays({a});", "}")) expect_equal( generate_dust_system_rhs(dat), c(method_args$rhs_delays, - " const auto a = delays[0][0];", + " const auto a = delays[0].data[0];", " const auto x = state[0];", " state_deriv[0] = x - a;", "}")) @@ -35,14 +35,14 @@ test_that("can generate a delayed array", { expect_equal( generate_dust_system_delays(dat), c(method_args$delays, - " const dust2::ode::delay a{1, dust2::tools::integer_sequence(shared.dim.x.size, 0)};", + " const dust2::ode::delay a(1, {{0, shared.dim.x.size}});", " return dust2::ode::delays({a});", "}")) expect_equal( generate_dust_system_rhs(dat), c(method_args$rhs_delays, - " const auto& a = delays[0];", + " const auto& a = delays[0].data;", " const auto * x = state + 0;", " for (size_t i = 1; i <= shared.dim.x.size; ++i) {", " state_deriv[i - 1 + 0] = x[i - 1] - a[i - 1];", @@ -64,7 +64,7 @@ test_that("can generate delay in output", { expect_equal( generate_dust_system_delays(dat), c(method_args$delays, - " const dust2::ode::delay a{1, {0}};", + " const dust2::ode::delay a(1, {{0, 1}});", " return dust2::ode::delays({a});", "}")) @@ -77,7 +77,7 @@ test_that("can generate delay in output", { expect_equal( generate_dust_system_output(dat), c(method_args$output_delays, - " const auto a = delays[0][0];", + " const auto a = delays[0].data[0];", " const auto x = state[0];", " state[1] = x - a;", "}")) diff --git a/tests/testthat/test-zzz-integration.R b/tests/testthat/test-zzz-integration.R index 8b04c1c..78ed738 100644 --- a/tests/testthat/test-zzz-integration.R +++ b/tests/testthat/test-zzz-integration.R @@ -220,3 +220,39 @@ test_that("Can generate an ode system with output", { expect_equal(y, logistic(pars$r, pars$K, t, rep(1, 5)), tolerance = 1e-5) }) + + +test_that("can compile model with delays", { + skip_if_not_installed("deSolve") + gen <- odin({ + ylag <- delay(y, tau) + initial(y) <- 0.5 + deriv(y) <- 0.2 * ylag * 1 / (1 + ylag^10) - 0.1 * y + tau <- parameter(10, constant = TRUE) + }, quiet = TRUE, debug = TRUE) + + rhs <- function(t, y, pars) { + tau <- pars$tau + if (t < tau) { + ylag <- 0.5 + } else { + ylag <- deSolve::lagvalue(t - tau) + } + list(0.2 * ylag * 1 / (1 + ylag^10) - 0.1 * y) + } + + t <- seq(0, 300, length.out = 301) + sys1 <- dust2::dust_system_create(gen) + dust2::dust_system_set_state_initial(sys1) + y1 <- dust2::dust_system_simulate(sys1, t) + + ## Compare against deSolve + z1 <- deSolve::dede(0.5, t, rhs, list(tau = 10)) + expect_equal(drop(y1), z1[, 2], tolerance = 1e-5) + + sys2 <- dust2::dust_system_create(gen, list(tau = 20)) + dust2::dust_system_set_state_initial(sys2) + y2 <- dust2::dust_system_simulate(sys2, t) + z2 <- deSolve::dede(0.5, t, rhs, list(tau = 20)) + expect_equal(drop(y2), z2[, 2], tolerance = 1e-5) +})