Skip to content

Commit

Permalink
Simplify unpack implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed Sep 26, 2023
1 parent 7572a25 commit 6bb9c34
Showing 1 changed file with 10 additions and 18 deletions.
28 changes: 10 additions & 18 deletions R/generate_dust.R
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,7 @@ generate_dust_core_update <- function(eqs, dat, rewrite) {
variables <- dat$components$rhs$variables
equations <- dat$components$rhs$equations

unpack <- lapply(variables, dust_unpack_variable,
dat, dat$meta$state, rewrite)
unpack <- lapply(variables, dust_unpack_variable, dat, rewrite)
debug <- generate_dust_debug(dat$debug, dat, rewrite)
body <- dust_flatten_eqs(c(unpack, eqs[equations], debug))

Expand All @@ -256,8 +255,7 @@ generate_dust_core_update_stochastic <- function(eqs, dat, rewrite) {
variables <- dat$components$update_stochastic$variables
equations <- dat$components$update_stochastic$equations

unpack <- lapply(variables, dust_unpack_variable,
dat, dat$meta$state, rewrite)
unpack <- lapply(variables, dust_unpack_variable, dat, rewrite)

body <- dust_flatten_eqs(
c(unpack,
Expand All @@ -275,8 +273,7 @@ generate_dust_core_output <- function(eqs, dat, rewrite) {
variables <- dat$components$output$variables
equations <- dat$components$output$equations

unpack <- lapply(variables, dust_unpack_variable,
dat, dat$meta$state, rewrite)
unpack <- lapply(variables, dust_unpack_variable, dat, rewrite)
body <- dust_flatten_eqs(c(unpack, eqs[equations]))
args <- c("double" = dat$meta$time,
"const std::vector<double>&" = dat$meta$state,
Expand All @@ -289,8 +286,7 @@ generate_dust_core_rhs <- function(eqs, dat, rewrite) {
variables <- dat$components$rhs$variables
equations <- dat$components$rhs$equations

unpack <- lapply(variables, dust_unpack_variable,
dat, dat$meta$state, rewrite)
unpack <- lapply(variables, dust_unpack_variable, dat, rewrite)
body <- dust_flatten_eqs(c(unpack, eqs[equations]))

args <- c("double" = dat$meta$time,
Expand Down Expand Up @@ -515,9 +511,7 @@ generate_dust_core_attributes <- function(dat) {
}


dust_unpack_variable <- function(name, dat, state, rewrite) {
## Here, we only ever used 'dat$meta$state' as the state arg here so
## it can go, and we assume that.
dust_unpack_variable <- function(name, dat, rewrite) {
data_info <- dat$data$elements[[name]]
location <- switch(dat$data$elements[[name]]$location,
variable = dat$meta$state,
Expand Down Expand Up @@ -686,8 +680,7 @@ generate_dust_core_compare <- function(eqs, dat, rewrite) {
}
variables <- dat$components$compare$variables
equations <- dat$components$compare$equations
unpack <- lapply(variables, dust_unpack_variable,
dat, dat$meta$state, rewrite)
unpack <- lapply(variables, dust_unpack_variable, dat, rewrite)
collect <- generate_dust_compare_collect(dat)
body <- dust_flatten_eqs(c(unpack, eqs[equations], collect))
args <- c("const real_type *" = dat$meta$state,
Expand Down Expand Up @@ -1246,8 +1239,7 @@ generate_dust_debug <- function(debug, dat, rewrite) {
dat$components$rhs$variables),
unlist(lapply(debug, function(x) x$depends$variables)))
if (length(msg) > 0) {
ret$add(dust_flatten_eqs(
lapply(msg, dust_unpack_variable, dat, dat$meta$state, rewrite)))
ret$add(dust_flatten_eqs(lapply(msg, dust_unpack_variable, dat, rewrite)))
}

time_fmt <- if (dat$features$continuous) "%f" else "%d"
Expand Down Expand Up @@ -1333,7 +1325,7 @@ generate_dust_core_adjoint_initial <- function(eqs, dat, rewrite) {
adjoint_initial <- dat$derivative$adjoint$components$initial

unpack <- lapply(adjoint_initial$variables, dust_unpack_variable, dat,
dat$meta$state, rewrite)
rewrite)
body <- dust_flatten_eqs(c(unpack, eqs[adjoint_initial$equations]))
cpp_function("void", "adjoint_initial", args, body)
}
Expand All @@ -1351,7 +1343,7 @@ generate_dust_core_adjoint_update <- function(eqs, dat, rewrite) {
adjoint_update <- dat$derivative$adjoint$components$rhs

unpack <- lapply(adjoint_update$variables, dust_unpack_variable, dat,
dat$meta$state, rewrite)
rewrite)
body <- dust_flatten_eqs(c(unpack, eqs[adjoint_update$equations]))
cpp_function("void", "adjoint_update", args, body)
}
Expand All @@ -1366,7 +1358,7 @@ generate_dust_core_adjoint_compare <- function(eqs, dat, rewrite) {

adjoint_compare <- dat$derivative$adjoint$components$compare
unpack <- lapply(adjoint_compare$variables, dust_unpack_variable, dat,
dat$meta$state, rewrite)
rewrite)
body <- dust_flatten_eqs(c(unpack, eqs[adjoint_compare$equations]))
cpp_function("void", "adjoint_compare_data", args, body)
}

0 comments on commit 6bb9c34

Please sign in to comment.