Skip to content

Commit

Permalink
Fix crash on do_index = TRUE with extra time
Browse files Browse the repository at this point in the history
  • Loading branch information
seananderson committed Dec 12, 2023
1 parent 516244e commit a776a41
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 5 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: sdmTMB
Title: Spatial and Spatiotemporal SPDE-Based GLMMs with 'TMB'
Version: 0.4.1.9005
Version: 0.4.1.9006
Authors@R: c(
person(c("Sean", "C."), "Anderson", , "[email protected]",
role = c("aut", "cre"),
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# sdmTMB (development version)

* Fix crash in if `sdmTMB(..., do_index = TRUE)` and `extra_time` supplied along
with `predict_args = list(newdata = ...)` that lacked `extra_time` elements.

* Allow `get_index()` to work with missing time elements.

* Add the ability to pass a custom randomized quantile function `qres_func`
Expand Down
7 changes: 5 additions & 2 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -1394,10 +1394,10 @@ sdmTMB <- function(
if (do_index) {
args <- list(object = out_structure, return_tmb_data = TRUE)
args <- c(args, predict_args)
tmb_data <- do.call(predict.sdmTMB, args)
if (!"newdata" %in% names(predict_args)) {
cli_warn("`newdata` must be supplied if `do_index = TRUE`.")
}
tmb_data <- do.call(predict.sdmTMB, args)
if ("bias_correct" %in% names(index_args)) {
cli_warn("`bias_correct` must be done later with `get_index(..., bias_correct = TRUE)`.")
index_args$bias_correct <- NULL
Expand All @@ -1408,7 +1408,7 @@ sdmTMB <- function(
index_args[["area"]] <- 1
}
if (length(index_args$area) == 1L) {
tmb_data$area_i <- rep(index_args[["area"]], nrow(predict_args[["newdata"]]))
tmb_data$area_i <- rep(index_args[["area"]], length(tmb_data$proj_year)) # proj_year includes padded extra_time! otherwise, crash
} else {
if (length(index_args$area) != nrow(predict_args[["newdata"]]))
cli_abort("`area` length does not match `nrow(newdata)`.")
Expand All @@ -1417,6 +1417,9 @@ sdmTMB <- function(
tmb_data$calc_index_totals <- 1L
tmb_params[["eps_index"]] <- numeric(0) # for bias correction
out_structure$do_index <- TRUE
do_index_time_missing_from_nd <-
out_structure$do_index_time_missing_from_nd <-
setdiff(data[[time]], predict_args$newdata[[time]])
} else {
out_structure$do_index <- FALSE
}
Expand Down
3 changes: 3 additions & 0 deletions R/index.R
Original file line number Diff line number Diff line change
Expand Up @@ -235,5 +235,8 @@ get_generic <- function(obj, value_name, bias_correct = FALSE, level = 0.95,
if (!is.null(obj$fake_nd)) {
d <- d[!d[[obj$fit_obj$time]] %in% obj$fake_nd[[obj$fit_obj$time]], ,drop = FALSE]
}
if ("do_index_time_missing_from_nd" %in% names(obj$fit_obj)) {
d <- d[!d[[obj$fit_obj$time]] %in% obj$fit_obj$do_index_time_missing_from_nd, ,drop = FALSE]
}
d[,c(time_name, 'est', 'lwr', 'upr', 'trans_est', 'se'), drop = FALSE]
}
1 change: 0 additions & 1 deletion R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,6 @@ predict.sdmTMB <- function(object, newdata = NULL,
return(tmb_data)
}

# TODO: when fields are a RW, visreg call crashes R here...
new_tmb_obj <- TMB::MakeADFun(
data = tmb_data,
parameters = get_pars(object),
Expand Down
22 changes: 21 additions & 1 deletion tests/testthat/test-extra-time.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ test_that("extra time, newdata, and offsets work", {
test_that("extra_time, newdata, get_index() work", {
m <- sdmTMB(
density ~ 1,
time_varying = ~1,
time_varying = ~ 1,
time_varying_type = "ar1",
data = pcod,
family = tweedie(link = "log"),
Expand Down Expand Up @@ -88,4 +88,24 @@ test_that("extra_time, newdata, get_index() work", {
p <- predict(m, newdata = nd, return_tmb_object = TRUE)
ind5 <- get_index(p)
expect_equal(ind2[ind2$year %in% nd$year, "est"], ind5[ind5$year %in% nd$year, "est"])

# with do_index = TRUE
nd <- replicate_df(pcod, "year", unique(pcod$year))
m2 <- sdmTMB(
density ~ 1,
time_varying = ~ 1,
time_varying_type = "ar1",
data = pcod,
family = tweedie(link = "log"),
time = "year",
spatial = "off",
spatiotemporal = "off",
do_index = TRUE,
predict_args = list(newdata = nd),
index_args = list(area = 1), # used to cause crash b/c extra_time
extra_time = c(2006, 2008, 2010, 2012, 2014, 2016, 2018) # last real year is 2017
)
ind6 <- get_index(m2)
expect_identical(ind6$year, c(2003, 2004, 2005, 2007, 2009, 2011, 2013, 2015, 2017))
expect_equal(ind3$est, ind6$est)
})

0 comments on commit a776a41

Please sign in to comment.