Skip to content

Commit

Permalink
compile and pseudoinstall stan models during pkgload::load, add snaps…
Browse files Browse the repository at this point in the history
…hot tests
  • Loading branch information
hillalex committed Oct 4, 2024
1 parent 3a8d8f8 commit 2267925
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 12 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ src/stan/**
src/stan/**/*.exe
src/stan/**/*.EXE
inst/doc
bin
.idea
*.png
30 changes: 30 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
.onLoad <- function(libname, pkgname) {

# When installing the package, src/install.libs.R compiles and installs the
# stan model files. Then when loading an installed version of the package,
# instantiate::stan_package_model will look in the installation directory to
# find the executable. But pkgload::load_all does not simulate behaviour of
# src/install.libs.R so here we compile the stan models and move
# them into a place where the pkgload system.file shim can find them

if (pkgload::is_loading()) {
# When installation is simulated using load_all, the system.file shim looks
# for files in the local source directory, so here we create a temporary
# 'bin/stan' directory in the source directory so that calls to
# system.file("bin/stan/model.stan", "epikinetics") will resolve correctly
bin <- file.path(libname, pkgname, "bin")
if (!dir.exists(bin)) {
message("Creating local bin directory")
dir.create(bin, recursive = TRUE, showWarnings = FALSE)
}
bin_stan <- file.path(libname, pkgname, "bin", "stan")
source_path <- file.path(libname, pkgname, "src", "stan")
fs::dir_copy(path = source_path, new_path = bin, overwrite = TRUE)
instantiate::stan_package_compile(
models = instantiate::stan_package_model_files(path = bin_stan),
cpp_options = list(stan_threads = TRUE),
stanc_options = list("O1")
)
message("Finished compiling models")
}
}
18 changes: 7 additions & 11 deletions src/install.libs.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,11 @@ if (!file.exists(bin)) {
}
bin_stan <- file.path(bin, "stan")
fs::dir_copy(path = "stan", new_path = bin_stan)
callr::r(
func = function(bin_stan) {
models <- instantiate::stan_package_model_files(path = bin_stan)
message(paste("Compiling models:", paste0(models, collapse = ",")))
instantiate::stan_package_compile(
models = instantiate::stan_package_model_files(path = bin_stan),
cpp_options = list(stan_threads = TRUE),
stanc_options = list("O1")
)
},
args = list(bin_stan = bin_stan)
models <- instantiate::stan_package_model_files(path = bin_stan)
message(paste("Compiling models:", paste0(models, collapse = ",")))
instantiate::stan_package_compile(
models = instantiate::stan_package_model_files(path = bin_stan),
cpp_options = list(stan_threads = TRUE),
stanc_options = list("O1")
)
message("Finished compiling models")
2 changes: 1 addition & 1 deletion tests/testthat.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
library(testthat)
library(epikinetics)

test_check("epikinetics")
test_check("epikinetics")
81 changes: 81 additions & 0 deletions tests/testthat/_snaps/snapshots.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Model fits are the same

Code
delta
Output
variable mean median sd mad q5 q95 rhat ess_bulk
lp__ -1184.11 -1178.54 58.33 56.66 -1277.41 -1101.09 1.09 35
t0_pop[1] 4.11 4.11 0.27 0.30 3.69 4.56 1.03 155
t0_pop[2] 4.77 4.77 0.26 0.26 4.33 5.17 1.04 72
t0_pop[3] 3.50 3.49 0.29 0.30 3.04 3.97 1.01 324
tp_pop[1] 9.51 9.53 0.70 0.73 8.27 10.57 1.07 40
tp_pop[2] 10.74 10.74 0.64 0.64 9.71 11.78 1.08 37
tp_pop[3] 8.84 8.84 0.84 0.88 7.47 10.14 1.08 40
ts_pop_delta[1] 52.74 52.69 2.75 2.71 48.15 57.29 1.00 494
ts_pop_delta[2] 61.67 61.58 2.82 2.95 56.79 66.53 1.01 142
ts_pop_delta[3] 49.86 49.73 2.65 2.68 45.61 54.34 1.01 355
ess_tail
103
289
220
412
340
307
227
416
370
348
# showing 10 of 10103 rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)

# Population trajectories are the same

Code
trajectories
Output
t me lo hi titre_type
<int> <num> <num> <num> <char>
1: 0 76.25839 56.51254 102.5247 Ancestral
2: 1 94.45253 72.00175 121.8491 Ancestral
3: 2 116.75048 90.41975 144.7452 Ancestral
4: 3 144.43023 113.33132 180.0799 Ancestral
5: 4 178.01885 141.05339 224.6319 Ancestral
---
902: 146 163.39995 128.89520 201.2052 Delta
903: 147 162.78005 128.23676 200.4452 Delta
904: 148 162.15565 127.58168 199.7415 Delta
905: 149 161.54259 126.92995 199.2712 Delta
906: 150 160.95822 126.28154 198.8112 Delta
infection_history
<char>
1: Infection naive
2: Infection naive
3: Infection naive
4: Infection naive
5: Infection naive
---
902: Previously infected (Pre-Omicron)
903: Previously infected (Pre-Omicron)
904: Previously infected (Pre-Omicron)
905: Previously infected (Pre-Omicron)
906: Previously infected (Pre-Omicron)

# Individual trajectories are the same

Code
trajectories
Output
calendar_date titre_type me lo hi time_shift
<IDat> <char> <num> <num> <num> <num>
1: 2021-03-08 Ancestral 543.79451 433.09823 670.4975 0
2: 2021-03-09 Ancestral 528.88463 429.80874 653.8752 0
3: 2021-03-10 Ancestral 545.26438 444.31407 665.8502 0
4: 2021-03-11 Ancestral 522.29561 418.67869 631.7992 0
5: 2021-03-12 Ancestral 531.74586 423.99750 648.3266 0
---
1775: 2022-08-07 Delta 91.77050 31.07070 429.5746 0
1776: 2022-08-08 Delta 91.18062 31.07408 424.9366 0
1777: 2022-08-09 Delta 94.16216 31.29551 426.2925 0
1778: 2022-08-10 Delta 90.74115 30.39902 426.8355 0
1779: 2022-08-11 Delta 92.97980 28.91787 438.0270 0

25 changes: 25 additions & 0 deletions tests/testthat/test-snapshots.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
dat <- data.table::fread(system.file("delta_full.rds", package = "epikinetics"))
mod <- biokinetics$new(data = dat, covariate_formula = ~0 + infection_history)
suppressWarnings({delta <- mod$fit(parallel_chains = 4,
iter_warmup = 50,
iter_sampling = 100,
threads_per_chain = 4,
seed = 100)})

local_edition(3)

test_that("Model fits are the same", {
expect_snapshot(delta)
})

test_that("Population trajectories are the same", {
set.seed(1)
trajectories <- mod$simulate_population_trajectories()
expect_snapshot(trajectories)
})

test_that("Individual trajectories are the same", {
set.seed(1)
trajectories <- mod$simulate_individual_trajectories()
expect_snapshot(trajectories)
})

0 comments on commit 2267925

Please sign in to comment.