Skip to content

Commit

Permalink
compute derivatives for compose_trans (#322)
Browse files Browse the repository at this point in the history
  • Loading branch information
mjskay committed Nov 3, 2023
1 parent 3df51f6 commit 669a8ec
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
29 changes: 26 additions & 3 deletions R/trans-compose.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,16 @@ compose_trans <- function(...) {

names <- vapply(trans_list, "[[", "name", FUN.VALUE = character(1))

has_d_transform <- all(lengths(lapply(trans_list, "[[", "d_transform")) > 0)
has_d_inverse <- all(lengths(lapply(trans_list, "[[", "d_inverse")) > 0)

trans_new(
paste0("composition(", paste0(names, collapse = ","), ")"),
transform = function(x) compose_fwd(x, trans_list),
inverse = function(x) compose_rev(x, trans_list),
breaks = function(x) trans_list[[1]]$breaks(x),
transform = function(x) compose_fwd(x, trans_list),
inverse = function(x) compose_rev(x, trans_list),
d_transform = if (has_d_transform) function(x) compose_deriv_fwd(x, trans_list),
d_inverse = if (has_d_inverse) function(x) compose_deriv_rev(x, trans_list),
breaks = function(x) trans_list[[1]]$breaks(x),
domain = domain
)
}
Expand All @@ -49,3 +54,21 @@ compose_rev <- function(x, trans_list) {
}
x
}

compose_deriv_fwd <- function(x, trans_list) {
x_deriv <- 1
for (trans in trans_list) {
x_deriv <- trans$d_transform(x) * x_deriv
x <- trans$transform(x)
}
x_deriv
}

compose_deriv_rev <- function(x, trans_list) {
x_deriv <- 1
for (trans in rev(trans_list)) {
x_deriv <- trans$d_inverse(x) * x_deriv
x <- trans$inverse(x)
}
x_deriv
}
12 changes: 12 additions & 0 deletions tests/testthat/test-trans-compose.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@ test_that("composes transforms correctly", {
expect_equal(t$inverse(-2), 100)
})

test_that("composes derivatives correctly", {
t <- compose_trans("sqrt", "reciprocal", "reverse")
expect_equal(t$d_transform(0.25), 4)
expect_equal(t$d_inverse(-2), 0.25)
})

test_that("produces NULL derivatives if not all transforms have derivatives", {
t <- compose_trans("sqrt", trans_new("no_deriv", identity, identity))
expect_null(t$d_transform)
expect_null(t$d_inverse)
})

test_that("uses breaks from first transformer", {
t <- compose_trans("log10", "reverse")
expect_equal(t$breaks(c(1, 1000)), log_breaks()(c(1, 1000)))
Expand Down

0 comments on commit 669a8ec

Please sign in to comment.