Skip to content

Commit

Permalink
Merge pull request #4 from m-clark/plot_gam_3d
Browse files Browse the repository at this point in the history
add 3d plot
  • Loading branch information
m-clark authored Aug 17, 2018
2 parents 5fd5ed2 + 96e76ba commit 5557c19
Show file tree
Hide file tree
Showing 12 changed files with 225 additions and 17 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export(n_distinct)
export(plot_coefficients)
export(plot_gam)
export(plot_gam_2d)
export(plot_gam_3d)
export(plot_gam_by)
export(plot_gam_check)
export(pull)
Expand Down
118 changes: 118 additions & 0 deletions R/plot_gam_3d.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#' @title Plot 2d smooths in 3d
#' @description 3d plot of 2d smooths for generalized additive models.
#' @param model The mgcv gam model
#'
#' @param main_var The 'x' axis.
#' @param second_var The 'y' axis'
#' @param conditional_data Values for other covariates. Default is NULL see
#' details.
#' @param n_plot Points to plot. 100 (the default) works well. Embiggen at the
#' cost of your own waiting time.
#' @param dmb Whether to use plotly's display mode bar. Default is FALSE.
#' @param ... Arguments for \link[scico]{scico}
#'
#' @details This works like \link[visibly]{plot_gam_2d}, the only difference
#' being that a 3d plot is generated instead. It uses \link[scico]{scico} for
#' the palette. It is expected that the two input variables are continuous
#' @family model visualization
#' @return A plotly surface object
#' @examples
#' library(mgcv); library(visibly)
#' set.seed(0)
#'
#' d = gamSim(2, scale=.1)$data
#' mod <- gam(y ~ s(x, z), data = d)
#' plot_gam_3d(mod, main_var = x, second_var = z)
#' plot_gam_3d(mod, main_var = x, second_var = z, palette='tokyo')
#' @export
plot_gam_3d <- function(model,
main_var,
second_var,
conditional_data = NULL,
n_plot = 100,
dmb = FALSE,
...) {

if (!inherits(model, 'gam'))
stop('This function is for gam objects from mgcv')

if(missing(main_var))
stop('main_var and second_var are required.')

if(missing(second_var))
stop('main_var and second_var are required.')

model_data <- model$model

# test_second_var <- model_data %>% pull(!!enquo(second_var))
# do_by <- n_distinct(test_second_var)

mv <- rlang::enquo(main_var)
sv <- rlang::enquo(second_var)

mv_range <- range(na.omit(pull(model_data, !!mv)))
sv_range <- range(na.omit(pull(model_data, !!sv)))

cd <- data_frame(!!quo_name(mv) := seq(mv_range[1],
mv_range[2],
length.out = n_plot),
!!quo_name(sv) := seq(sv_range[1],
sv_range[2],
length.out = n_plot)) %>%
tidyr::expand(!!mv, !!sv)

data_list <-
create_prediction_data(model_data = model_data,
conditional_data = cd) %>%
mutate(prediction = predict(model, ., type = 'response'))

mv_name <- quo_name(mv)
sv_name <- quo_name(sv)

xlo <- list(
gridcolor = 'transparent',
zerolinecolor = 'transparent',
title = mv_name
)

ylo <- list(
gridcolor = 'transparent',
zerolinecolor = 'transparent',
title = sv_name
)

zlo <- list(
ticktext = '',
gridcolor = 'transparent',
zerolinecolor = 'transparent',
title = 'Prediction'
)

colnames(data_list)[1:2] = c('x', 'y')

# Sigh, but it works
pred_mat <- matrix(data_list$prediction, n_plot, n_plot)

# override plotly's default x y z labels/text
custom_txt <- paste0("Prediction: ", round(data_list$prediction, 3),
"\n", mv_name, ": ", round(data_list$x, 3),
"\n", sv_name, ": ", round(data_list$y, 3)) %>%
matrix(n_plot, n_plot)

data_list %>%
plotly::plot_ly(x = unique(.$x),
y = unique(.$y),
colors = grDevices::colorRamp(scico::scico(nrow(.), ...))) %>%
plotly::add_surface(z = ~ pred_mat,
text = custom_txt,
hoverinfo = 'text') %>%
plotly::layout(
scene = list(# scene!
xaxis = xlo,
yaxis = ylo,
zaxis = zlo
)) %>%
theme_plotly() %>%
plotly::config(displayModeBar = dmb)
}

4 changes: 2 additions & 2 deletions man/plot_coefficients.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/plot_coefficients.brmsfit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions man/plot_coefficients.lm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/plot_coefficients.merMod.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/plot_gam.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/plot_gam_2d.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

54 changes: 54 additions & 0 deletions man/plot_gam_3d.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/plot_gam_check.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 0 additions & 4 deletions tests/testthat/test_plot_gam_2d.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ test_that('plot_gam_2d fails if no second/by_var',{
expect_error(plot_gam_by(by_mod1, main_var = x2))
})

test_that('plot_gam_2d fails if not gam object',{
expect_error(plot_gam_2d(lm(y ~ x*z, d), main_var = x, second_var=z))
})

test_that('plot_gam_2d will switch to by',{
expect_message(plot_gam_2d(by_mod2, main_var = x2, second_var = fac_num))
})
Expand Down
36 changes: 36 additions & 0 deletions tests/testthat/test_plot_gam_3d.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
context('Test plot_gam_3d')

# initial prep ------------------------------------------------------------

# example taken from the mgcv plot.gam help file.
library(mgcv); library(dplyr)
set.seed(0)
d <- gamSim(2, scale=.1)$data
d$misc <- rnorm(nrow(d), mean = 50, sd = 10)
b <- gam(y ~ s(x, z), data = d)


d2 <- gamSim(4)
d2$fac_num <- as.numeric(d2$fac)
by_mod1 <- gam(y ~ s(x2, by=fac), data = d2)
by_mod2 <- gam(y ~ s(x2, by=fac_num), data = d2)



# Tests -------------------------------------------------------------------

test_that('plot_gam_3d returns a plotly object',{
expect_s3_class(plot_gam_3d(b, main_var = x, second_var = z), 'plotly')
})

test_that('plot_gam_3d fails if not gam object',{
expect_error(plot_gam_3d(lm(y ~ x*z, d), main_var = x, second_var=z))
})

test_that('plot_gam_3d fails if no main_var',{
expect_error(plot_gam_3d(b))
})

test_that('plot_gam_3d fails if no second_var',{
expect_error(plot_gam_3d(b, main_var = x))
})

0 comments on commit 5557c19

Please sign in to comment.