-
Notifications
You must be signed in to change notification settings - Fork 0
/
01-hands-on-classification-tree.Rmd
182 lines (118 loc) · 7.05 KB
/
01-hands-on-classification-tree.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
---
title: "Session 1: Decision Trees in R"
author: "JR Ferrer-Paris"
date: "02/08/2021"
tags: [UNSW coders, Workshop]
abstract: |
This document will walk you through examples to fit decision trees for classification using the _rpart_ package and two different datasets. This is part of the UNSW codeRs workshop: _Introduction to Classification Trees and Random Forests in R_ at https://github.com/UNSW-codeRs/workshop-random-forests
UNSW codeRs is a student and staff run community dedicated for ‘R’ users for anyone who wants to further develop their coding skills. It is our goal to create a safe and open space for members to share and gain new experiences relating to R, coding and statistics.
https://unsw-coders.netlify.app/
output:
pdf_document:
template: NULL
editor_options:
chunk_output_type: console
---
## Overview
Decision trees are recursive partitioning methods that divide the predictor spaces into simpler regions and can be visualized in a tree-like structure. They attempt to classify data by dividing it into subsets according to a Y output variable and based on some predictors.
## What data do we need?
- Y: The output or response variable is a categorical variable with two or more classes (in R: factor with two or more levels)
- X: A set of predictors or features, might be a mix of continuous and categorical variables, they should not have any missing values
### Load data
Here we will work with two examples.
First, we will use the _iris_ dataset from base R. This dataset has 150 observations with four measurements (continuous variables) for three species (categorical variable with three categories):
```{r dataset1}
data(iris)
str(iris)
```
As a second example we will use the Breast Cancer dataset from the _mlbench_ package. This dataset has 699 observations with 9 nominal or ordinal variables describing cell properties and the output or target variable is the class of tumor with two possible values: bening or malignant:
```{r dataset2}
require(mlbench)
data(BreastCancer)
str(BreastCancer)
```
## What package to use
Classification trees are implemented in packages:
- _tree_: Classification and Regression Trees
- _rpart_: Recursive Partitioning and Regression Trees
### Load packages
Here we will work with package _rpart_, and we will also load additional packages for creating the plots
```{r load_packages}
library(rpart)
# auxiliary packages for plotting:
library(rpart.plot)
```
## Fit a model
### _iris_ dataset
Let's start with a familiar dataset:
```{r tree_iris}
set.seed(3)
tree = rpart::rpart(Species ~ ., data = iris,
method = "class")
print(tree)
```
This is a very simple tree and we can walk through the output recognising three levels of nodes: (1) is the root node, (2) and (3) are the branches based on Petal Length. Branch (2) has 50 samples all belong to the first class (setosa), branch (3) has 100 samples of two different classes. Branch (3) splits into two further branches (6) and (7) based on petal width, these end-nodes (or leaf-nodes) have 54 and 46 samples respectively.
We can visualise the same information in a fancy _rpart.plot_:
```{r plot_tree_iris}
rpart.plot::rpart.plot(tree)
```
This function use different colors for each category, splits are labelled with the variable and threshold used. Root nodes are on top, and end-nodes are at the bottom, each node is labelled with the modal category and has information on the proportion of observation in each category and the percentage of the total sample size.
When an end node only contains samples from a single class it is considered to be "pure". So the end-node for _I. setosa_ at the bottom left is pure, the other end-nodes have 2 and 9% impurity.
### _Breast cancer_ dataset
Now let's look at a more challenging dataset.
```{r tree_breast_cancer}
set.seed(3)
BC.data <- BreastCancer[,-1]
tree = rpart::rpart(Class ~ ., data = BC.data,
method = "class")
print(tree)
```
This is a more complex tree with up to five levels of branching, can you see them?
The plot is a great visual aid, but what do all these values mean?
```{r plot_tree_breast_cancer}
rpart.plot::rpart.plot(tree)
```
Here the output or response variable has two categories, so the rules are slightly simplified, but is actually all pretty similar as the previos example. Each box is labelled with the modal category on top, the proportion of observations in the second class within each group (in this case 'malignant'), and the percentage of total observation within the group. Compare figure and text to try make sense of this.
#### Variable importance
We can also look inside of `tree` object to see its components, for example "variable.importance":
```{r variable_importance}
names(tree)
data.frame(tree$variable.importance)
```
#### Complexity parameter
In decision trees the main hyperparameter (configuration setting) is the **complexity parameter** (CP), but the name is a little counterintuitive; a high CP results in a simple decision tree with few splits, whereas a low CP results in a larger decision tree with many splits.
`rpart` uses cross-validation internally to estimate the accuracy at various CP settings. We can review those to see what setting seems best.
Print the results for various CP settings - we want the one with the lowest "xerror".
```{r print_complex_parameter}
printcp(tree)
```
We can visualise this using this function:
```{r plot_complex_parameter}
plotcp(tree)
```
There is an obvious drop between 1 and 2, but afterwards the differences in xerror are pretty small. Considering that a tree with fewer splits might be easier to interpret we can adjust the `cp` value:
```{r prune_tree}
tree_pruned2 = prune(tree, cp = 0.037)
```
How does it look?
```{r}
rpart.plot(tree_pruned2)
```
So this tree looks much simpler, but is it good enough? Notice that the end-nodes on the left and right have relative low impurity and together include 96% of the original sample. So with only two variables we can get a good discrimination of most of the samples.
If we want to look at the detailed results, variable importance, and summary of splits we can use:
```{r}
summary(tree_pruned2)
```
And compare this output with the previous `tree` object.
That's it for now! Let's move to the next [document](02-hands-on-randmom-forest.Rmd).
## Post-scriptum
#### Additional resources
- Davis David [**Random Forest Classifier Tutorial: How to Use Tree-Based Algorithms for Machine Learning**](https://www.freecodecamp.org/news/how-to-use-the-tree-based-algorithm-for-machine-learning/)
- Evan Muzzall and Chris Kennedy [**Introduction to Machine Learning in R**](https://dlab-berkeley.github.io/Machine-Learning-in-R/slides.html)
- Dave Tang [**Building a classification tree in R**](https://davetang.org/muse/2013/03/12/building-a-classification-tree-in-r/)
- Zach @ Statology [**How to Fit Classification and Regression Trees in R**](https://www.statology.org/classification-and-regression-trees-in-r/)
- Ben Gorman [**Decision Trees in R using rpart**](https://www.gormanalysis.com/blog/decision-trees-in-r-using-rpart/)
#### Session information:
```{r sessionInfo}
sessionInfo()
```