From b675d75b18d148c88b09bcd6022a4087a9028b1a Mon Sep 17 00:00:00 2001 From: Daniel Date: Sun, 1 Sep 2024 18:35:21 +0200 Subject: [PATCH] add weightit tests, improve linkfun/inv (#922) * add weightit tests, improve linkfun/inv * desc * code, comment * fix * fix * typo * fix * lintr * lintr * news * fixes * Update test-weightit.R * Update test-weightit.R * Update test-weightit.R * fix * wordlist * styler * fix --- DESCRIPTION | 3 + NEWS.md | 6 + R/find_parameters.R | 2 +- R/find_statistic.R | 4 +- R/link_function.R | 70 +++++----- R/link_inverse.R | 38 +++++- inst/WORDLIST | 1 + tests/testthat/test-brms.R | 6 +- tests/testthat/test-lm.R | 41 +++--- tests/testthat/test-weightit.R | 226 +++++++++++++++++++++++++++++++++ 10 files changed, 323 insertions(+), 74 deletions(-) create mode 100644 tests/testthat/test-weightit.R diff --git a/DESCRIPTION b/DESCRIPTION index f2d1be0d29..4ebbbaadce 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -104,6 +104,7 @@ Suggests: censReg, cgam, clubSandwich, + cobalt, coxme, cplm, crch, @@ -116,6 +117,7 @@ Suggests: feisr, fixest (>= 0.11.2), fungible, + fwb, gam, gamlss, gamlss.data, @@ -204,6 +206,7 @@ Suggests: truncreg, tweedie, VGAM, + WeightIt, withr VignetteBuilder: knitr diff --git a/NEWS.md b/NEWS.md index 74547f7b9a..cff3fe5a1f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -11,10 +11,16 @@ `df[, 5]`, are used as response variable in the formula, as this can lead to unexpected results. +* Minor improvements to `link_function()` and `link_inverse()`. + ## Bug fixes * Fixed regression from latest fix related to `get_variance()` for *brms* models. +* Fixed issue in `link_function()` and `link_inverse()` for models of class + *cglm* with `"identity"` link, which was not correctly recognized due to a + typo. + # insight 0.20.3 ## Changes diff --git a/R/find_parameters.R b/R/find_parameters.R index cc9f0e4b8e..7b6e788335 100644 --- a/R/find_parameters.R +++ b/R/find_parameters.R @@ -177,7 +177,7 @@ find_parameters.brmultinom <- find_parameters.multinom find_parameters.multinom_weightit <- function(x, flatten = FALSE, ...) { params <- stats::coef(x) resp <- gsub("(.*)~(.*)", "\\1", names(params)) - pars <- gsub("(.*)~(.*)", "\\2", names(params))[resp == resp[1]] + pars <- list(conditional = gsub("(.*)~(.*)", "\\2", names(params))[resp == resp[1]]) if (flatten) { unique(unlist(pars, use.names = FALSE)) diff --git a/R/find_statistic.R b/R/find_statistic.R index 7bec17509e..a2319da331 100644 --- a/R/find_statistic.R +++ b/R/find_statistic.R @@ -122,7 +122,7 @@ find_statistic <- function(x, ...) { "ergm", "feglm", "flexsurvreg", "gee", "ggcomparisons", "glimML", "glmm", "glmmadmb", "glmmFit", "glmmLasso", - "glmmTMB", "glmx", "gmnl", "glmgee", "glm_weightit", + "glmmTMB", "glmx", "gmnl", "glmgee", "hurdle", "lavaan", "loggammacenslmrob", "logitmfx", "logitor", "logitr", "LORgee", "lrm", "margins", "marginaleffects", "marginaleffects.summary", "metaplus", "mixor", @@ -175,7 +175,7 @@ find_statistic <- function(x, ...) { "bam", "bigglm", "cgam", "cgamm", "eglm", "emmGrid", "emm_list", - "gam", "glm", "Glm", "glmc", "glmerMod", "glmRob", "glmrob", + "gam", "glm", "Glm", "glmc", "glmerMod", "glmRob", "glmrob", "glm_weightit", "pseudoglm", "scam", "speedglm" diff --git a/R/link_function.R b/R/link_function.R index adba3395b6..a696c6c98a 100644 --- a/R/link_function.R +++ b/R/link_function.R @@ -36,28 +36,37 @@ link_function.default <- function(x, ...) { x <- x$gam class(x) <- c(class(x), c("glm", "lm")) } + .extract_generic_linkfun(x) +} - tryCatch( - { - # get model family - ff <- .gam_family(x) +.extract_generic_linkfun <- function(x, default_link = NULL) { + # general approach + out <- .safe(stats::family(x)$linkfun) + # if it fails, try to retrieve from model information + if (is.null(out)) { + # get model family, consider special gam-case + ff <- .gam_family(x) + if ("linkfun" %in% names(ff)) { # return link function, if exists - if ("linkfun" %in% names(ff)) { - return(ff$linkfun) - } - + out <- ff$linkfun + } else if ("link" %in% names(ff) && is.character(ff$link)) { # else, create link function from link-string - if ("link" %in% names(ff)) { - return(match.fun(ff$link)) + out <- .safe(stats::make.link(link = ff$link)$linkfun) + # or match the function - for "exp()", make.link() won't work + if (is.null(out)) { + out <- .safe(match.fun(ff$link)) } - - NULL - }, - error = function(x) { - NULL } - ) + } + # if all fails, force default link + if (is.null(out) && !is.null(default_link)) { + out <- switch(default_link, + identity = .safe(stats::gaussian(link = "identity")$linkfun), + .safe(stats::make.link(link = default_link)$linkfun) + ) + } + out } @@ -66,7 +75,7 @@ link_function.default <- function(x, ...) { #' @export link_function.lm <- function(x, ...) { - stats::gaussian(link = "identity")$linkfun + .extract_generic_linkfun(x, "identity") } #' @export @@ -202,7 +211,7 @@ link_function.nestedLogit <- function(x, ...) { #' @export link_function.multinom <- function(x, ...) { - stats::make.link(link = "logit")$linkfun + .extract_generic_linkfun(x, "logit") } #' @export @@ -481,7 +490,7 @@ link_function.cglm <- function(x, ...) { method <- parse(text = safe_deparse(x$call))[[1]]$method if (!is.null(method) && method == "clm") { - link <- "identiy" + link <- "identity" } stats::make.link(link = link)$linkfun } @@ -544,28 +553,7 @@ link_function.bcplm <- link_function.cpglmm #' @export link_function.gam <- function(x, ...) { - lf <- tryCatch( - { - # get model family - ff <- .gam_family(x) - - # return link function, if exists - if ("linkfun" %in% names(ff)) { - return(ff$linkfun) - } - - # else, create link function from link-string - if ("link" %in% names(ff)) { - return(match.fun(ff$link)) - } - - NULL - }, - error = function(x) { - NULL - } - ) - + lf <- .extract_generic_linkfun(x) if (is.null(lf)) { mi <- .gam_family(x) if (object_has_names(mi, "linfo")) { diff --git a/R/link_inverse.R b/R/link_inverse.R index 87b58d5be2..44de28207f 100644 --- a/R/link_inverse.R +++ b/R/link_inverse.R @@ -43,17 +43,45 @@ link_inverse.default <- function(x, ...) { if (inherits(x, "Zelig-relogit")) { stats::make.link(link = "logit")$linkinv } else { - .safe(stats::family(x)$linkinv) + .extract_generic_linkinv(x) } } +.extract_generic_linkinv <- function(x, default_link = NULL) { + # general approach + out <- .safe(stats::family(x)$linkinv) + # if it fails, try to retrieve from model information + if (is.null(out)) { + # get model family, consider special gam-case + ff <- .gam_family(x) + if ("linkfun" %in% names(ff)) { + # return link function, if exists + out <- ff$linkinv + } else if ("link" %in% names(ff) && is.character(ff$link)) { + # else, create link function from link-string + out <- .safe(stats::make.link(link = ff$link)$linkinv) + # or match the function - for "exp()", make.link() won't work + if (is.null(out)) { + out <- .safe(match.fun(ff$link)) + } + } + } + # if all fails, force default link + if (is.null(out) && !is.null(default_link)) { + out <- switch(default_link, + identity = .safe(stats::gaussian(link = "identity")$linkinv), + .safe(stats::make.link(link = default_link)$linkinv) + ) + } + out +} # GLM families --------------------------------------------------- #' @export link_inverse.glm <- function(x, ...) { - tryCatch(stats::family(x)$linkinv, error = function(x) NULL) + .extract_generic_linkinv(x, "logit") } #' @export @@ -96,7 +124,7 @@ link_inverse.flexsurvreg <- function(x, ...) { #' @export link_inverse.lm <- function(x, ...) { - stats::gaussian(link = "identity")$linkinv + .extract_generic_linkinv(x, "identity") } #' @export @@ -239,7 +267,7 @@ link_inverse.DirichletRegModel <- function(x, what = c("mean", "precision"), ... #' @export link_inverse.gmnl <- function(x, ...) { - stats::make.link("logit")$linkinv + .extract_generic_linkinv(x, "logit") } #' @export @@ -434,7 +462,7 @@ link_inverse.cglm <- function(x, ...) { method <- parse(text = safe_deparse(x$call))[[1]]$method if (!is.null(method) && method == "clm") { - link <- "identiy" + link <- "identity" } stats::make.link(link = link)$linkinv } diff --git a/inst/WORDLIST b/inst/WORDLIST index 64b39d6556..7e86f02c27 100644 --- a/inst/WORDLIST +++ b/inst/WORDLIST @@ -69,6 +69,7 @@ brms brmsfit btergm ci +cglm cloglog clubSandwich cmprsk diff --git a/tests/testthat/test-brms.R b/tests/testthat/test-brms.R index 5119006914..36891dd729 100644 --- a/tests/testthat/test-brms.R +++ b/tests/testthat/test-brms.R @@ -884,14 +884,14 @@ test_that("clean_parameters", { test_that("get_modelmatrix", { out <- get_modelmatrix(m1) expect_identical(dim(out), c(236L, 4L)) - m9 <- insight::download_model("brms_mo2") + m9 <- suppressWarnings(insight::download_model("brms_mo2")) skip_if(is.null(m9)) out <- get_modelmatrix(m9) expect_identical(dim(out), c(32L, 2L)) }) test_that("get_modelmatrix", { - m10 <- insight::download_model("brms_lf_1") + m10 <- suppressWarnings(insight::download_model("brms_lf_1")) expect_identical( find_variables(m10), list( @@ -903,7 +903,7 @@ test_that("get_modelmatrix", { # get variance test_that("get_variance works", { - mdl <- insight::download_model("brms_mixed_9") + mdl <- suppressWarnings(insight::download_model("brms_mixed_9")) out <- get_variance(mdl) expect_equal( out, diff --git a/tests/testthat/test-lm.R b/tests/testthat/test-lm.R index 98a1980e1c..6d89e29f13 100644 --- a/tests/testthat/test-lm.R +++ b/tests/testthat/test-lm.R @@ -75,13 +75,10 @@ test_that("get_df", { }) test_that("get_data", { - expect_equal(nrow(get_data(m1)), 150) - expect_equal( - colnames(get_data(m1)), - c("Sepal.Length", "Petal.Width", "Species") - ) - expect_equal(nrow(get_data(m2)), 32) - expect_equal(colnames(get_data(m2)), c("mpg", "hp", "cyl", "wt")) + expect_identical(nrow(get_data(m1)), 150L) + expect_named(get_data(m1), c("Sepal.Length", "Petal.Width", "Species")) + expect_identical(nrow(get_data(m2)), 32L) + expect_named(get_data(m2), c("mpg", "hp", "cyl", "wt")) }) test_that("get_intercept", { @@ -110,14 +107,14 @@ test_that("find_formula", { }) test_that("find_terms", { - expect_equal( + expect_identical( find_terms(m1), list( response = "Sepal.Length", conditional = c("Petal.Width", "Species") ) ) - expect_equal( + expect_identical( find_terms(m2), list( response = "log(mpg)", @@ -129,11 +126,11 @@ test_that("find_terms", { ) ) ) - expect_equal( + expect_identical( find_terms(m1, flatten = TRUE), c("Sepal.Length", "Petal.Width", "Species") ) - expect_equal( + expect_identical( find_terms(m2, flatten = TRUE), c( "log(mpg)", @@ -146,29 +143,29 @@ test_that("find_terms", { }) test_that("find_variables", { - expect_equal( + expect_identical( find_variables(m1), list( response = "Sepal.Length", conditional = c("Petal.Width", "Species") ) ) - expect_equal(find_variables(m2), list( + expect_identical(find_variables(m2), list( response = "mpg", conditional = c("hp", "cyl", "wt") )) - expect_equal( + expect_identical( find_variables(m1, flatten = TRUE), c("Sepal.Length", "Petal.Width", "Species") ) - expect_equal( + expect_identical( find_variables(m2, flatten = TRUE), c("mpg", "hp", "cyl", "wt") ) }) test_that("find_parameters", { - expect_equal( + expect_identical( find_parameters(m1), list( conditional = c( @@ -179,8 +176,8 @@ test_that("find_parameters", { ) ) ) - expect_equal(nrow(get_parameters(m1)), 4) - expect_equal( + expect_identical(nrow(get_parameters(m1)), 4L) + expect_identical( get_parameters(m1)$Parameter, c( "(Intercept)", @@ -194,7 +191,7 @@ test_that("find_parameters", { test_that("find_parameters summary.lm", { s <- summary(m1) - expect_equal( + expect_identical( find_parameters(s), list( conditional = c( @@ -213,7 +210,7 @@ test_that("linkfun", { }) test_that("find_algorithm", { - expect_equal(find_algorithm(m1), list(algorithm = "OLS")) + expect_identical(find_algorithm(m1), list(algorithm = "OLS")) }) test_that("get_variance", { @@ -235,7 +232,7 @@ test_that("all_models_equal", { }) test_that("get_varcov", { - expect_equal(diag(get_varcov(m1)), diag(vcov(m1))) + expect_equal(diag(get_varcov(m1)), diag(vcov(m1)), tolerance = 1e-5) }) test_that("get_statistic", { @@ -243,7 +240,7 @@ test_that("get_statistic", { }) test_that("find_statistic", { - expect_equal(find_statistic(m1), "t-statistic") + expect_identical(find_statistic(m1), "t-statistic") }) diff --git a/tests/testthat/test-weightit.R b/tests/testthat/test-weightit.R new file mode 100644 index 0000000000..b3cfcfc678 --- /dev/null +++ b/tests/testthat/test-weightit.R @@ -0,0 +1,226 @@ +skip_on_cran() +skip_if_not_installed("WeightIt") +skip_if_not_installed("cobalt") +skip_if_not_installed("fwb") + +data("lalonde", package = "cobalt") + +# Logistic regression ATT weights +w.out <- WeightIt::weightit( + treat ~ age + educ + married + re74, + data = lalonde, + method = "glm", + estimand = "ATT" +) +set.seed(123) +fit3 <- WeightIt::lm_weightit( + re78 ~ treat + age + educ, + data = lalonde, + weightit = w.out, + vcov = "FWB", + R = 50, # should use way more + fwb.args = list(wtype = "mammen") +) + +# Multinomial logistic regression outcome model +# that adjusts for estimation of weights +lalonde$re78_3 <- factor(findInterval(lalonde$re78, c(0, 5e3, 1e4))) + +fit4 <- WeightIt::multinom_weightit( + re78_3 ~ treat + age + educ, + data = lalonde, + weightit = w.out +) + +# Ordinal probit regression that adjusts for estimation +# of weights +fit5 <- WeightIt::ordinal_weightit( + ordered(re78_3) ~ treat + age + educ, + data = lalonde, + link = "probit", + weightit = w.out +) + + +test_that("model_info", { + expect_true(model_info(fit3)$is_linear) + expect_true(model_info(fit4)$is_multinomial) + expect_true(model_info(fit5)$is_ordinal) +}) + +test_that("get_residuals", { + expect_equal( + head(get_residuals(fit3)), + head(stats::residuals(fit3)), + tolerance = 1e-3, + ignore_attr = TRUE + ) + expect_equal( + head(get_residuals(fit4)), + head(stats::residuals(fit4)), + tolerance = 1e-3, + ignore_attr = TRUE + ) + expect_equal( + head(get_residuals(fit5)), + head(stats::residuals(fit5)), + tolerance = 1e-3, + ignore_attr = TRUE + ) +}) + +test_that("get_sigma", { + expect_equal(get_sigma(fit3), 5391.306, tolerance = 1e-2, ignore_attr = TRUE) + expect_equal(get_sigma(fit4), 0.4720903, tolerance = 1e-2, ignore_attr = TRUE) + expect_equal(get_sigma(fit5), 0.4753789, tolerance = 1e-2, ignore_attr = TRUE) +}) + +test_that("find_predictors", { + expect_identical(find_predictors(fit3), list(conditional = c("treat", "age", "educ"))) + expect_null(find_predictors(fit3, effects = "random")) + expect_identical(find_predictors(fit4), list(conditional = c("treat", "age", "educ"))) + expect_null(find_predictors(fit4, effects = "random")) + expect_identical(find_predictors(fit5), list(conditional = c("treat", "age", "educ"))) + expect_null(find_predictors(fit5, effects = "random")) +}) + +test_that("find_response", { + expect_identical(find_response(fit3), "re78") + expect_identical(find_response(fit4), "re78_3") + expect_identical(find_response(fit5), "re78_3") +}) + +test_that("link_inverse", { + expect_equal(link_inverse(fit3)(0.2), 0.2, tolerance = 1e-3) + expect_equal(link_inverse(fit4)(0.2), plogis(0.2), tolerance = 1e-3) + expect_equal(link_inverse(fit5)(0.2), 0.5792597, tolerance = 1e-3) # probit +}) + +test_that("link_function", { + expect_equal(link_function(fit3)(0.2), 0.2, tolerance = 1e-3) + expect_equal(link_function(fit4)(0.2), qlogis(0.2), tolerance = 1e-3) + expect_equal(link_function(fit5)(0.2), -0.8416212, tolerance = 1e-3) # probit +}) + +test_that("loglik", { + expect_equal(get_loglikelihood(fit3), -6361.52, tolerance = 1e-2, ignore_attr = TRUE) +}) + +test_that("get_df", { + expect_equal(get_df(fit3), df.residual(fit3), ignore_attr = TRUE) + expect_equal(get_df(fit4), df.residual(fit4), ignore_attr = TRUE) + expect_equal(get_df(fit5), df.residual(fit5), ignore_attr = TRUE) + expect_equal(get_df(fit3, type = "model"), 5, ignore_attr = TRUE) + expect_equal(get_df(fit4, type = "model"), 4, ignore_attr = TRUE) + expect_equal(get_df(fit5, type = "model"), 5, ignore_attr = TRUE) +}) + +test_that("get_data", { + expect_equal(nrow(get_data(fit3)), 614, ignore_attr = TRUE) + expect_named(get_data(fit3), c("re78", "treat", "age", "educ")) + expect_equal(nrow(get_data(fit4)), 614, ignore_attr = TRUE) + expect_named(get_data(fit5), c("re78_3", "treat", "age", "educ")) +}) + +test_that("get_intercept", { + expect_equal(get_intercept(fit3), as.vector(stats::coef(fit3)[1]), ignore_attr = TRUE) + expect_equal(get_intercept(fit4), as.vector(stats::coef(fit4)[c(1, 5)]), ignore_attr = TRUE) + expect_true(is.na(get_intercept(fit5))) +}) + +test_that("find_formula", { + expect_length(find_formula(fit3), 1) + expect_equal( + find_formula(fit3), + list(conditional = as.formula("re78 ~ treat + age + educ")), + ignore_attr = TRUE + ) + expect_equal( + find_formula(fit4), + list(conditional = as.formula("re78_3 ~ treat + age + educ")), + ignore_attr = TRUE + ) + expect_equal( + find_formula(fit5), + list(conditional = as.formula("ordered(re78_3) ~ treat + age + educ")), + ignore_attr = TRUE + ) +}) + +test_that("find_terms", { + expect_identical( + find_terms(fit3), + list( + response = "re78", + conditional = c("treat", "age", "educ") + ) + ) + expect_identical( + find_terms(fit4), + list( + response = "re78_3", + conditional = c("treat", "age", "educ") + ) + ) + expect_identical( + find_terms(fit5), + list( + response = "ordered(re78_3)", + conditional = c("treat", "age", "educ") + ) + ) +}) + +test_that("find_parameters", { + expect_identical( + find_parameters(fit3), + list(conditional = c("(Intercept)", "treat", "age", "educ")) + ) + expect_identical( + find_parameters(fit4), + list(conditional = c("(Intercept)", "treat", "age", "educ")) + ) + expect_identical( + find_parameters(fit5), + list(conditional = c("treat", "age", "educ", "1|2", "2|3")) + ) + expect_identical(nrow(get_parameters(fit3)), 4L) + expect_identical(nrow(get_parameters(fit4)), 8L) + expect_identical(nrow(get_parameters(fit5)), 5L) +}) + +test_that("is_model", { + expect_true(is_model(fit3)) + expect_true(is_model(fit4)) + expect_true(is_model(fit5)) +}) + +test_that("get_varcov", { + expect_equal(diag(get_varcov(fit3)), diag(vcov(fit3)), tolerance = 1e-5) + expect_equal(diag(get_varcov(fit4)), diag(vcov(fit4)), tolerance = 1e-5) + expect_equal(diag(get_varcov(fit5)), diag(vcov(fit5)), tolerance = 1e-5) +}) + +test_that("get_statistic", { + expect_equal( + get_statistic(fit3)$Statistic, + c(0.17184, 1.33867, 0.06674, 3.15887), + tolerance = 1e-3 + ) + expect_equal( + get_statistic(fit4)$Statistic, + c(0.00271, 0.31156, -2.37521, -0.33413, -4.22954, 0.67165, -0.01399, 3.51396), + tolerance = 1e-3 + ) + expect_equal( + get_statistic(fit5)$Statistic, + c(0.60512, -0.65645, 2.98951, 2.54289, 4.76061), + tolerance = 1e-3 + ) +}) + +test_that("find_statistic", { + expect_identical(find_statistic(fit3), "t-statistic") + expect_identical(find_statistic(fit4), "z-statistic") + expect_identical(find_statistic(fit5), "z-statistic") +})