Skip to content

Commit

Permalink
Merge pull request #12 from jcierocki/gbm-model
Browse files Browse the repository at this point in the history
XGBoost model and training on raw, not factorized datasets
  • Loading branch information
jcierocki authored May 9, 2020
2 parents 2aee0c6 + 9b23293 commit 7fe84f5
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 29 deletions.
Binary file modified data/split.RDS
Binary file not shown.
Binary file added data/split_raw.RDS
Binary file not shown.
11 changes: 8 additions & 3 deletions dataset_prep.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@ data1 <- data_raw %>%
HasCrCard = factor(HasCrCard) %>% `levels<-`(c("No", "Yes"))) %>%
dplyr::select(-RowNumber, -CustomerId, -Surname)

data2 <- data1 %>% factorize() %>% as_tibble() %>% filter_vars_by_iv()
data1 %>% filter_vars_by_iv(significance_thres = 0.02) %>%
initial_split(prop = 0.75) %>%
saveRDS("data/split_raw.RDS")

data2 <- data1 %>%
factorize(bin_methods = "tree") %>%
as_tibble() %>%
filter_vars_by_iv(significance_thres = 0.02)

dataset_split <- initial_split(data2, prop = 0.75) %>% saveRDS("data/split.RDS")

rm(list = ls())


24 changes: 15 additions & 9 deletions funs_preproc.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,23 @@ merge_factor_vars.tbl <- function(var1, ...) {
do.call(function(...) factor(str_c(...)), as.list(var1))
}

factorize <- function(df, y_name = "Exited", y_pos = "No") {
choose_best_binning <- function(binnings_df) {
binnings_df %>% pmap(function(...) {
opts <- list(...)
best_iv_idx <- opts %>% map_dbl(~ .x$total_iv[1]) %>% which.max()

opts[[best_iv_idx]]
}) %>% return
}

factorize <- function(df, y_name = "Exited", y_pos = "No", bin_limit = 6, bin_methods = c("tree", "chimerge")) {
fct_cols <- colnames(df)[data1 %>% map_lgl(~ !is.factor(.x)) & colnames(df) != y_name]
bins_tree <- df %>% woebin(y = y_name, x = fct_cols, positive = y_pos, bin_num_limit = 5, method = "tree")
bins_chimerge <- df %>% woebin(y = y_name, x = fct_cols, positive = y_pos, bin_num_limit = 5, method = "chimerge")
binnings <- bin_methods %>%
map(~ df %>% woebin(y = y_name, x = fct_cols, positive = y_pos, bin_num_limit = bin_limit, method = .x)) %>%
`names<-`(bin_methods) %>%
as_tibble()

bins_best <- map2(bins_tree, bins_chimerge, function(x, y) {
if(x$total_iv[1] > y$total_iv[1])
return(x)
else
return(y)
})
bins_best <- choose_best_binning(binnings)

df %>% woebin_ply(bins = bins_best, to = "bin") %>%
mutate_if(~ !is.factor(.x), as.factor) %>%
Expand Down
32 changes: 32 additions & 0 deletions gbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,38 @@ rm(list = ls())
# source("dataset_prep.R")

dataset_split <- readRDS("data/split.RDS")
dataset_split$data <- dataset_split$data %>%
mutate_if(~ length(levels(.x)) > 3, as.integer) %>%
mutate_at(vars(Balance), as.integer)

# dataset_split <- readRDS("data/split_raw.RDS")

df_train <- dataset_split %>% training()
df_test <- dataset_split %>% testing()

gbm_model_1 <- boost_tree(mode = "classification",
mtry = 3,
trees = 500,
min_n = 5,
# tree_depth = 5,
learn_rate = .1,
loss_reduction = 0,
sample_size = 0.7) %>%
set_engine("xgboost", objective = "binary:logistic") %>%
fit(Exited ~ ., data = df_train)

df_pred <- gbm_model_1 %>%
predict(df_test) %>%
bind_cols(df_test)

df_pred %>% metrics(Exited, .pred_class)

df_pred_probs <- gbm_model_1 %>%
predict(df_test, type = "prob") %>%
bind_cols(df_test)

df_pred_probs %>% roc_auc(Exited, .pred_No)
df_pred_probs %>% roc_curve(Exited, .pred_No) %>% autoplot()

vi(gbm_model_1)
vip(gbm_model_1)
55 changes: 38 additions & 17 deletions rand_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,52 @@ rm(list = ls())

# source("dataset_prep.R")

dataset_split <- readRDS("data/split.RDS")
df_train <- dataset_split %>% training()
df_test <- dataset_split %>% testing()
dataset_split1 <- readRDS("data/split.RDS")
dataset_split2 <- readRDS("data/split_raw.RDS")

df_train1 <- dataset_split1 %>% training()
df_test1 <- dataset_split1 %>% testing()
df_train2 <- dataset_split2 %>% training()
df_test2 <- dataset_split2 %>% testing()

ranger_model_1 <- rand_forest("classification", 2, 500, 5) %>%
set_engine("ranger", num.threads = 8, replace = F, sample.fraction = 0.8, importance = "impurity") %>%
# set_engine("ranger", num.threads = 8, replace = F, sample.fraction = 0.8, importance = "permutation", local.importance = T) %>%
# set_engine("ranger", num.threads = 8) %>%
fit(Exited ~ ., data = df_train)
ranger_model_specs <- rand_forest("classification", 2, 1000, 5) %>%
# set_engine("ranger", num.threads = 8, replace = F, sample.fraction = 0.8, importance = "impurity") %>%
set_engine("ranger", num.threads = 8, replace = F, sample.fraction = 0.8, importance = "permutation", local.importance = T)

ranger_model_1 <- ranger_model_specs %>% fit(Exited ~ ., data = df_train1)

ranger_model_2 <- ranger_model_specs %>% fit(Exited ~ ., data = df_train2)

df_pred <- ranger_model_1 %>%
predict(df_test) %>%
bind_cols(df_test)
df_pred1 <- ranger_model_1 %>%
predict(df_test1) %>%
bind_cols(df_test1)

df_pred2 <- ranger_model_2 %>%
predict(df_test2) %>%
bind_cols(df_test2)

df_pred %>% metrics(Exited, .pred_class)
df_pred1 %>% metrics(Exited, .pred_class)
df_pred2 %>% metrics(Exited, .pred_class)

df_pred_probs <- ranger_model_1 %>%
predict(df_test, type = "prob") %>%
bind_cols(df_test)
df_pred_probs1 <- ranger_model_1 %>%
predict(df_test1, type = "prob") %>%
bind_cols(df_test1)

df_pred_probs %>% roc_auc(Exited, .pred_No)
df_pred_probs %>% roc_curve(Exited, .pred_No) %>% autoplot()
df_pred_probs2 <- ranger_model_2 %>%
predict(df_test2, type = "prob") %>%
bind_cols(df_test2)

df_pred_probs1 %>% roc_auc(Exited, .pred_No)
df_pred_probs2 %>% roc_auc(Exited, .pred_No)

df_pred_probs1 %>% roc_curve(Exited, .pred_No) %>% autoplot()
df_pred_probs2 %>% roc_curve(Exited, .pred_No) %>% autoplot()

vi(ranger_model_1)
vi(ranger_model_2)

vip(ranger_model_1)
vip(ranger_model_2)



0 comments on commit 7fe84f5

Please sign in to comment.