From 483a82a707ea449b65a9faf6283d177dae6dddb5 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Mon, 6 Nov 2023 01:53:06 -0600 Subject: [PATCH] Derivatives of transformations (#341) --- NEWS.md | 2 + R/trans-compose.R | 29 +++++- R/trans-numeric.R | 129 ++++++++++++++++------- R/trans.R | 17 +++- man/probability_trans.Rd | 5 +- man/trans_new.Rd | 13 ++- tests/testthat/test-trans-compose.R | 12 +++ tests/testthat/test-trans-numeric.R | 153 ++++++++++++++++++++++++++++ 8 files changed, 314 insertions(+), 46 deletions(-) diff --git a/NEWS.md b/NEWS.md index 775cf1fe..ce2434ff 100644 --- a/NEWS.md +++ b/NEWS.md @@ -23,6 +23,8 @@ * Add an inverse (area) hyperbolic sine transformation `asinh_trans()`, which provides a logarithm-like transformation of a space, but which accommodates negative values (#297) +* Transformation objects can optionally include the derivatives of the transform + and the inverse transform (@mjskay, #322). # scales 1.2.1 diff --git a/R/trans-compose.R b/R/trans-compose.R index 26b85da2..252d4af2 100644 --- a/R/trans-compose.R +++ b/R/trans-compose.R @@ -27,11 +27,16 @@ compose_trans <- function(...) { names <- vapply(trans_list, "[[", "name", FUN.VALUE = character(1)) + has_d_transform <- all(lengths(lapply(trans_list, "[[", "d_transform")) > 0) + has_d_inverse <- all(lengths(lapply(trans_list, "[[", "d_inverse")) > 0) + trans_new( paste0("composition(", paste0(names, collapse = ","), ")"), - transform = function(x) compose_fwd(x, trans_list), - inverse = function(x) compose_rev(x, trans_list), - breaks = function(x) trans_list[[1]]$breaks(x), + transform = function(x) compose_fwd(x, trans_list), + inverse = function(x) compose_rev(x, trans_list), + d_transform = if (has_d_transform) function(x) compose_deriv_fwd(x, trans_list), + d_inverse = if (has_d_inverse) function(x) compose_deriv_rev(x, trans_list), + breaks = function(x) trans_list[[1]]$breaks(x), domain = domain ) } @@ -49,3 +54,21 @@ compose_rev <- function(x, trans_list) { } x } + +compose_deriv_fwd <- function(x, trans_list) { + x_deriv <- 1 + for (trans in trans_list) { + x_deriv <- trans$d_transform(x) * x_deriv + x <- trans$transform(x) + } + x_deriv +} + +compose_deriv_rev <- function(x, trans_list) { + x_deriv <- 1 + for (trans in rev(trans_list)) { + x_deriv <- trans$d_inverse(x) * x_deriv + x <- trans$inverse(x) + } + x_deriv +} diff --git a/R/trans-numeric.R b/R/trans-numeric.R index 509fb0ec..301e15b0 100644 --- a/R/trans-numeric.R +++ b/R/trans-numeric.R @@ -11,6 +11,8 @@ asn_trans <- function() { "asn", function(x) 2 * asin(sqrt(x)), function(x) sin(x / 2)^2, + d_transform = function(x) 1 / sqrt(x - x^2), + d_inverse = function(x) sin(x) / 2, domain = c(0, 1) ) } @@ -21,7 +23,14 @@ asn_trans <- function() { #' @examples #' plot(atanh_trans(), xlim = c(-1, 1)) atanh_trans <- function() { - trans_new("atanh", "atanh", "tanh", domain = c(-1, 1)) + trans_new( + "atanh", + "atanh", + "tanh", + d_transform = function(x) 1 / (1 - x^2), + d_inverse = function(x) 1 / cosh(x)^2, + domain = c(-1, 1) + ) } #' Inverse Hyperbolic Sine transformation @@ -33,7 +42,9 @@ asinh_trans <- function() { trans_new( "asinh", transform = asinh, - inverse = sinh + inverse = sinh, + d_transform = function(x) 1 / sqrt(x^2 + 1), + d_inverse = cosh ) } @@ -80,30 +91,35 @@ asinh_trans <- function() { #' plot(modulus_trans(1), xlim = c(-10, 10)) #' plot(modulus_trans(2), xlim = c(-10, 10)) boxcox_trans <- function(p, offset = 0) { - trans <- function(x) { + if (abs(p) < 1e-07) { + trans <- function(x) log(x + offset) + inv <- function(x) exp(x) - offset + d_trans <- function(x) 1 / (x + offset) + d_inv <- "exp" + } else { + trans <- function(x) ((x + offset)^p - 1) / p + inv <- function(x) (x * p + 1)^(1 / p) - offset + d_trans <- function(x) (x + offset)^(p - 1) + d_inv <- function(x) (x * p + 1)^(1 / p - 1) + } + + trans_with_check <- function(x) { if (any((x + offset) < 0, na.rm = TRUE)) { cli::cli_abort(c( "{.fun boxcox_trans} must be given only positive values", i = "Consider using {.fun modulus_trans} instead?" )) } - if (abs(p) < 1e-07) { - log(x + offset) - } else { - ((x + offset)^p - 1) / p - } - } - - inv <- function(x) { - if (abs(p) < 1e-07) { - exp(x) - offset - } else { - (x * p + 1)^(1 / p) - offset - } + trans(x) } trans_new( - paste0("pow-", format(p)), trans, inv, domain = c(0, Inf) + paste0("pow-", format(p)), + trans_with_check, + inv, + d_transform = d_trans, + d_inverse = d_inv, + domain = c(0, Inf) ) } @@ -113,12 +129,17 @@ modulus_trans <- function(p, offset = 1) { if (abs(p) < 1e-07) { trans <- function(x) sign(x) * log(abs(x) + offset) inv <- function(x) sign(x) * (exp(abs(x)) - offset) + d_trans <- function(x) 1 / (abs(x) + offset) + d_inv <- function(x) exp(abs(x)) } else { trans <- function(x) sign(x) * ((abs(x) + offset)^p - 1) / p inv <- function(x) sign(x) * ((abs(x) * p + 1)^(1 / p) - offset) + d_trans <- function(x) (abs(x) + offset)^(p - 1) + d_inv <- function(x) (abs(x) * p + 1)^(1 / p - 1) } trans_new( - paste0("mt-pow-", format(p)), trans, inv + paste0("mt-pow-", format(p)), trans, inv, + d_transform = d_trans, d_inverse = d_inv ) } @@ -153,34 +174,44 @@ yj_trans <- function(p) { eps <- 1e-7 if (abs(p) < eps) { - trans_pos <- function(x) log(x + 1) - inv_pos <- function(x) exp(x) - 1 + trans_pos <- log1p + inv_pos <- expm1 + d_trans_pos <- function(x) 1 / (1 + x) + d_inv_pos <- exp } else { trans_pos <- function(x) ((x + 1)^p - 1) / p inv_pos <- function(x) (p * x + 1)^(1 / p) - 1 + d_trans_pos <- function(x) (x + 1)^(p - 1) + d_inv_pos <- function(x) (p * x + 1)^(1 / p - 1) } if (abs(2 - p) < eps) { - trans_neg <- function(x) -log(-x + 1) + trans_neg <- function(x) -log1p(-x) inv_neg <- function(x) 1 - exp(-x) + d_trans_neg <- function(x) 1 / (1 - x) + d_inv_new <- function(x) exp(-x) } else { trans_neg <- function(x) -((-x + 1)^(2 - p) - 1) / (2 - p) inv_neg <- function(x) 1 - (-(2 - p) * x + 1)^(1 / (2 - p)) + d_trans_neg <- function(x) (1 - x)^(1 - p) + d_inv_neg <- function(x) (-(2 - p) * x + 1)^(1 / (2 - p) - 1) } trans_new( paste0("yeo-johnson-", format(p)), function(x) trans_two_sided(x, trans_pos, trans_neg), - function(x) trans_two_sided(x, inv_pos, inv_neg) + function(x) trans_two_sided(x, inv_pos, inv_neg), + d_transform = function(x) trans_two_sided(x, d_trans_pos, d_trans_neg, f_at_0 = 1), + d_inverse = function(x) trans_two_sided(x, d_inv_pos, d_inv_neg, f_at_0 = 1) ) } -trans_two_sided <- function(x, pos, neg) { +trans_two_sided <- function(x, pos, neg, f_at_0 = 0) { out <- rep(NA_real_, length(x)) present <- !is.na(x) out[present & x > 0] <- pos(x[present & x > 0]) out[present & x < 0] <- neg(x[present & x < 0]) - out[present & x == 0] <- 0 + out[present & x == 0] <- f_at_0 out } @@ -198,7 +229,9 @@ exp_trans <- function(base = exp(1)) { trans_new( paste0("power-", format(base)), function(x) base^x, - function(x) log(x, base = base) + function(x) log(x, base = base), + d_transform = function(x) base^x * log(base), + d_inverse = function(x) 1 / x / log(base) ) } @@ -208,7 +241,13 @@ exp_trans <- function(base = exp(1)) { #' @examples #' plot(identity_trans(), xlim = c(-1, 1)) identity_trans <- function() { - trans_new("identity", "force", "force") + trans_new( + "identity", + "force", + "force", + d_transform = function(x) rep(1, length(x)), + d_inverse = function(x) rep(1, length(x)) + ) } @@ -237,11 +276,13 @@ identity_trans <- function() { #' lines(log_trans(), xlim = c(1, 20), col = "red") log_trans <- function(base = exp(1)) { force(base) - trans <- function(x) log(x, base) - inv <- function(x) base^x - - trans_new(paste0("log-", format(base)), trans, inv, - log_breaks(base = base), + trans_new( + paste0("log-", format(base)), + function(x) log(x, base), + function(x) base^x, + d_transform = function(x) 1 / x / log(base), + d_inverse = function(x) base^x * log(base), + breaks = log_breaks(base = base), domain = c(1e-100, Inf) ) } @@ -261,7 +302,11 @@ log2_trans <- function() { #' @export log1p_trans <- function() { trans_new( - "log1p", "log1p", "expm1", + "log1p", + "log1p", + "expm1", + d_transform = function(x) 1 / (1 + x), + d_inverse = "exp", domain = c(-1 + .Machine$double.eps, Inf) ) } @@ -273,15 +318,18 @@ pseudo_log_trans <- function(sigma = 1, base = exp(1)) { trans_new( "pseudo_log", function(x) asinh(x / (2 * sigma)) / log(base), - function(x) 2 * sigma * sinh(x * log(base)) + function(x) 2 * sigma * sinh(x * log(base)), + d_transform = function(x) 1 / (sqrt(4 + x^2/sigma^2) * sigma * log(base)), + d_inverse = function(x) 2 * sigma * cosh(x * log(base)) * log(base) ) } #' Probability transformation #' #' @param distribution probability distribution. Should be standard R -#' abbreviation so that "p" + distribution is a valid probability density -#' function, and "q" + distribution is a valid quantile function. +#' abbreviation so that "p" + distribution is a valid cumulative distribution +#' function, "q" + distribution is a valid quantile function, and +#' "d" + distribution is a valid probability density function. #' @param ... other arguments passed on to distribution and quantile functions #' @export #' @examples @@ -290,11 +338,14 @@ pseudo_log_trans <- function(sigma = 1, base = exp(1)) { probability_trans <- function(distribution, ...) { qfun <- match.fun(paste0("q", distribution)) pfun <- match.fun(paste0("p", distribution)) + dfun <- match.fun(paste0("d", distribution)) trans_new( paste0("prob-", distribution), function(x) qfun(x, ...), function(x) pfun(x, ...), + d_transform = function(x) 1 / dfun(qfun(x, ...), ...), + d_inverse = function(x) dfun(x, ...), domain = c(0, 1) ) } @@ -314,7 +365,9 @@ reciprocal_trans <- function() { trans_new( "reciprocal", function(x) 1 / x, - function(x) 1 / x + function(x) 1 / x, + d_transform = function(x) -1 / x^2, + d_inverse = function(x) -1 / x^2 ) } @@ -332,6 +385,8 @@ reverse_trans <- function() { "reverse", function(x) -x, function(x) -x, + d_transform = function(x) rep(-1, length(x)), + d_inverse = function(x) rep(-1, length(x)), minor_breaks = regular_minor_breaks(reverse = TRUE) ) } @@ -349,6 +404,8 @@ sqrt_trans <- function() { "sqrt", "sqrt", function(x) ifelse(x < 0, NA_real_, x ^ 2), + d_transform = function(x) 0.5 / sqrt(x), + d_inverse = function(x) 2 * x, domain = c(0, Inf) ) } diff --git a/R/trans.R b/R/trans.R index 23a93804..38e43b69 100644 --- a/R/trans.R +++ b/R/trans.R @@ -3,14 +3,19 @@ #' A transformation encapsulates a transformation and its inverse, as well #' as the information needed to create pleasing breaks and labels. The `breaks()` #' function is applied on the un-transformed range of the data, and the -#' `format()` function takes the output of the `breaks()` function and return -#' well-formatted labels. +#' `format()` function takes the output of the `breaks()` function and returns +#' well-formatted labels. Transformations may also include the derivatives of the +#' transformation and its inverse, but are not required to. #' #' @param name transformation name #' @param transform function, or name of function, that performs the #' transformation #' @param inverse function, or name of function, that performs the #' inverse of the transformation +#' @param d_transform Optional function, or name of function, that gives the +#' derivative of the transformation. May be `NULL`. +#' @param d_inverse Optional function, or name of function, that gives the +#' derivative of the inverse of the transformation. May be `NULL`. #' @param breaks default breaks function for this transformation. The breaks #' function is applied to the un-transformed data. #' @param minor_breaks default minor breaks function for this transformation. @@ -23,17 +28,23 @@ #' @export #' @keywords internal #' @aliases trans -trans_new <- function(name, transform, inverse, breaks = extended_breaks(), +trans_new <- function(name, transform, inverse, + d_transform = NULL, d_inverse = NULL, + breaks = extended_breaks(), minor_breaks = regular_minor_breaks(), format = format_format(), domain = c(-Inf, Inf)) { if (is.character(transform)) transform <- match.fun(transform) if (is.character(inverse)) inverse <- match.fun(inverse) + if (is.character(d_transform)) d_transform <- match.fun(d_transform) + if (is.character(d_inverse)) d_inverse <- match.fun(d_inverse) structure( list( name = name, transform = transform, inverse = inverse, + d_transform = d_transform, + d_inverse = d_inverse, breaks = breaks, minor_breaks = minor_breaks, format = format, diff --git a/man/probability_trans.Rd b/man/probability_trans.Rd index 67ca1771..e0a4e82d 100644 --- a/man/probability_trans.Rd +++ b/man/probability_trans.Rd @@ -14,8 +14,9 @@ probit_trans() } \arguments{ \item{distribution}{probability distribution. Should be standard R -abbreviation so that "p" + distribution is a valid probability density -function, and "q" + distribution is a valid quantile function.} +abbreviation so that "p" + distribution is a valid cumulative distribution +function, "q" + distribution is a valid quantile function, and +"d" + distribution is a valid probability density function.} \item{...}{other arguments passed on to distribution and quantile functions} } diff --git a/man/trans_new.Rd b/man/trans_new.Rd index 1e6b8fa1..64ae23ad 100644 --- a/man/trans_new.Rd +++ b/man/trans_new.Rd @@ -11,6 +11,8 @@ trans_new( name, transform, inverse, + d_transform = NULL, + d_inverse = NULL, breaks = extended_breaks(), minor_breaks = regular_minor_breaks(), format = format_format(), @@ -30,6 +32,12 @@ transformation} \item{inverse}{function, or name of function, that performs the inverse of the transformation} +\item{d_transform}{Optional function, or name of function, that gives the +derivative of the transformation. May be \code{NULL}.} + +\item{d_inverse}{Optional function, or name of function, that gives the +derivative of the inverse of the transformation. May be \code{NULL}.} + \item{breaks}{default breaks function for this transformation. The breaks function is applied to the un-transformed data.} @@ -46,8 +54,9 @@ argument.} A transformation encapsulates a transformation and its inverse, as well as the information needed to create pleasing breaks and labels. The \code{breaks()} function is applied on the un-transformed range of the data, and the -\code{format()} function takes the output of the \code{breaks()} function and return -well-formatted labels. +\code{format()} function takes the output of the \code{breaks()} function and returns +well-formatted labels. Transformations may also include the derivatives of the +transformation and its inverse, but are not required to. } \seealso{ \Sexpr[results=rd,stage=build]{scales:::seealso_trans()} diff --git a/tests/testthat/test-trans-compose.R b/tests/testthat/test-trans-compose.R index d564fd4f..4a49dbc5 100644 --- a/tests/testthat/test-trans-compose.R +++ b/tests/testthat/test-trans-compose.R @@ -4,6 +4,18 @@ test_that("composes transforms correctly", { expect_equal(t$inverse(-2), 100) }) +test_that("composes derivatives correctly", { + t <- compose_trans("sqrt", "reciprocal", "reverse") + expect_equal(t$d_transform(0.25), 4) + expect_equal(t$d_inverse(-2), 0.25) +}) + +test_that("produces NULL derivatives if not all transforms have derivatives", { + t <- compose_trans("sqrt", trans_new("no_deriv", identity, identity)) + expect_null(t$d_transform) + expect_null(t$d_inverse) +}) + test_that("uses breaks from first transformer", { t <- compose_trans("log10", "reverse") expect_equal(t$breaks(c(1, 1000)), log_breaks()(c(1, 1000))) diff --git a/tests/testthat/test-trans-numeric.R b/tests/testthat/test-trans-numeric.R index 13252631..a7f5cbf8 100644 --- a/tests/testthat/test-trans-numeric.R +++ b/tests/testthat/test-trans-numeric.R @@ -121,3 +121,156 @@ test_that("probability transforms have domain (0,1)", { expect_equal(logit_trans()$domain, c(0, 1)) expect_equal(probit_trans()$domain, c(0, 1)) }) + +# Derivatives ------------------------------------------------------------- + +test_that("asn_trans derivatives work", { + trans <- asn_trans() + expect_equal(trans$d_transform(c(0, 0.5, 1)), c(Inf, 2, Inf)) + expect_equal(trans$d_inverse(c(0, pi/2, pi)), c(0, 0.5, 0)) + x <- seq(0.1, 0.9, length.out = 10) + expect_equal(trans$d_transform(x), 1 / trans$d_inverse(trans$transform(x))) + expect_equal(trans$d_inverse(x), 1 / trans$d_transform(trans$inverse(x))) +}) + +test_that("atanh_trans derivatives work", { + trans <- atanh_trans() + expect_equal(trans$d_transform(c(-1, 0, 1)), c(Inf, 1, Inf)) + expect_equal(trans$d_inverse(c(-log(2), 0, log(2))), c(0.64, 1, 0.64)) + x <- seq(-0.9, 0.9, length.out = 10) + expect_equal(trans$d_transform(x), 1 / trans$d_inverse(trans$transform(x))) + expect_equal(trans$d_inverse(x), 1 / trans$d_transform(trans$inverse(x))) +}) + +test_that("asinh_trans derivatives work", { + trans <- asinh_trans() + expect_equal(trans$d_transform(c(-1, 0, 1)), c(sqrt(2) / 2, 1, sqrt(2) / 2)) + expect_equal(trans$d_inverse(c(-log(2), 0, log(2))), c(1.25, 1, 1.25)) + x <- seq(-0.9, 0.9, length.out = 10) + expect_equal(trans$d_transform(x), 1 / trans$d_inverse(trans$transform(x))) + expect_equal(trans$d_inverse(x), 1 / trans$d_transform(trans$inverse(x))) +}) + +test_that("boxcox_trans derivatives work", { + trans <- boxcox_trans(p = 0, offset = 1) + expect_equal(trans$d_transform(c(0, 1, 2)), c(1, 1/2, 1/3)) + expect_equal(trans$d_inverse(c(0, 1, 2)), exp(c(0, 1, 2))) + x <- 0:10 + expect_equal(trans$d_transform(x), 1 / trans$d_inverse(trans$transform(x))) + expect_equal(trans$d_inverse(x), 1 / trans$d_transform(trans$inverse(x))) + + trans <- boxcox_trans(p = 2, offset = 2) + expect_equal(trans$d_transform(c(0, 1, 2)), c(2, 3, 4)) + expect_equal(trans$d_inverse(c(0, 0.5, 4)), c(1, sqrt(2) / 2, 1/3)) + expect_equal(trans$d_transform(x), 1 / trans$d_inverse(trans$transform(x))) + expect_equal(trans$d_inverse(x), 1 / trans$d_transform(trans$inverse(x))) +}) + +test_that("modulus_trans derivatives work", { + trans <- modulus_trans(p = 0, offset = 1) + expect_equal(trans$d_transform(c(-2, -1, 1, 2)), c(1/3, 1/2, 1/2, 1/3)) + expect_equal(trans$d_inverse(c(-2, -1, 1, 2)), exp(c(2, 1, 1, 2))) + x <- c(-10:-2, 2:10) + expect_equal(trans$d_transform(x), 1 / trans$d_inverse(trans$transform(x))) + expect_equal(trans$d_inverse(x), 1 / trans$d_transform(trans$inverse(x))) + + trans <- modulus_trans(p = 2, offset = 2) + expect_equal(trans$d_transform(c(-2, -1, 1, 2)), c(4, 3, 3, 4)) + expect_equal(trans$d_inverse(c(-4, -0.5, 0.5, 4)), c(1/3, sqrt(2) / 2, sqrt(2) / 2, 1/3)) + expect_equal(trans$d_transform(x), 1 / trans$d_inverse(trans$transform(x))) + expect_equal(trans$d_inverse(x), 1 / trans$d_transform(trans$inverse(x))) +}) + +test_that("yj_trans derivatives work", { + trans <- yj_trans(p = 0) + expect_equal(trans$d_transform(c(-2, -1, 1, 2)), c(3, 2, 0.5, 1/3)) + expect_equal(trans$d_inverse(c(-1/2, 1, 2)), c(sqrt(2) / 2, exp(1), exp(2))) + x <- c(-10:10) + expect_equal(trans$d_transform(x), 1 / trans$d_inverse(trans$transform(x))) + expect_equal(trans$d_inverse(x), 1 / trans$d_transform(trans$inverse(x))) + + trans <- yj_trans(p = 3) + expect_equal(trans$d_transform(c(-2, -1, 1, 2)), c(1/9, 1/4, 4, 9)) + expect_equal(trans$d_inverse(c(-4, -0.5, 1)), c(1/9, 4, (1/16)^(1/3))) + expect_equal(trans$d_transform(x), 1 / trans$d_inverse(trans$transform(x))) + expect_equal(trans$d_inverse(0:10), 1 / trans$d_transform(trans$inverse(0:10))) +}) + +test_that("exp_trans derivatives work", { + trans <- exp_trans(10) + expect_equal(trans$d_transform(c(0, 1, 2)), c(1, 10, 100) * log(10)) + expect_equal(trans$d_inverse(c(0.1, 1, 10) / log(10)), c(10, 1, 0.1)) + x <- 1:10 + expect_equal(trans$d_transform(x), 1 / trans$d_inverse(trans$transform(x))) + expect_equal(trans$d_inverse(x), 1 / trans$d_transform(trans$inverse(x))) +}) + +test_that("identity_trans derivatives work", { + trans <- identity_trans() + expect_equal(trans$d_transform(numeric(0)), numeric(0)) + expect_equal(trans$d_transform(c(0, 1, 2)), c(1, 1, 1)) + expect_equal(trans$d_inverse(numeric(0)), numeric(0)) + expect_equal(trans$d_inverse(c(0, 1, 2)), c(1, 1, 1)) +}) + +test_that("log_trans derivatives work", { + trans <- log_trans(10) + expect_equal(trans$d_transform(c(0.1, 1, 10) / log(10)), c(10, 1, 0.1)) + expect_equal(trans$d_inverse(c(0, 1, 2)), c(1, 10, 100) * log(10)) + x <- 1:10 + expect_equal(trans$d_transform(x), 1 / trans$d_inverse(trans$transform(x))) + expect_equal(trans$d_inverse(x), 1 / trans$d_transform(trans$inverse(x))) +}) + +test_that("log1p_trans derivatives work", { + trans <- log1p_trans() + expect_equal(trans$d_transform(c(0, 1, 2)), c(1, 1/2, 1/3)) + expect_equal(trans$d_inverse(c(0, 1, 2)), exp(c(0, 1, 2))) + x <- 0:10 + expect_equal(trans$d_transform(x), 1 / trans$d_inverse(trans$transform(x))) + expect_equal(trans$d_inverse(x), 1 / trans$d_transform(trans$inverse(x))) +}) + +test_that("pseudo_log_trans derivatives work", { + trans <- pseudo_log_trans(0.5) + expect_equal(trans$d_transform(c(0, 1)), c(1, sqrt(2) / 2)) + expect_equal(trans$d_inverse(c(0, 1)), c(1, cosh(1))) + x <- 1:10 + expect_equal(trans$d_transform(x), 1 / trans$d_inverse(trans$transform(x))) + expect_equal(trans$d_inverse(x), 1 / trans$d_transform(trans$inverse(x))) +}) + +test_that("logit_trans derivatives work", { + trans <- logit_trans() + expect_equal(trans$d_transform(c(0.1, 0.5, 0.8)), c(100/9, 4, 6.25)) + expect_equal(trans$d_inverse(c(0, 1, 2)), dlogis(c(0, 1, 2))) + x <- seq(0.1, 0.9, length.out = 10) + expect_equal(trans$d_transform(x), 1 / trans$d_inverse(trans$transform(x))) + expect_equal(trans$d_inverse(x), 1 / trans$d_transform(trans$inverse(x))) +}) + +test_that("reciprocal_trans derivatives work", { + trans <- reciprocal_trans() + expect_equal(trans$d_transform(c(0.1, 1, 10)), c(-100, -1, -0.01)) + expect_equal(trans$d_inverse(c(0.1, 1, 10)), c(-100, -1, -0.01)) + x <- (1:20)/10 + expect_equal(trans$d_transform(x), 1 / trans$d_inverse(trans$transform(x))) + expect_equal(trans$d_inverse(x), 1 / trans$d_transform(trans$inverse(x))) +}) + +test_that("reverse_trans derivatives work", { + trans <- reverse_trans() + expect_equal(trans$d_transform(numeric(0)), numeric(0)) + expect_equal(trans$d_transform(c(-1, 1, 2)), c(-1, -1, -1)) + expect_equal(trans$d_inverse(numeric(0)), numeric(0)) + expect_equal(trans$d_inverse(c(-1, 1, 2)), c(-1, -1, -1)) +}) + +test_that("sqrt_trans derivatives work", { + trans <- sqrt_trans() + expect_equal(trans$d_transform(c(1, 4, 9)), c(1/2, 1/4, 1/6)) + expect_equal(trans$d_inverse(c(1, 2, 3)), c(2, 4, 6)) + x <- 1:10 + expect_equal(trans$d_transform(x), 1 / trans$d_inverse(trans$transform(x))) + expect_equal(trans$d_inverse(x), 1 / trans$d_transform(trans$inverse(x))) +})