Skip to content

Commit

Permalink
Merge pull request #135 from mrc-ide/mrc-6135
Browse files Browse the repository at this point in the history
Update to work with delays from more recent dust
  • Loading branch information
weshinsley authored Dec 17, 2024
2 parents 250dc81 + 26ea238 commit f266b0f
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 14 deletions.
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: odin2
Title: Next generation odin
Version: 0.3.15
Version: 0.3.16
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Wes", "Hinsley", role = "aut"),
Expand All @@ -17,12 +17,13 @@ URL: https://mrc-ide.github.io/odin2, https://github.com/mrc-ide/odin2
BugReports: https://github.com/mrc-ide/odin2/issues
Imports:
cli,
dust2 (>= 0.3.7),
dust2 (>= 0.3.9),
glue,
monty (>= 0.3.11),
rlang
Suggests:
decor,
deSolve,
fs,
knitr,
rmarkdown,
Expand Down
11 changes: 5 additions & 6 deletions R/generate_dust.R
Original file line number Diff line number Diff line change
Expand Up @@ -435,12 +435,11 @@ generate_dust_system_delays <- function(dat) {
dat$sexp_data)
if (target %in% arrays$name) {
size <- generate_dust_sexp(call("OdinLength", target), dat$sexp_data)
index <- sprintf("dust2::tools::integer_sequence(%s, %s)",
size, offset)
} else {
index <- sprintf("{%s}", offset)
size <- "1"
}
body$add(sprintf("const dust2::ode::delay<real_type> %s{%s, %s};",
index <- sprintf("{%s, %s}", offset, size)
body$add(sprintf("const dust2::ode::delay<real_type> %s(%s, {%s});",
nm, by, index))
}

Expand Down Expand Up @@ -973,10 +972,10 @@ generate_dust_system_delay_equation <- function(nm, dat) {

if (delay_type == "variable") {
if (is_array) {
ret <- sprintf("const auto& %s = delays[%d];",
ret <- sprintf("const auto& %s = delays[%d].data;",
nm, i - 1)
} else {
ret <- sprintf("const auto %s = delays[%d][0];",
ret <- sprintf("const auto %s = delays[%d].data[0];",
nm, i - 1)
}
} else {
Expand Down
12 changes: 6 additions & 6 deletions tests/testthat/test-generate-delay.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ test_that("can generate a very simple delay", {
expect_equal(
generate_dust_system_delays(dat),
c(method_args$delays,
" const dust2::ode::delay<real_type> a{1, {0}};",
" const dust2::ode::delay<real_type> a(1, {{0, 1}});",
" return dust2::ode::delays<real_type>({a});",
"}"))

expect_equal(
generate_dust_system_rhs(dat),
c(method_args$rhs_delays,
" const auto a = delays[0][0];",
" const auto a = delays[0].data[0];",
" const auto x = state[0];",
" state_deriv[0] = x - a;",
"}"))
Expand All @@ -35,14 +35,14 @@ test_that("can generate a delayed array", {
expect_equal(
generate_dust_system_delays(dat),
c(method_args$delays,
" const dust2::ode::delay<real_type> a{1, dust2::tools::integer_sequence(shared.dim.x.size, 0)};",
" const dust2::ode::delay<real_type> a(1, {{0, shared.dim.x.size}});",
" return dust2::ode::delays<real_type>({a});",
"}"))

expect_equal(
generate_dust_system_rhs(dat),
c(method_args$rhs_delays,
" const auto& a = delays[0];",
" const auto& a = delays[0].data;",
" const auto * x = state + 0;",
" for (size_t i = 1; i <= shared.dim.x.size; ++i) {",
" state_deriv[i - 1 + 0] = x[i - 1] - a[i - 1];",
Expand All @@ -64,7 +64,7 @@ test_that("can generate delay in output", {
expect_equal(
generate_dust_system_delays(dat),
c(method_args$delays,
" const dust2::ode::delay<real_type> a{1, {0}};",
" const dust2::ode::delay<real_type> a(1, {{0, 1}});",
" return dust2::ode::delays<real_type>({a});",
"}"))

Expand All @@ -77,7 +77,7 @@ test_that("can generate delay in output", {
expect_equal(
generate_dust_system_output(dat),
c(method_args$output_delays,
" const auto a = delays[0][0];",
" const auto a = delays[0].data[0];",
" const auto x = state[0];",
" state[1] = x - a;",
"}"))
Expand Down
36 changes: 36 additions & 0 deletions tests/testthat/test-zzz-integration.R
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,39 @@ test_that("Can generate an ode system with output", {

expect_equal(y, logistic(pars$r, pars$K, t, rep(1, 5)), tolerance = 1e-5)
})


test_that("can compile model with delays", {
skip_if_not_installed("deSolve")
gen <- odin({
ylag <- delay(y, tau)
initial(y) <- 0.5
deriv(y) <- 0.2 * ylag * 1 / (1 + ylag^10) - 0.1 * y
tau <- parameter(10, constant = TRUE)
}, quiet = TRUE, debug = TRUE)

rhs <- function(t, y, pars) {
tau <- pars$tau
if (t < tau) {
ylag <- 0.5
} else {
ylag <- deSolve::lagvalue(t - tau)
}
list(0.2 * ylag * 1 / (1 + ylag^10) - 0.1 * y)
}

t <- seq(0, 300, length.out = 301)
sys1 <- dust2::dust_system_create(gen)
dust2::dust_system_set_state_initial(sys1)
y1 <- dust2::dust_system_simulate(sys1, t)

## Compare against deSolve
z1 <- deSolve::dede(0.5, t, rhs, list(tau = 10))
expect_equal(drop(y1), z1[, 2], tolerance = 1e-5)

sys2 <- dust2::dust_system_create(gen, list(tau = 20))
dust2::dust_system_set_state_initial(sys2)
y2 <- dust2::dust_system_simulate(sys2, t)
z2 <- deSolve::dede(0.5, t, rhs, list(tau = 20))
expect_equal(drop(y2), z2[, 2], tolerance = 1e-5)
})

0 comments on commit f266b0f

Please sign in to comment.