-
Notifications
You must be signed in to change notification settings - Fork 0
/
organicml-train.Rmd
150 lines (116 loc) · 3.95 KB
/
organicml-train.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
---
title: "Organic ML Pretrained Convnet"
output:
md_document:
variant: markdown_github
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE)
```
## Load the libraries
```{r}
library(keras)
library(tensorflow)
```
## Define where the image datasets are
There is a small dataset of 410 images. Because there are so few images, I will just work with a training and validation set, and I will not use a separate test dataset.
|dataset|benzene ring images|non benzene ring images
|---|---|---|
|train|152|148|
|validation|57|53|
```{r}
train_dir <- "data/train"
validation_dir <- "data/validation"
```
## Create the image generators
Use data augmentation on the training set, but not the validation set.
Create a function to create the generators. This function will be called twice. The first time will be during training of the first network. The second time will be training the second network.
```{r warning=FALSE}
create_generators <- function(test_validation_batch_size) {
validation_datagen <- image_data_generator(rescale = 1/255)
train_datagen <- image_data_generator(rescale = 1/255)
train_generator <- flow_images_from_directory(
train_dir,
train_datagen,
target_size = c(256, 256),
batch_size = test_validation_batch_size,
shuffle = FALSE,
class_mode = "binary"
)
validation_generator <- flow_images_from_directory(
validation_dir,
validation_datagen,
target_size = c(256, 256),
batch_size = test_validation_batch_size,
shuffle = FALSE,
class_mode = "binary"
)
list(train_generator = train_generator, validation_generator = validation_generator)
}
```
## Define the network
Use VGG16 trained on imagenet for the convolutional layers. Freeze the convolutional layers so that their weights are not adjusted during training. The final layers will be a simple binary classifier made with dense layers.
First, make a function that can return an initalized network.
```{r}
create_network <- function() {
conv_base <- application_vgg16(
weights = "imagenet",
include_top = FALSE,
input_shape = c(256, 256, 3)
)
conv_base %>% freeze_weights()
network <- keras_model_sequential() %>%
conv_base %>%
layer_flatten() %>%
layer_dense(units = 256, activation = "relu") %>%
layer_dense(units = 128, activation = "relu") %>%
layer_dense(units = 64, activation = "relu") %>%
layer_dense(units = 1, activation = "sigmoid")
network %>% compile(
loss = "binary_crossentropy",
optimizer = optimizer_rmsprop(learning_rate = 0.5e-5),
metrics = c("accuracy")
)
network
}
```
Now use the function to create the network
```{r error=FALSE, warning=FALSE}
convnet_1 <- create_network()
summary(convnet_1)
```
## Fit the network using the training and validation data
Checkpoint every model that shows a better `val_accuracy` than the previous. Stop training when there have been 5 epochs without improvement to `val_accuracy`.
```{r error=FALSE, results='hide'}
tensorflow::set_random_seed(0)
callbacks_list <- list(
callback_model_checkpoint(
filepath = "organicml_checkpoint.h5",
monitor = "val_accuracy",
mode = "max",
save_best_only = TRUE
)
)
test_validation_batch_size <- 25
generators_1 <- create_generators(test_validation_batch_size = test_validation_batch_size)
fit_history_1 <- convnet_1 %>%
fit(
generators_1$train_generator,
steps_per_epoch = 300/test_validation_batch_size,
epochs = 50,
validation_data = generators_1$validation_generator,
validation_steps = 110/test_validation_batch_size,
callbacks = callbacks_list
)
```
## Analyze training history
```{r}
plot(fit_history_1)
```
What epoch had the maximum validation accuracy?
```{r}
max_val_accuracy_1 <- max(fit_history_1$metrics$val_accuracy)
argmax_val_accuracy_1 <- which.max(fit_history_1$metrics$val_accuracy)
cat("Maximum val_accuracy: ", max_val_accuracy_1, "\n")
cat("Epoch of maximum validation accuracy: ", argmax_val_accuracy_1, "\n")
```