generated from KSUDS/p4_machinelearning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.R
88 lines (68 loc) · 2.28 KB
/
model.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
library(tidyverse)
library(tidymodels)
library(DALEX)
library(skimr)
library(GGally)
library(xgboost)
library(vip)
library(patchwork)
httpgd::hgd()
httpgd::hgd_browse()
dat_ml <- read_rds("dat_ml.rds")
set.seed(76)
dat_split <- initial_split(dat_ml, prop = 1 / 2, strata = before1980)
dat_train <- training(dat_split)
dat_test <- testing(dat_split)
bt_model <- boost_tree() %>%
set_engine(engine = "xgboost") %>%
set_mode("classification") %>%
fit(before1980 ~ ., data = dat_train)
logistic_model <- logistic_reg() %>%
set_engine(engine = "glm") %>%
set_mode("classification") %>%
fit(before1980 ~ ., data = dat_train)
nb_model <- discrim::naive_Bayes() %>%
set_engine(engine = "naivebayes") %>%
set_mode("classification") %>%
fit(before1980 ~ ., data = dat_train)
vip(bt_model, num_features = 20) + vip(logistic_model, num_features = 20)
preds_logistic <- bind_cols(
predict(logistic_model, new_data = dat_test),
predict(logistic_model, dat_test, type = "prob"),
truth = pull(dat_test, before1980)
)
# takes a minute
preds_nb <- bind_cols(
predict(nb_model, new_data = dat_test),
predict(nb_model, dat_test, type = "prob"),
truth = pull(dat_test, before1980)
)
preds_bt <- bind_cols(
predict(bt_model, new_data = dat_test),
predict(bt_model, dat_test, type = "prob"),
truth = pull(dat_test, before1980)
)
metrics_calc <- metric_set(accuracy, bal_accuracy, precision, recall, f_meas)
preds_bt %>%
metrics_calc(truth, estimate = .pred_class)
preds_bt %>%
roc_curve(truth, estimate = .pred_before) %>%
autoplot() +
labs(title = "ROC Curve - Boosted Trees") +
theme_bw()
ggsave("ROCCurveBT.png", plot = last_plot())
preds_nb %>%
roc_curve(truth, estimate = .pred_before) %>%
autoplot() +
labs(title = "ROC Curve - Naive Bayes") +
theme_bw()
ggsave("ROCCurveNB.png", plot = last_plot())
preds_all <- bind_rows(
mutate(preds_nb, model = "Naive Bayes"),
mutate(preds_bt, model = "Boosted Tree"),
mutate(preds_logistic, model = "Logistic Regression")
)
preds_all %>%
group_by(model) %>%
metrics_calc(truth, estimate = .pred_class) %>%
pivot_wider(names_from = .metric, values_from = .estimate)