Skip to content
This repository was archived by the owner on Sep 20, 2024. It is now read-only.

Commit 2e2f7bf

Browse files
committed
Simplified models, config and plot
1 parent 5ec95ea commit 2e2f7bf

13 files changed

+203
-150
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ python src/export.py config/sample.json --checkpoint_path PATH_TO_CHECKPOINT --o
2727
## Models
2828
### Anomaly detection
2929
Anomaly detection will use a autoencoder approche, the prediction substratced from the original image should show the anomaly.
30-
- Fast: Simple 1D prediction
31-
- Advanced: Simple 2D prediction
30+
- Deep Autoencoder: Simple 1D autoencoder
31+
- Convolutional Autoencoder
3232
### Segmentation
3333
#### Unet
3434
- [Vanilla Unet](https://arxiv.org/pdf/1505.04597.pdf) (original paper)

config/sample.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"input_shape": [512, 512, 3],
3-
"model": "fast",
3+
"model": "deep_autoencoder",
44
"train": {
55
"batch_size": 40,
66
"epochs": 100,
@@ -20,7 +20,7 @@
2020
"height_shift_range": 0.3,
2121
"rotation_range": 10
2222
},
23-
"fast_model": {
23+
"deep_autoencoder_model": {
2424
"translator_layer_size": 200,
2525
"middle_layer_size": 32
2626
}

src/image_generator.py

+1
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def main():
9696
print("Finished round (" + str(r+1) + "/" + str(rounds) + ")")
9797

9898
def generate_images(config, image_util, output_path, prefix, class_name, images, masks):
99+
prefix = str(prefix) + "_" + class_name
99100
images_output_path = output_path + "/" + class_name
100101
masks_output_path = images_output_path + "/masks"
101102

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .advanced_model import AdvancedModel
2-
from .fast_model import FastModel
1+
from .convolutional_autoencoder_model import ConvolutionalAutoencoderModel
2+
from .deep_autoencoder_model import DeepAutoencoderModel
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,43 @@
1+
import matplotlib.pyplot as plt
12
import numpy as np
3+
import numpy.ma as ma
24
import tensorflow as tf
35
from tensorflow.keras import Input
4-
from tensorflow.keras.layers import (
5-
Activation,
6-
BatchNormalization,
7-
Conv2D,
8-
Conv2DTranspose,
9-
Dense,
10-
Flatten,
11-
LeakyReLU,
12-
Reshape,
13-
)
6+
from tensorflow.keras.layers import (Activation, BatchNormalization, Conv2D,
7+
Conv2DTranspose, Dense, Flatten,
8+
LeakyReLU, Reshape)
149
from tensorflow.keras.models import Model
1510
from tensorflow.keras.optimizers import Adam
11+
1612
from models import BaseModel
13+
from utils.plots import *
1714

1815

19-
class AdvancedModel(BaseModel):
16+
class ConvolutionalAutoencoderModel(BaseModel):
2017
def __init__(self, config):
2118
super().__init__(config)
2219

2320
def create_optimizer(self, optimzer="adam"):
2421
super().create_optimizer(optimzer)
2522

26-
def create_model(self, filters=(32, 64), latent_dim=16):
23+
def compile(self, loss="mse"):
24+
self.model.compile(loss=loss, optimizer=self.optimizer, metrics=["accuracy"])
25+
26+
def create_model(self):
27+
filters = (32, 64)
28+
kernel_size = (3,3)
29+
try:
30+
model_config = self.config.train.raw["convolutional_autoencoder_model"]
31+
latent_dim = model_config["latent_dim"]
32+
except:
33+
latent_dim = 16
34+
35+
2736
input_shape = self.config.input_shape
2837
inputs = Input(shape=input_shape, name=self.input_name)
2938
x = inputs
3039
for f in filters:
31-
x = Conv2D(filters=f, kernel_size=(3, 3), strides=2, padding="same")(x)
40+
x = Conv2D(filters=f, kernel_size=kernel_size, strides=2, padding="same")(x)
3241
x = LeakyReLU(alpha=0.2)(x)
3342
x = BatchNormalization(axis=input_shape[2])(x)
3443
volume_size = tf.keras.backend.int_shape(x)
@@ -41,11 +50,11 @@ def create_model(self, filters=(32, 64), latent_dim=16):
4150
x = Reshape((volume_size[1], volume_size[2], volume_size[3]))(x)
4251
for f in filters[::-1]:
4352
x = Conv2DTranspose(
44-
filters=f, kernel_size=(3, 3), strides=2, padding="same"
53+
filters=f, kernel_size=kernel_size, strides=2, padding="same"
4554
)(x)
4655
x = LeakyReLU(alpha=0.2)(x)
4756
x = BatchNormalization(axis=input_shape[2])(x)
48-
x = Conv2DTranspose(filters=input_shape[2], kernel_size=(3, 3), padding="same")(
57+
x = Conv2DTranspose(filters=input_shape[2], kernel_size=kernel_size, padding="same")(
4958
x
5059
)
5160
outputs = Activation("sigmoid", name=self.output_name)(x) # Decoded
@@ -54,9 +63,5 @@ def create_model(self, filters=(32, 64), latent_dim=16):
5463
self.model = Model(inputs, decoder(encoder(inputs)), name="autoencoder")
5564
return self.model
5665

57-
def overwrite_optimizer(self, optimizer, optimizer_name):
58-
self.optimzer = optimizer
59-
self.optimizer_name = optimizer_name
60-
61-
def compile(self, loss="mse"):
62-
self.model.compile(loss=loss, optimizer=self.optimizer, metrics=["accuracy"])
66+
def plot_predictions(self, test_images):
67+
plot_difference(self.config, self.predictions, test_images)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
import numpy.ma as ma
4+
import tensorflow as tf
5+
from tensorflow.keras.layers import BatchNormalization, Dense, Flatten, Reshape
6+
from tensorflow.keras.models import Model
7+
from tensorflow.keras.optimizers import Adam
8+
9+
from models import BaseModel
10+
from utils.plots import *
11+
12+
13+
class DeepAutoencoderModel(BaseModel):
14+
def __init__(self, config):
15+
super().__init__(config)
16+
17+
def create_optimizer(self, optimzer="adam"):
18+
super().create_optimizer(optimzer)
19+
20+
def create_model(self):
21+
input_shape = self.config.input_shape
22+
23+
try:
24+
model_config = self.config.train.raw["deep_autoencoder_model"]
25+
translator_layer_size = model_config["translator_layer_size"]
26+
middle_layer_size = model_config["middle_layer_size"]
27+
except:
28+
translator_layer_size = 100
29+
middle_layer_size = 16
30+
31+
sub_layer_size = int(translator_layer_size / 2)
32+
input_dim = input_shape[0] * input_shape[1] * input_shape[2]
33+
input = tf.keras.Input(input_shape, name=self.input_name)
34+
x = input
35+
x = Flatten()(x)
36+
x = BatchNormalization()(x)
37+
x = Dense(translator_layer_size, activation="relu", name="encoder")(x)
38+
x = Dense(sub_layer_size, activation="relu")(x)
39+
x = Dense(middle_layer_size, activation="relu")(x)
40+
x = Dense(sub_layer_size, activation="relu")(x)
41+
x = BatchNormalization()(x)
42+
x = Dense(translator_layer_size, activation="relu", name="decoder")(x)
43+
x = Dense(input_dim, activation="sigmoid", name="reconstructor")(x)
44+
x = Reshape(input_shape, name=self.output_name)(x)
45+
46+
self.model = Model(input, x)
47+
return self.model
48+
49+
def compile(self, loss="mean_squared_error"):
50+
self.model.compile(loss=loss, optimizer=self.optimizer, metrics=["accuracy"])
51+
52+
def plot_predictions(self, test_images):
53+
plot_difference(self.config, self.predictions, test_images)

src/models/anomaly_detection/fast_model.py

-86
This file was deleted.

src/models/base_model.py

+4-32
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from tensorflow.keras.preprocessing.image import ImageDataGenerator
1111

1212
from utils import *
13+
from utils.plots import *
1314

1415

1516
class BaseModel:
@@ -84,7 +85,7 @@ def load_weights(self):
8485

8586
def train(self):
8687
if not self.train_datagen == None:
87-
self.model.fit(
88+
self.history = self.model.fit(
8889
self.train_datagen,
8990
epochs=self.config.train.epochs,
9091
steps_per_epoch=len(self.train_images) / self.config.train.batch_size,
@@ -95,7 +96,7 @@ def train(self):
9596
)
9697

9798
else:
98-
self.model.fit(
99+
self.history = self.model.fit(
99100
self.train_images,
100101
self.y_train,
101102
batch_size=self.config.train.batch_size,
@@ -114,36 +115,7 @@ def predict(self, test_images):
114115
)
115116

116117
def plot_predictions(self, test_images):
117-
pred_count = len(self.predictions)
118-
plt_shape = (self.config.input_shape[0], self.config.input_shape[1])
119-
plt_cmap = "gray"
120-
if self.config.input_shape[2] > 1:
121-
plt_shape = (
122-
self.config.input_shape[0],
123-
self.config.input_shape[1],
124-
self.config.input_shape[2],
125-
)
126-
index = 1
127-
plt_index = 0
128-
for test_image in test_images:
129-
original_image = test_image.reshape(plt_shape)
130-
pred_image = self.predictions[plt_index].reshape(plt_shape)
131-
mask = ma.masked_where(pred_image < self.config.test_threshold, pred_image)
132-
plt.subplot(pred_count, 3, index)
133-
plt.title("Original")
134-
plt.imshow(original_image, interpolation="none", cmap=plt_cmap)
135-
index += 1
136-
plt.subplot(pred_count, 3, index)
137-
plt.title("Prediction")
138-
plt.imshow(pred_image, interpolation="none", cmap=plt_cmap)
139-
index += 1
140-
plt.subplot(pred_count, 3, index)
141-
plt.title("Overlay")
142-
plt.imshow(original_image, interpolation="none", cmap=plt_cmap)
143-
plt.imshow(mask, cmap="jet", interpolation="none", alpha=0.7)
144-
index += 1
145-
plt_index += 1
146-
plt.show()
118+
utils.plot_prediction(config, self.predictions, test_images)
147119

148120
def prepare_training(self):
149121
self.train_images = None

src/train.py

+30
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,18 @@ def main():
8282
metavar="path",
8383
help="Overwrites the path to the saved checkpoint containing the model weights",
8484
)
85+
parser.add_argument(
86+
"--plot_history",
87+
dest="plot_history",
88+
metavar="boolean (default: false)",
89+
type=bool,
90+
help="Plots the model training history",
91+
)
8592

8693
args = parser.parse_args()
8794
config_path = args.config
8895
config = Config(config_path)
96+
plot_history = False
8997

9098
# Overwrite config
9199
if args.train_files_path:
@@ -109,6 +117,9 @@ def main():
109117
if args.checkpoint_path:
110118
config.train.checkpoint_path = args.checkpoint_path
111119

120+
if args.plot_history:
121+
plot_history = True
122+
112123
# Set seed to get reproducable experiments
113124
seed_value = 33
114125
os.environ["PYTHONHASHSEED"] = str(seed_value)
@@ -128,7 +139,26 @@ def main():
128139

129140
# ToDo: Train model
130141
model.train()
142+
history = model.history
143+
epochs = len(history.epoch) + model.initial_epoch
144+
model.model.save_weights(config.train.checkpoints_path + "/model-{0:04d}.ckpts".format(epochs))
131145

146+
if plot_history:
147+
plt.subplot(121)
148+
plt.plot(history.history['accuracy'])
149+
plt.plot(history.history['val_accuracy'])
150+
plt.title('Model accuracy')
151+
plt.ylabel('Accuracy')
152+
plt.xlabel('Epoch')
153+
plt.legend(['Train', 'Test'], loc='upper left')
154+
plt.subplot(122)
155+
plt.plot(history.history['loss'])
156+
plt.plot(history.history['val_loss'])
157+
plt.title('Model loss')
158+
plt.ylabel('Loss')
159+
plt.xlabel('Epoch')
160+
plt.legend(['Train', 'Test'], loc='upper left')
161+
plt.show()
132162

133163
if __name__ == "__main__":
134164
main()

src/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .config import Config
22
from .image_util import ImageUtil
33
from .model_creater import *
4+
from .plots import *

0 commit comments

Comments
 (0)