Skip to content

Commit

Permalink
Merge pull request #2 from jvivesb/main
Browse files Browse the repository at this point in the history
Fixed minor bugs and added corr plot function and weight draws function
  • Loading branch information
ignacio82 authored Nov 9, 2021
2 parents 7bf3e4d + f8871f6 commit 8566e15
Showing 1 changed file with 74 additions and 2 deletions.
76 changes: 74 additions & 2 deletions R/factory.R
Original file line number Diff line number Diff line change
Expand Up @@ -856,9 +856,9 @@ bayesianSynth <- R6::R6Class(
quantity = TRUE,
tense = "past",
display_mode_name = TRUE,
title = "Contrafactual Lift",
title = "Counterfactual Lift",
xlab = "Effect of the Intervention",
units = "the contrafactual lift",
units = "the counterfactual lift",
...
) %>% return()
},
Expand Down Expand Up @@ -942,6 +942,78 @@ bayesianSynth <- R6::R6Class(
colors = c("#4daf4a", "#377eb8", "#e41a1c"),
...
) %>% return()
},
# TODO(jvives): finish implicit weight plots function to work generally
#' @description
#' Plot implicit weight distribution across draws.
#' @return ggplot object with weight distribution per unit.
weightDraws = function(){
betas <- private$fitted %>%
as.data.frame() %>%
dplyr::select(contains('beta'))

beta_names <- private$data %>%
dplyr::filter(!!private$treated == 0) %>%
dplyr::select(!!private$id) %>%
unique() %>% pull()

treated_name <- private$data %>%
dplyr::filter(!!private$treated == 1) %>%
dplyr::select(!!private$id) %>%
unique() %>% pull()

donor_names <- beta_names[beta_names %in% treated_name == FALSE]

names(betas) <- donor_names
melt_betas <- tidyr::gather(betas, ID, weight)

melt_betas %>% ggplot2::ggplot(ggplot2::aes(x=weight, y=ID, fill=ID)) +
ggridges::geom_density_ridges() +
ggplot2::theme_minimal() +
ggplot2::theme(legend.position = "none") %>% return()
},
## TODO(jvives): finish correlation plot function to work generally
#' @description
#' Plots correlations between weights across draws.
#' @return ggplot heatmap object with correlations.
weightCorr = function(){
betas <- private$fitted %>%
as.data.frame() %>%
dplyr::select(contains('beta'))

beta_names <- private$data %>%
dplyr::filter(!!private$treated == 0) %>%
dplyr::select(!!private$id) %>%
unique() %>% pull()

treated_name <- private$data %>%
dplyr::filter(!!private$treated == 1) %>%
dplyr::select(!!private$id) %>%
unique() %>% pull()

donor_names <- beta_names[beta_names %in% treated_name == FALSE]

names(betas) <- donor_names

cormat <- round(cor(betas),3)
diag(cormat) <- NA

melted_cormat <- melt(cormat)
ggplot2::ggplot(data = melted_cormat,
ggplot2::aes(x=X1,
y=X2,
fill=value)) +
ggplot2::xlab('Units') +
ggplot2::ylab('Units') +
ggplot2::labs(fill = 'Corr') +
ggplot2::geom_tile() +
ggplot2::scale_fill_gradient2(low='red',
mid = 'white',
high='green') +
ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90,
vjust = 0.5,
hjust=1)) %>%
return()
}
)
)

0 comments on commit 8566e15

Please sign in to comment.