From 6bb9c3428c86aabb5bba30fb0dcbb0a33a668290 Mon Sep 17 00:00:00 2001
From: Rich FitzJohn <r.fitzjohn@imperial.ac.uk>
Date: Thu, 13 Jul 2023 10:56:03 +0100
Subject: [PATCH] Simplify unpack implementation

---
 R/generate_dust.R | 28 ++++++++++------------------
 1 file changed, 10 insertions(+), 18 deletions(-)

diff --git a/R/generate_dust.R b/R/generate_dust.R
index dbeaa84..d6e7ff6 100644
--- a/R/generate_dust.R
+++ b/R/generate_dust.R
@@ -239,8 +239,7 @@ generate_dust_core_update <- function(eqs, dat, rewrite) {
   variables <- dat$components$rhs$variables
   equations <- dat$components$rhs$equations
 
-  unpack <- lapply(variables, dust_unpack_variable,
-                   dat, dat$meta$state, rewrite)
+  unpack <- lapply(variables, dust_unpack_variable, dat, rewrite)
   debug <- generate_dust_debug(dat$debug, dat, rewrite)
   body <- dust_flatten_eqs(c(unpack, eqs[equations], debug))
 
@@ -256,8 +255,7 @@ generate_dust_core_update_stochastic <- function(eqs, dat, rewrite) {
   variables <- dat$components$update_stochastic$variables
   equations <- dat$components$update_stochastic$equations
 
-  unpack <- lapply(variables, dust_unpack_variable,
-                   dat, dat$meta$state, rewrite)
+  unpack <- lapply(variables, dust_unpack_variable, dat, rewrite)
 
   body <- dust_flatten_eqs(
     c(unpack,
@@ -275,8 +273,7 @@ generate_dust_core_output <- function(eqs, dat, rewrite) {
   variables <- dat$components$output$variables
   equations <- dat$components$output$equations
 
-  unpack <- lapply(variables, dust_unpack_variable,
-                   dat, dat$meta$state, rewrite)
+  unpack <- lapply(variables, dust_unpack_variable, dat, rewrite)
   body <- dust_flatten_eqs(c(unpack, eqs[equations]))
   args <- c("double" = dat$meta$time,
             "const std::vector<double>&" = dat$meta$state,
@@ -289,8 +286,7 @@ generate_dust_core_rhs <- function(eqs, dat, rewrite) {
   variables <- dat$components$rhs$variables
   equations <- dat$components$rhs$equations
 
-  unpack <- lapply(variables, dust_unpack_variable,
-                   dat, dat$meta$state, rewrite)
+  unpack <- lapply(variables, dust_unpack_variable, dat, rewrite)
   body <- dust_flatten_eqs(c(unpack, eqs[equations]))
 
   args <- c("double" = dat$meta$time,
@@ -515,9 +511,7 @@ generate_dust_core_attributes <- function(dat) {
 }
 
 
-dust_unpack_variable <- function(name, dat, state, rewrite) {
-  ## Here, we only ever used 'dat$meta$state' as the state arg here so
-  ## it can go, and we assume that.
+dust_unpack_variable <- function(name, dat, rewrite) {
   data_info <- dat$data$elements[[name]]
   location <- switch(dat$data$elements[[name]]$location,
                      variable = dat$meta$state,
@@ -686,8 +680,7 @@ generate_dust_core_compare <- function(eqs, dat, rewrite) {
   }
   variables <- dat$components$compare$variables
   equations <- dat$components$compare$equations
-  unpack <- lapply(variables, dust_unpack_variable,
-                   dat, dat$meta$state, rewrite)
+  unpack <- lapply(variables, dust_unpack_variable, dat, rewrite)
   collect <- generate_dust_compare_collect(dat)
   body <- dust_flatten_eqs(c(unpack, eqs[equations], collect))
   args <- c("const real_type *" = dat$meta$state,
@@ -1246,8 +1239,7 @@ generate_dust_debug <- function(debug, dat, rewrite) {
                            dat$components$rhs$variables),
                    unlist(lapply(debug, function(x) x$depends$variables)))
   if (length(msg) > 0) {
-    ret$add(dust_flatten_eqs(
-      lapply(msg, dust_unpack_variable, dat, dat$meta$state, rewrite)))
+    ret$add(dust_flatten_eqs(lapply(msg, dust_unpack_variable, dat, rewrite)))
   }
 
   time_fmt <- if (dat$features$continuous) "%f" else "%d"
@@ -1333,7 +1325,7 @@ generate_dust_core_adjoint_initial <- function(eqs, dat, rewrite) {
   adjoint_initial <- dat$derivative$adjoint$components$initial
 
   unpack <- lapply(adjoint_initial$variables, dust_unpack_variable, dat,
-                   dat$meta$state, rewrite)
+                   rewrite)
   body <- dust_flatten_eqs(c(unpack, eqs[adjoint_initial$equations]))
   cpp_function("void", "adjoint_initial", args, body)
 }
@@ -1351,7 +1343,7 @@ generate_dust_core_adjoint_update <- function(eqs, dat, rewrite) {
   adjoint_update <- dat$derivative$adjoint$components$rhs
 
   unpack <- lapply(adjoint_update$variables, dust_unpack_variable, dat,
-                   dat$meta$state, rewrite)
+                   rewrite)
   body <- dust_flatten_eqs(c(unpack, eqs[adjoint_update$equations]))
   cpp_function("void", "adjoint_update", args, body)
 }
@@ -1366,7 +1358,7 @@ generate_dust_core_adjoint_compare <- function(eqs, dat, rewrite) {
 
   adjoint_compare <- dat$derivative$adjoint$components$compare
   unpack <- lapply(adjoint_compare$variables, dust_unpack_variable, dat,
-                   dat$meta$state, rewrite)
+                   rewrite)
   body <- dust_flatten_eqs(c(unpack, eqs[adjoint_compare$equations]))
   cpp_function("void", "adjoint_compare_data", args, body)
 }