Skip to content

Commit

Permalink
add weightit tests, improve linkfun/inv (#922)
Browse files Browse the repository at this point in the history
* add weightit tests, improve linkfun/inv

* desc

* code, comment

* fix

* fix

* typo

* fix

* lintr

* lintr

* news

* fixes

* Update test-weightit.R

* Update test-weightit.R

* Update test-weightit.R

* fix

* wordlist

* styler

* fix
  • Loading branch information
strengejacke authored Sep 1, 2024
1 parent 07b1aa3 commit b675d75
Show file tree
Hide file tree
Showing 10 changed files with 323 additions and 74 deletions.
3 changes: 3 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ Suggests:
censReg,
cgam,
clubSandwich,
cobalt,
coxme,
cplm,
crch,
Expand All @@ -116,6 +117,7 @@ Suggests:
feisr,
fixest (>= 0.11.2),
fungible,
fwb,
gam,
gamlss,
gamlss.data,
Expand Down Expand Up @@ -204,6 +206,7 @@ Suggests:
truncreg,
tweedie,
VGAM,
WeightIt,
withr
VignetteBuilder:
knitr
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@
`df[, 5]`, are used as response variable in the formula, as this can lead to
unexpected results.

* Minor improvements to `link_function()` and `link_inverse()`.

## Bug fixes

* Fixed regression from latest fix related to `get_variance()` for *brms* models.

* Fixed issue in `link_function()` and `link_inverse()` for models of class
*cglm* with `"identity"` link, which was not correctly recognized due to a
typo.

# insight 0.20.3

## Changes
Expand Down
2 changes: 1 addition & 1 deletion R/find_parameters.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ find_parameters.brmultinom <- find_parameters.multinom
find_parameters.multinom_weightit <- function(x, flatten = FALSE, ...) {
params <- stats::coef(x)
resp <- gsub("(.*)~(.*)", "\\1", names(params))
pars <- gsub("(.*)~(.*)", "\\2", names(params))[resp == resp[1]]
pars <- list(conditional = gsub("(.*)~(.*)", "\\2", names(params))[resp == resp[1]])

if (flatten) {
unique(unlist(pars, use.names = FALSE))
Expand Down
4 changes: 2 additions & 2 deletions R/find_statistic.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ find_statistic <- function(x, ...) {
"ergm",
"feglm", "flexsurvreg",
"gee", "ggcomparisons", "glimML", "glmm", "glmmadmb", "glmmFit", "glmmLasso",
"glmmTMB", "glmx", "gmnl", "glmgee", "glm_weightit",
"glmmTMB", "glmx", "gmnl", "glmgee",
"hurdle",
"lavaan", "loggammacenslmrob", "logitmfx", "logitor", "logitr", "LORgee", "lrm",
"margins", "marginaleffects", "marginaleffects.summary", "metaplus", "mixor",
Expand Down Expand Up @@ -175,7 +175,7 @@ find_statistic <- function(x, ...) {
"bam", "bigglm",
"cgam", "cgamm",
"eglm", "emmGrid", "emm_list",
"gam", "glm", "Glm", "glmc", "glmerMod", "glmRob", "glmrob",
"gam", "glm", "Glm", "glmc", "glmerMod", "glmRob", "glmrob", "glm_weightit",
"pseudoglm",
"scam",
"speedglm"
Expand Down
70 changes: 29 additions & 41 deletions R/link_function.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,28 +36,37 @@ link_function.default <- function(x, ...) {
x <- x$gam
class(x) <- c(class(x), c("glm", "lm"))
}
.extract_generic_linkfun(x)
}

tryCatch(
{
# get model family
ff <- .gam_family(x)

.extract_generic_linkfun <- function(x, default_link = NULL) {
# general approach
out <- .safe(stats::family(x)$linkfun)
# if it fails, try to retrieve from model information
if (is.null(out)) {
# get model family, consider special gam-case
ff <- .gam_family(x)
if ("linkfun" %in% names(ff)) {
# return link function, if exists
if ("linkfun" %in% names(ff)) {
return(ff$linkfun)
}

out <- ff$linkfun
} else if ("link" %in% names(ff) && is.character(ff$link)) {
# else, create link function from link-string
if ("link" %in% names(ff)) {
return(match.fun(ff$link))
out <- .safe(stats::make.link(link = ff$link)$linkfun)
# or match the function - for "exp()", make.link() won't work
if (is.null(out)) {
out <- .safe(match.fun(ff$link))
}

NULL
},
error = function(x) {
NULL
}
)
}
# if all fails, force default link
if (is.null(out) && !is.null(default_link)) {
out <- switch(default_link,
identity = .safe(stats::gaussian(link = "identity")$linkfun),
.safe(stats::make.link(link = default_link)$linkfun)
)
}
out
}


Expand All @@ -66,7 +75,7 @@ link_function.default <- function(x, ...) {

#' @export
link_function.lm <- function(x, ...) {
stats::gaussian(link = "identity")$linkfun
.extract_generic_linkfun(x, "identity")
}

#' @export
Expand Down Expand Up @@ -202,7 +211,7 @@ link_function.nestedLogit <- function(x, ...) {

#' @export
link_function.multinom <- function(x, ...) {
stats::make.link(link = "logit")$linkfun
.extract_generic_linkfun(x, "logit")
}

#' @export
Expand Down Expand Up @@ -481,7 +490,7 @@ link_function.cglm <- function(x, ...) {
method <- parse(text = safe_deparse(x$call))[[1]]$method

if (!is.null(method) && method == "clm") {
link <- "identiy"
link <- "identity"
}
stats::make.link(link = link)$linkfun
}
Expand Down Expand Up @@ -544,28 +553,7 @@ link_function.bcplm <- link_function.cpglmm

#' @export
link_function.gam <- function(x, ...) {
lf <- tryCatch(
{
# get model family
ff <- .gam_family(x)

# return link function, if exists
if ("linkfun" %in% names(ff)) {
return(ff$linkfun)
}

# else, create link function from link-string
if ("link" %in% names(ff)) {
return(match.fun(ff$link))
}

NULL
},
error = function(x) {
NULL
}
)

lf <- .extract_generic_linkfun(x)
if (is.null(lf)) {
mi <- .gam_family(x)
if (object_has_names(mi, "linfo")) {
Expand Down
38 changes: 33 additions & 5 deletions R/link_inverse.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,45 @@ link_inverse.default <- function(x, ...) {
if (inherits(x, "Zelig-relogit")) {
stats::make.link(link = "logit")$linkinv
} else {
.safe(stats::family(x)$linkinv)
.extract_generic_linkinv(x)
}
}

.extract_generic_linkinv <- function(x, default_link = NULL) {
# general approach
out <- .safe(stats::family(x)$linkinv)
# if it fails, try to retrieve from model information
if (is.null(out)) {
# get model family, consider special gam-case
ff <- .gam_family(x)
if ("linkfun" %in% names(ff)) {
# return link function, if exists
out <- ff$linkinv
} else if ("link" %in% names(ff) && is.character(ff$link)) {
# else, create link function from link-string
out <- .safe(stats::make.link(link = ff$link)$linkinv)
# or match the function - for "exp()", make.link() won't work
if (is.null(out)) {
out <- .safe(match.fun(ff$link))
}
}
}
# if all fails, force default link
if (is.null(out) && !is.null(default_link)) {
out <- switch(default_link,
identity = .safe(stats::gaussian(link = "identity")$linkinv),
.safe(stats::make.link(link = default_link)$linkinv)
)
}
out
}


# GLM families ---------------------------------------------------

#' @export
link_inverse.glm <- function(x, ...) {
tryCatch(stats::family(x)$linkinv, error = function(x) NULL)
.extract_generic_linkinv(x, "logit")
}

#' @export
Expand Down Expand Up @@ -96,7 +124,7 @@ link_inverse.flexsurvreg <- function(x, ...) {

#' @export
link_inverse.lm <- function(x, ...) {
stats::gaussian(link = "identity")$linkinv
.extract_generic_linkinv(x, "identity")
}

#' @export
Expand Down Expand Up @@ -239,7 +267,7 @@ link_inverse.DirichletRegModel <- function(x, what = c("mean", "precision"), ...

#' @export
link_inverse.gmnl <- function(x, ...) {
stats::make.link("logit")$linkinv
.extract_generic_linkinv(x, "logit")
}

#' @export
Expand Down Expand Up @@ -434,7 +462,7 @@ link_inverse.cglm <- function(x, ...) {
method <- parse(text = safe_deparse(x$call))[[1]]$method

if (!is.null(method) && method == "clm") {
link <- "identiy"
link <- "identity"
}
stats::make.link(link = link)$linkinv
}
Expand Down
1 change: 1 addition & 0 deletions inst/WORDLIST
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ brms
brmsfit
btergm
ci
cglm
cloglog
clubSandwich
cmprsk
Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/test-brms.R
Original file line number Diff line number Diff line change
Expand Up @@ -884,14 +884,14 @@ test_that("clean_parameters", {
test_that("get_modelmatrix", {
out <- get_modelmatrix(m1)
expect_identical(dim(out), c(236L, 4L))
m9 <- insight::download_model("brms_mo2")
m9 <- suppressWarnings(insight::download_model("brms_mo2"))
skip_if(is.null(m9))
out <- get_modelmatrix(m9)
expect_identical(dim(out), c(32L, 2L))
})

test_that("get_modelmatrix", {
m10 <- insight::download_model("brms_lf_1")
m10 <- suppressWarnings(insight::download_model("brms_lf_1"))
expect_identical(
find_variables(m10),
list(
Expand All @@ -903,7 +903,7 @@ test_that("get_modelmatrix", {

# get variance
test_that("get_variance works", {
mdl <- insight::download_model("brms_mixed_9")
mdl <- suppressWarnings(insight::download_model("brms_mixed_9"))
out <- get_variance(mdl)
expect_equal(
out,
Expand Down
Loading

0 comments on commit b675d75

Please sign in to comment.