Skip to content

Commit

Permalink
Add terra predict support, fixes #78
Browse files Browse the repository at this point in the history
  • Loading branch information
kenkellner committed Sep 15, 2024
1 parent ea4f25f commit 775f869
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 6 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Imports:
rstan (>= 2.26.0),
rstantools (>= 2.0.0),
stats
Suggests: knitr, raster, rmarkdown, testthat
Suggests: knitr, raster, rmarkdown, terra, testthat
VignetteBuilder: knitr
Description: Fit Bayesian hierarchical models of animal abundance and occurrence
via the 'rstan' package, the R interface to the 'Stan' C++ library.
Expand Down
23 changes: 19 additions & 4 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#'
#' @param object A fitted model of class \code{ubmsFit}
#' @param submodel Submodel to predict from, for example \code{"det"}
#' @param newdata Optional data frame or RasterStack of covariates to generate
#' @param newdata Optional data frame, SpatRaster, or RasterStack of covariates to generate
#' predictions from. If not provided (the default), predictions are
#' generated from the original data
#' @param transform If \code{TRUE}, back-transform the predictions to their
Expand All @@ -25,9 +25,9 @@
#' For parameters with more than one dimension, the rows are in site-major
#' order, or site-year-observation for dynamic models.
#'
#' If \code{newdata} was a RasterStack, returns a RasterStack with four
#' layers corresponding to the four columns above with the same projection
#' as the original RasterStack.
#' If \code{newdata} was a SpatRaster/RasterStack, returns a SpatRaster/RasterStack
#' with four layers corresponding to the four columns above with the same projection
#' as the original SpatRaster/RasterStack.
#'
#' @aliases predict
#' @method predict ubmsFit
Expand All @@ -41,6 +41,8 @@ setMethod("predict", "ubmsFit",

if(inherits(newdata, c("RasterLayer", "RasterStack"))){
return(predict_raster(object, submodel, newdata, transform, re.form, level))
} else if(inherits(newdata, "SpatRaster")){
return(predict_terra(object, submodel, newdata, transform, re.form, level))
}

samples <- 1:nsamples(object)
Expand Down Expand Up @@ -69,3 +71,16 @@ predict_raster <- function(object, submodel, inp_rast, transform, re.form, level
names(out)[3:4] <- c("Lower","Upper")
out
}

predict_terra <- function(object, submodel, inp_rast, transform, re.form, level){
if(!requireNamespace("terra", quietly=TRUE)){
stop('Package "terra" is not installed', call.=FALSE)
}
df_dat <- terra::as.data.frame(inp_rast, xy=TRUE)
out <- cbind(df_dat[,1:2,drop=FALSE],
predict(object, submodel, newdata=df_dat, transform=transform,
re.form=re.form, level=level))
out <- terra::rast(out, type="xyz", crs=terra::crs(inp_rast))
names(out)[3:4] <- c("Lower","Upper")
out
}
23 changes: 22 additions & 1 deletion tests/testthat/test_ubmsFit_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,11 @@ test_that("traceplot method works for ubmsFit",{
expect_is(tr, "gg")
})

test_that("predicting map works for ubmsFit",{
test_that("predicting with raster works for ubmsFit",{
skip_on_ci()
skip_if_not(requireNamespace("raster"))

set.seed(123)
r <- raster::raster(matrix(rnorm(30), ncol=5, nrow=6))
names(r) <- "x1"
r2 <- r
Expand All @@ -112,6 +115,24 @@ test_that("predicting map works for ubmsFit",{
expect_equal(length(pr_rast2), 30*4)
})

test_that("predicting with terra works for ubmsFit",{
skip_on_ci()
skip_if_not(requireNamespace("terra"))

set.seed(123)
r <- terra::rast(matrix(rnorm(30), ncol=5, nrow=6))
names(r) <- "x1"
r2 <- r
names(r2) <- "x2"
rs <- c(r, r2)
pr_terra <- predict(fit, "state", newdata=r, re.form=NA)
expect_is(pr_terra, "SpatRaster")
expect_equal(length(terra::values(pr_terra)), 30*4)
pr_terra2 <- predict(fit, "state", newdata=rs, re.form=NA)
expect_is(pr_terra2, "SpatRaster")
expect_equal(length(terra::values(pr_terra2)), 30*4)
})

test_that("getP method works for ubmsFit",{
gp <- getP(fit)
expect_equal(dim(gp), c(3,3,40))
Expand Down

0 comments on commit 775f869

Please sign in to comment.