Skip to content

Commit

Permalink
Merge pull request #18 from jcierocki/validation
Browse files Browse the repository at this point in the history
Ranger models automation and validation
  • Loading branch information
jcierocki authored May 10, 2020
2 parents 6fe28ec + 7294c7c commit 6fd060d
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 50 deletions.
Binary file added data/fitted_models.RDS
Binary file not shown.
Binary file added data/predictions.RDS
Binary file not shown.
4 changes: 2 additions & 2 deletions dataset_prep.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ data2 <- data2 %>%

data2 %>% filter_vars_by_iv(significance_thres = 0.01) %>%
initial_split(prop = 0.75) %>%
saveRDS("data/split_raw.RDS")
write_rds("data/split_raw.RDS", compress = "gz2")

data3 <- data2 %>%
factorize(bin_methods = "tree") %>%
as_tibble() %>%
filter_vars_by_iv(significance_thres = 0.01)

dataset_split <- data3 %>% initial_split(prop = 0.75) %>% saveRDS("data/split.RDS")
dataset_split <- data3 %>% initial_split(prop = 0.75) %>% write_rds("data/split.RDS", compress = "gz2")

rm(list = ls())
Binary file added figures/conf_matrix1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figures/conf_matrix2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figures/metrics.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figures/roc_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figures/roc_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figures/vip_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figures/vip_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 16 additions & 0 deletions funs_valid.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#### validation and visualisation automating functions

get_all_metrics <- function(pred_dfs, spec = 1L) {
f <- function(df) metrics(df, Exited, .pred_class, .pred_No)
model1_metrics <- f(pred_dfs[[1]][[spec]])

all_metrics <- do.call(
function(...) bind_cols(model1_metrics, ...),
pred_dfs %>% select(-1) %>% map(~ f(.x[[spec]])$.estimate)) %>%
rename(model_1 = .estimate) %>%
return
}

exportable_conf_matrix <- function(df) {
conf_mat(df, Exited, .pred_class)$table
}
79 changes: 31 additions & 48 deletions rand_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,52 +10,35 @@ rm(list = ls())

# source("dataset_prep.R")

dataset_split1 <- readRDS("data/split.RDS")
dataset_split2 <- readRDS("data/split_raw.RDS")

df_train1 <- dataset_split1 %>% training()
df_test1 <- dataset_split1 %>% testing()
df_train2 <- dataset_split2 %>% training()
df_test2 <- dataset_split2 %>% testing()

ranger_model_specs <- rand_forest("classification", 2, 1000, 5) %>%
# set_engine("ranger", num.threads = 8, replace = F, sample.fraction = 0.8, importance = "impurity") %>%
set_engine("ranger", num.threads = 8, replace = F, sample.fraction = 0.8, importance = "permutation", local.importance = T)

ranger_model_1 <- ranger_model_specs %>% fit(Exited ~ ., data = df_train1)

ranger_model_2 <- ranger_model_specs %>% fit(Exited ~ ., data = df_train2)
dataset_splits <- list(
read_rds("data/split.RDS"),
read_rds("data/split_raw.RDS")
)

testing_sets <- dataset_splits %>% map(~ .x %>% testing())

models_specs <- list(
rand_forest("classification", 2, 1000, 5) %>%
# set_engine("ranger", num.threads = 8, replace = F, sample.fraction = 0.8, importance = "impurity") %>%
set_engine("ranger", num.threads = 8, replace = F, sample.fraction = 0.8, importance = "permutation", local.importance = T)
)

spec_names <- str_c("model_", 1:length(dataset_splits))
fitted_models <- dataset_splits %>%
map(~ .x %>% training()) %>%
map2_dfc(spec_names, function(df, col_name) {
tibble(!!col_name := models_specs %>% map(~ .x %>% fit(Exited ~ ., data = df)))
})

df_pred1 <- ranger_model_1 %>%
predict(df_test1) %>%
bind_cols(df_test1)

df_pred2 <- ranger_model_2 %>%
predict(df_test2) %>%
bind_cols(df_test2)

df_pred1 %>% metrics(Exited, .pred_class)
df_pred2 %>% metrics(Exited, .pred_class)

df_pred_probs1 <- ranger_model_1 %>%
predict(df_test1, type = "prob") %>%
bind_cols(df_test1)

df_pred_probs2 <- ranger_model_2 %>%
predict(df_test2, type = "prob") %>%
bind_cols(df_test2)

df_pred_probs1 %>% roc_auc(Exited, .pred_No)
df_pred_probs2 %>% roc_auc(Exited, .pred_No)

df_pred_probs1 %>% roc_curve(Exited, .pred_No) %>% autoplot()
df_pred_probs2 %>% roc_curve(Exited, .pred_No) %>% autoplot()

vi(ranger_model_1)
vi(ranger_model_2)

vip(ranger_model_1)
vip(ranger_model_2)



pred_dfs <- list(fitted_models, testing_sets, spec_names) %>% pmap_dfc(function(models_by_spec, df, spec_name) {
tibble(!!spec_name :=
models_by_spec %>% map(function(model) {
df %>% bind_cols(
model %>% predict(df),
model %>% predict(df, type = "prob")
)
}))
})

fitted_models %>% write_rds("data/fitted_models.RDS", compress = "bz2")
pred_dfs %>% write_rds("data/predictions.RDS", compress = "bz2")
49 changes: 49 additions & 0 deletions validation_rf.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#### walidation

library(tidyverse)
library(stringr)
library(tidymodels)
library(ranger)
library(vip)
library(knitr)
library(kableExtra)

rm(list = ls())

fitted_models <- read_rds("data/fitted_models.RDS")
pred_dfs <- read_rds("data/predictions.RDS")

source("funs_valid.R")

all_metrics <- get_all_metrics(pred_dfs)
all_metrics %>%
dplyr::select(-2) %>%
kable(format = "html") %>%
save_kable("figures/metrics.png")

pred_dfs[[1]][[1]] %>%
exportable_conf_matrix %>%
kable(format = "html") %>%
save_kable("figures/conf_matrix1.png")

pred_dfs[[2]][[1]] %>%
exportable_conf_matrix %>%
kable(format = "html") %>%
save_kable("figures/conf_matrix2.png")

roc_1 <- pred_dfs[[1]][[1]] %>%
roc_curve(Exited, .pred_No) %>%
autoplot()

roc_2 <- pred_dfs[[2]][[1]] %>%
roc_curve(Exited, .pred_No) %>%
autoplot()

ggsave("figures/roc_1.png", roc_1)
ggsave("figures/roc_2.png", roc_2)

vip_1 <- vip(fitted_models[[1]][[1]])
vip_2 <- vip(fitted_models[[2]][[1]])

ggsave("figures/vip_1.png", vip_1)
ggsave("figures/vip_2.png", vip_2)

0 comments on commit 6fd060d

Please sign in to comment.