Skip to content

Commit

Permalink
Merge pull request #5 from seroanalytics/spline
Browse files Browse the repository at this point in the history
support spline options
  • Loading branch information
hillalex authored Sep 15, 2024
2 parents 971f681 + 369aacb commit 2e7d650
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 9 deletions.
27 changes: 20 additions & 7 deletions R/api.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ target_get_trace <- function(name,
req,
filter = NULL,
disaggregate = NULL,
scale = "natural") {
scale = "natural",
method = "auto",
span = 0.75,
k = 10) {
logger::log_info(paste("Requesting data from", name,
"with biomarker", biomarker))
dataset <- read_dataset(req, name, scale)
Expand All @@ -112,15 +115,23 @@ target_get_trace <- function(name,
groups <- split(dat, eval(parse(text = paste("~", disaggregate))))
nms <- names(groups)
return(lapply(seq_along(groups), function(i) {
model <- with_warnings(model_out(groups[[i]], xcol))
model <- with_warnings(model_out(groups[[i]],
xcol = xcol,
method = method,
span = span,
k = k))
list(name = jsonlite::unbox(nms[[i]]),
model = model$output,
raw = data_out(groups[[i]], xcol),
warnings = model$warnings)
}))
} else {
logger::log_info("Returning single trace")
model <- with_warnings(model_out(dat, xcol))
model <- with_warnings(model_out(dat,
xcol = xcol,
method = method,
span = span,
k = k))
nm <- ifelse(is.null(filter), "all", filter)
return(list(list(name = jsonlite::unbox(nm),
model = model$output,
Expand Down Expand Up @@ -149,16 +160,18 @@ read_dataset <- function(req, name, scale) {
list(data = dat, xcol = xcol)
}

model_out <- function(dat, xcol) {
model_out <- function(dat, xcol, method = "auto", span = 0.75, k = 10) {
n <- nrow(dat)
if (n == 0) {
return(list(x = list(), y = list()))
}
if (n > 1000) {
m <- mgcv::gam(value ~ s(eval(parse(text = xcol)), bs = "cs"),
if ((n > 1000 && method == "auto") || method == "gam") {
fmla <- sprintf("value ~ s(%s, bs = 'cs', k = %f)", xcol, k)
m <- mgcv::gam(eval(parse(text = fmla)),
data = dat, method = "REML")
} else {
m <- stats::loess(value ~ eval(parse(text = xcol)), data = dat, span = 0.75)
fmla <- sprintf("value ~ %s", xcol)
m <- stats::loess(fmla, data = dat, span = span)
}
range <- range(dat[, xcol], na.rm = TRUE)
xseq <- range[1]:range[2]
Expand Down
5 changes: 4 additions & 1 deletion R/router.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ get_trace <- function() {
target_get_trace,
porcelain::porcelain_input_query(disaggregate = "string",
filter = "string",
scale = "string"),
scale = "string",
method = "string",
span = "numeric",
k = "numeric"),
returning = porcelain::porcelain_returning_json("DataSeries"))
}

Expand Down
6 changes: 5 additions & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ with_warnings <- function(expr) {
invokeRestart("muffleWarning")
}

val <- withCallingHandlers(expr, warning = w_handler)
e_handler <- function(e) {
porcelain::porcelain_stop(jsonlite::unbox(conditionMessage(e)))
}

val <- withCallingHandlers(expr, warning = w_handler, error = e_handler)
list(output = val,
warnings = my_warnings)
}
Expand Down
68 changes: 68 additions & 0 deletions tests/testthat/test-model.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
test_that("model is gam if specified", {
dat <- data.frame(day = 1:100, value = rnorm(100))
res <- model_out(dat, xcol = "day", method = "gam")

m <- mgcv::gam(value ~ s(day, bs = "cs"),
data = dat, method = "REML")
xdf <- tibble::tibble(day = 1:100)
expected <- stats::predict(m, xdf)

expect_true(all(res$y == expected))
})

test_that("model is loess if specified", {
dat <- data.frame(day = 1:2000, value = rnorm(2000))
res <- model_out(dat, xcol = "day", method = "loess")

m <- stats::loess(value ~ day, data = dat, span = 0.75)
xdf <- tibble::tibble(day = 1:2000)
expected <- stats::predict(m, xdf)

expect_true(all(res$y == expected))
})

test_that("model is loess if not specified and n <= 1000", {
dat <- data.frame(day = 1:1000, value = rnorm(1000))
res <- model_out(dat, xcol = "day")

m <- stats::loess(value ~ day, data = dat, span = 0.75)
xdf <- tibble::tibble(day = 1:1000)
expected <- stats::predict(m, xdf)

expect_true(all(res$y == expected))
})

test_that("model is gam if not specified and n > 1000", {
dat <- data.frame(day = 1:1001, value = rnorm(1001))
res <- model_out(dat, xcol = "day")

m <- mgcv::gam(value ~ s(day, bs = "cs"),
data = dat, method = "REML")
xdf <- tibble::tibble(day = 1:1001)
expected <- stats::predict(m, xdf)

expect_true(all(res$y == expected))
})

test_that("model uses gam options", {
dat <- data.frame(day = 1:1001, value = rnorm(1001))
res <- model_out(dat, xcol = "day", k = 5)

m <- mgcv::gam(value ~ s(day, bs = "cs", k = 5),
data = dat, method = "REML")
xdf <- tibble::tibble(day = 1:1001)
expected <- stats::predict(m, xdf)

expect_true(all(res$y == expected))
})

test_that("model uses loess options", {
dat <- data.frame(day = 1:100, value = rnorm(100))
res <- model_out(dat, xcol = "day", span = 0.5)

m <- stats::loess(value ~ day, data = dat, span = 0.5)
xdf <- tibble::tibble(day = 1:100)
expected <- stats::predict(m, xdf)

expect_true(all(res$y == expected))
})
62 changes: 62 additions & 0 deletions tests/testthat/test-read.R
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,65 @@ test_that("can get log2 data", {
))
})

test_that("can use loess model options", {
dat <- data.frame(biomarker = "ab",
value = 1:5,
day = 1:5)
router <- build_routes(cookie_key)
local_add_dataset(dat, name = "testdataset")
res <- router$call(make_req("GET",
"/dataset/testdataset/trace/ab/",
qs = "method=loess&span=0.5",
HTTP_COOKIE = cookie))
expect_equal(res$status, 200)
body <- jsonlite::fromJSON(res$body)
data <- body$data

suppressWarnings(m <- stats::loess(value ~ day, data = dat, span = 0.5))
xdf <- tibble::tibble(day = 1:5)
expected <- stats::predict(m, xdf)
expect_equal(unlist(data$model[1, "y"]),
jsonlite::fromJSON(
jsonlite::toJSON(expected) # convert to/from json for consistent rounding
))
})

test_that("can use gam model options", {
dat <- data.frame(biomarker = "ab",
value = 1:5,
day = 1:5)
router <- build_routes(cookie_key)
local_add_dataset(dat, name = "testdataset")
res <- router$call(make_req("GET",
"/dataset/testdataset/trace/ab/",
qs = "method=gam&k=2",
HTTP_COOKIE = cookie))
expect_equal(res$status, 200)
body <- jsonlite::fromJSON(res$body)
data <- body$data
suppressWarnings(m <- mgcv::gam(value ~ s(day, bs = "cs", k = 2),
data = dat, method = "REML"))
xdf <- tibble::tibble(day = 1:5)
expected <- stats::predict(m, xdf)
expect_equal(unlist(data$model[1, "y"]),
jsonlite::fromJSON(
jsonlite::toJSON(expected) # convert to/from json for consistent rounding
))
})

test_that("error running the model results in a 400", {
dat <- data.frame(biomarker = "ab",
value = 1:5,
day = 1:5)
router <- build_routes(cookie_key)
local_add_dataset(dat, name = "testdataset")
res <- router$call(make_req("GET",
"/dataset/testdataset/trace/ab/",
qs = "method=gam&k=10",
HTTP_COOKIE = cookie))
expect_equal(res$status, 400)
validate_failure_schema(res$body)
body <- jsonlite::fromJSON(res$body)
expect_equal(body$errors[1, "detail"],
"day has insufficient unique values to support 10 knots: reduce k.")
})

0 comments on commit 2e7d650

Please sign in to comment.