From 6bb9c3428c86aabb5bba30fb0dcbb0a33a668290 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn <r.fitzjohn@imperial.ac.uk> Date: Thu, 13 Jul 2023 10:56:03 +0100 Subject: [PATCH] Simplify unpack implementation --- R/generate_dust.R | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/R/generate_dust.R b/R/generate_dust.R index dbeaa84..d6e7ff6 100644 --- a/R/generate_dust.R +++ b/R/generate_dust.R @@ -239,8 +239,7 @@ generate_dust_core_update <- function(eqs, dat, rewrite) { variables <- dat$components$rhs$variables equations <- dat$components$rhs$equations - unpack <- lapply(variables, dust_unpack_variable, - dat, dat$meta$state, rewrite) + unpack <- lapply(variables, dust_unpack_variable, dat, rewrite) debug <- generate_dust_debug(dat$debug, dat, rewrite) body <- dust_flatten_eqs(c(unpack, eqs[equations], debug)) @@ -256,8 +255,7 @@ generate_dust_core_update_stochastic <- function(eqs, dat, rewrite) { variables <- dat$components$update_stochastic$variables equations <- dat$components$update_stochastic$equations - unpack <- lapply(variables, dust_unpack_variable, - dat, dat$meta$state, rewrite) + unpack <- lapply(variables, dust_unpack_variable, dat, rewrite) body <- dust_flatten_eqs( c(unpack, @@ -275,8 +273,7 @@ generate_dust_core_output <- function(eqs, dat, rewrite) { variables <- dat$components$output$variables equations <- dat$components$output$equations - unpack <- lapply(variables, dust_unpack_variable, - dat, dat$meta$state, rewrite) + unpack <- lapply(variables, dust_unpack_variable, dat, rewrite) body <- dust_flatten_eqs(c(unpack, eqs[equations])) args <- c("double" = dat$meta$time, "const std::vector<double>&" = dat$meta$state, @@ -289,8 +286,7 @@ generate_dust_core_rhs <- function(eqs, dat, rewrite) { variables <- dat$components$rhs$variables equations <- dat$components$rhs$equations - unpack <- lapply(variables, dust_unpack_variable, - dat, dat$meta$state, rewrite) + unpack <- lapply(variables, dust_unpack_variable, dat, rewrite) body <- dust_flatten_eqs(c(unpack, eqs[equations])) args <- c("double" = dat$meta$time, @@ -515,9 +511,7 @@ generate_dust_core_attributes <- function(dat) { } -dust_unpack_variable <- function(name, dat, state, rewrite) { - ## Here, we only ever used 'dat$meta$state' as the state arg here so - ## it can go, and we assume that. +dust_unpack_variable <- function(name, dat, rewrite) { data_info <- dat$data$elements[[name]] location <- switch(dat$data$elements[[name]]$location, variable = dat$meta$state, @@ -686,8 +680,7 @@ generate_dust_core_compare <- function(eqs, dat, rewrite) { } variables <- dat$components$compare$variables equations <- dat$components$compare$equations - unpack <- lapply(variables, dust_unpack_variable, - dat, dat$meta$state, rewrite) + unpack <- lapply(variables, dust_unpack_variable, dat, rewrite) collect <- generate_dust_compare_collect(dat) body <- dust_flatten_eqs(c(unpack, eqs[equations], collect)) args <- c("const real_type *" = dat$meta$state, @@ -1246,8 +1239,7 @@ generate_dust_debug <- function(debug, dat, rewrite) { dat$components$rhs$variables), unlist(lapply(debug, function(x) x$depends$variables))) if (length(msg) > 0) { - ret$add(dust_flatten_eqs( - lapply(msg, dust_unpack_variable, dat, dat$meta$state, rewrite))) + ret$add(dust_flatten_eqs(lapply(msg, dust_unpack_variable, dat, rewrite))) } time_fmt <- if (dat$features$continuous) "%f" else "%d" @@ -1333,7 +1325,7 @@ generate_dust_core_adjoint_initial <- function(eqs, dat, rewrite) { adjoint_initial <- dat$derivative$adjoint$components$initial unpack <- lapply(adjoint_initial$variables, dust_unpack_variable, dat, - dat$meta$state, rewrite) + rewrite) body <- dust_flatten_eqs(c(unpack, eqs[adjoint_initial$equations])) cpp_function("void", "adjoint_initial", args, body) } @@ -1351,7 +1343,7 @@ generate_dust_core_adjoint_update <- function(eqs, dat, rewrite) { adjoint_update <- dat$derivative$adjoint$components$rhs unpack <- lapply(adjoint_update$variables, dust_unpack_variable, dat, - dat$meta$state, rewrite) + rewrite) body <- dust_flatten_eqs(c(unpack, eqs[adjoint_update$equations])) cpp_function("void", "adjoint_update", args, body) } @@ -1366,7 +1358,7 @@ generate_dust_core_adjoint_compare <- function(eqs, dat, rewrite) { adjoint_compare <- dat$derivative$adjoint$components$compare unpack <- lapply(adjoint_compare$variables, dust_unpack_variable, dat, - dat$meta$state, rewrite) + rewrite) body <- dust_flatten_eqs(c(unpack, eqs[adjoint_compare$equations])) cpp_function("void", "adjoint_compare_data", args, body) }