diff --git a/R/generate_dust.R b/R/generate_dust.R index d6e7ff6..5f4269a 100644 --- a/R/generate_dust.R +++ b/R/generate_dust.R @@ -380,8 +380,16 @@ generate_dust_core_info <- function(dat, rewrite) { len <- generate_dust_core_info_len(nms_variable, nms_output, dat, rewrite) body$add(sprintf("size_t len = %s;", len)) + if (dat$features$has_derivative) { + body$add(sprintf("cpp11::writable::strings adjoint({%s});", + paste(dquote(dat$derivative$parameters), collapse = ", "))) + } + body$add("using namespace cpp11::literals;") body$add("return cpp11::writable::list({") + if (dat$features$has_derivative) { + body$add(' "adjoint"_nm = adjoint,') + } body$add(' "dim"_nm = dim,') body$add(' "len"_nm = len,') body$add(' "index"_nm = index});') diff --git a/tests/testthat/test-differentiate.R b/tests/testthat/test-differentiate.R index 74e6f60..26d7d60 100644 --- a/tests/testthat/test-differentiate.R +++ b/tests/testthat/test-differentiate.R @@ -9,10 +9,18 @@ test_that("sir adjoint model works", { pars <- list(beta = 0.25, gamma = 0.1, I0 = 1) mod <- gen$new(pars, 0, 1, deterministic = TRUE) mod$set_data(d) + + ## This is the current temporary arrangement with dust and may change: + info <- mod$info() + expect_setequal(info$adjoint, c("beta", "gamma", "I0")) + res <- mod$run_adjoint() expect_equal(res$log_likelihood, -44.0256051296862, tolerance = 1e-14) + expect_equal(names(res$gradient), info$adjoint) expect_equal(res$gradient, - c(244.877646917118, -140.566517375877, 25.2152128116894), + c(beta = 244.877646917118, + gamma = -140.566517375877, + I0 = 25.2152128116894)[info$adjoint], tolerance = 1e-14) })