-
Notifications
You must be signed in to change notification settings - Fork 0
/
10-evaluating-classification.Rmd
216 lines (145 loc) · 8.29 KB
/
10-evaluating-classification.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
```{r 10_setup, include=FALSE}
knitr::opts_chunk$set(echo=TRUE, eval=FALSE)
```
# Evaluating Classification Models
## Learning Goals {-}
- Calculate (by hand from confusion matrices) and contextually interpret overall accuracy, sensitivity, and specificity
- Construct and interpret plots of predicted probabilities across classes
- Explain how a ROC curve is constructed and the rationale behind AUC as an evaluation metric
- Appropriately use and interpret the no-information rate to evaluate accuracy metrics
<br>
Slides from today are available [here](https://docs.google.com/presentation/d/1Ssof1lK7yOwh3b0Xuu2FTw6Rmw0rccFrRsDXwV56ZFo/edit?usp=sharing).
<br><br><br>
## LASSO for logistic regression in `caret` {-}
To build LASSO models for logistic regression in `caret`, first load the package and set the seed for the random number generator to ensure reproducible results:
```{r}
library(caret)
set.seed(___) # Pick your favorite number to fill in the parentheses
```
If we would like to use estimates of test (overall) accuracy to choose our final model (based on a probability threshold of 0.5), we can adapt the following:
```{r}
lasso_logistic_mod <- train(
y ~ x,
data = ___,
method = "glmnet",
family = "binomial",
tuneGrid = data.frame(alpha = 1, lambda = seq(___, ___, length.out = ___)),
trControl = trainControl(method = "cv", number = ___, selectionFunction = ___),
metric = "Accuracy",
na.action = na.omit
)
```
Argument Meaning
----------------------- -----------------
`y ~ x` `y` must be a `character` or `factor`
`data` Sample data
`method` & `family` The `glm` method implements various "generalized" linear models. When we specify `family = "binomial"`, `glm` performs logistic regression.
`trControl` Use cross-validation to estimate test performance for each model fit. `selectionFunction` can be `"best"` or `"oneSE"` as before.
`tuneGrid` Tuning parameters for LASSO: `alpha = 1` indicates LASSO (as opposed to another regularization method). Specify a sequence of `lambda` values for the penalty.
`metric` Evaluate and compare competing models with respect to their CV-`Accuracy`. (Uses a default probability threshold of 0.5 to make hard predictions.)
`na.action` Set `na.action = na.omit` to prevent errors if the data has missing values.
<br>
If we would like to choose our final model based on estimates of test sensitivity, specificity, or AUC, we can adapt the following:
```{r}
lasso_logistic_mod <- train(
y ~ x,
data = ___,
method = "glmnet",
family = "binomial",
tuneGrid = data.frame(alpha = 1, lambda = seq(___, ___, length.out = ___)),
trControl = trainControl(method = "cv", number = ___, selectionFunction = ___, classProbs = TRUE, summaryFunction = twoClassSummaryCustom),
metric = "AUC",
na.action = na.omit
)
```
The two new arguments to `trainControl()` are as follows:
- `classProbs`: Set to true if soft (probability) predictions should be computed
- `summaryFunction`: Use `twoClassSummaryCustom` in order to compute overall accuracy, sensitivity, and specificity (based on a threshold of 0.5) and to compute AUC
The `metric` now has 4 options:
- `"AUC"`: Compute AUC
- `"Sens"`: Compute sensitivity
- `"Spec"`: Compute specificity
- `"Accuracy"`: Compute overall accuracy
You'll need to run the code below to create the `twoClassSummaryCustom` function (don't worry about how this is written):
```{r}
twoClassSummaryCustom <- function (data, lev = NULL, model = NULL) {
if (length(lev) > 2) {
stop(paste("Your outcome has", length(lev), "levels. The twoClassSummary() function isn't appropriate."))
}
caret:::requireNamespaceQuietStop("pROC")
if (!all(levels(data[, "pred"]) == lev)) {
stop("levels of observed and predicted data do not match")
}
rocObject <- try(pROC::roc(data$obs, data[, lev[1]], direction = ">",
quiet = TRUE), silent = TRUE)
rocAUC <- if (inherits(rocObject, "try-error"))
NA
else rocObject$auc
out <- c(rocAUC, sensitivity(data[, "pred"], data[, "obs"],
lev[1]), specificity(data[, "pred"], data[, "obs"], lev[2]))
out2 <- postResample(data[, "pred"], data[, "obs"])
out <- c(out, out2[1])
names(out) <- c("AUC", "Sens", "Spec", "Accuracy")
out
}
```
<br><br><br>
## Exercises {-}
**You can download a template RMarkdown file to start from [here](template_rmds/10-evaluating-classification.Rmd).**
### Context {-}
Before proceeding, install the `pROC` package (utilities for evaluating classification models with ROC curves) by entering `install.packages("pROC")` in the Console.
We'll continue working with the [spam dataset](https://archive.ics.uci.edu/ml/datasets/Spambase) from last time.
- `spam`: Either `spam` or `not spam` (outcome)
- `word_freq_WORD`: percentage of words in the e-mail that match `WORD` (0-100)
- `char_freq_CHAR`: percentage of characters in the e-mail that match `CHAR` (e.g., exclamation points, dollar signs)
- `capital_run_length_average`: average length of uninterrupted sequences of capital letters
- `capital_run_length_longest`: length of longest uninterrupted sequence of capital letters
- `capital_run_length_total`: sum of length of uninterrupted sequences of capital letters
Our goal will be to use email features to predict whether or not an email is spam - essentially, to build a spam filter!
```{r}
library(readr)
library(ggplot2)
library(dplyr)
library(caret)
spam <- read_csv("https://www.dropbox.com/s/leurr6a30f4l32a/spambase.csv?dl=1")
# A little data cleaning to remove the space in "not spam"
spam <- spam %>%
mutate(spam = ifelse(spam=="spam", "spam", "not_spam"))
```
### Exercise 1: Conceptual warmup {-}
LASSO for the logistic regression setting works analogously to the regression setting. How would you expect a plot of test **accuracy** vs. $\lambda$ to look, and why? (Draw it!)
### Exercise 2: Implementing LASSO logistic regression in `caret` {-}
Fit a LASSO logistic regression model for the `spam` outcome, and allow all possible predictors to be considered (`~ .` in the model formula).
- Use 10-fold CV.
- Choose a final model whose test AUC is within one standard error of the overall best metric.
- Initially try a sequence of 100 $\lambda$'s from 0 to 10.
- Diagnose whether this sequence should be updated by looking at the plot of test AUC vs. $\lambda$ (`plot(lasso_logistic_mod)`).
- If needed, adjust the max value in the sequence up or down by a factor of 10. (You'll be able to determine from the plot whether to adjust up or down.)
```{r}
set.seed(___)
lasso_logistic_mod <- train(
)
```
### Exercise 3: Inspecting the model {-}
Inspect the `$bestTune` part of your fitted `lasso_logistic_mod` in conjunction with the plot of test AUC vs. $\lambda$.
Is anything surprising about the results relative to your expectations from Exercise 1? Brainstorm some possible explanations in consideration of the data context.
### Exercise 4: Interpreting evaluation metrics {-}
Inspect the overall CV results for the "best" $\lambda$, and compute the no-information rate (NIR):
```{r}
# CV results for "best lambda"
lasso_logistic_mod$results %>%
filter(lambda==lasso_logistic_mod$bestTune$lambda)
# Count up number of spam and not_spam emails in the training data
spam %>%
count(spam) # Name of the outcome variable goes inside count()
# Compute the NIR
```
- Interpret the estimates of test sensitivity and specificity - what do these numbers mean? Do you think higher sensitivity or specificity would be more important in designing a spam filter?
- Interpret overall accuracy - does this seem high? How can the no-information rate (NIR) help us interpret the overall accuracy?
- Why is an AUC of 1 the best possible value for this metric? How does the AUC for our spam model look relative to this best value?
### Exercise 5: Algorithmic understanding for evaluation metrics {-}
Inspect the iteration specific information from CV for the "best" $\lambda$:
```{r}
lasso_logistic_mod$resample
```
How is one row of information computed? Carefully describe the CV process for a single iteration to estimate each of `AUC`, `Sens`, `Spec`, and `Accuracy` (overall accuracy). Use a generic confusion matrix (filled with variables instead of hard numbers) to illustrate the underlying computations.