-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
80 lines (58 loc) · 1.89 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import tensorflow as tf
from tensorflow import keras
from keras.models import Model
from keras.layers import (
Input,
Conv2D,
MaxPooling2D,
UpSampling2D,
concatenate,
Conv2DTranspose,
BatchNormalization,
Dropout,
Lambda,
)
from keras.optimizers import Adam
from keras.layers import Activation, MaxPool2D, Concatenate
def conv_block(input, num_filters):
x = Conv2D(num_filters, 3, padding="same")(input)
x = BatchNormalization()(x) # Not in the original network.
x = Activation("relu")(x)
x = Conv2D(num_filters, 3, padding="same")(x)
x = BatchNormalization()(x) # Not in the original network
x = Activation("relu")(x)
return x
# Encoder block: Conv block followed by maxpooling
def encoder_block(input, num_filters):
x = conv_block(input, num_filters)
p = MaxPool2D((2, 2))(x)
return x, p
# Decoder block
# skip features gets input from encoder for concatenation
def decoder_block(input, skip_features, num_filters):
x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
x = Concatenate()([x, skip_features])
x = conv_block(x, num_filters)
return x
# Build Unet using the blocks
def build_unet(input_shape, n_classes):
inputs = Input(input_shape)
s1, p1 = encoder_block(inputs, 8)
s2, p2 = encoder_block(p1, 16)
s3, p3 = encoder_block(p2, 32)
s4, p4 = encoder_block(p3, 64)
b1 = conv_block(p4, 128) # Bridge
d1 = decoder_block(b1, s4, 64)
d2 = decoder_block(d1, s3, 32)
d3 = decoder_block(d2, s2, 16)
d4 = decoder_block(d3, s1, 8)
if n_classes == 1: # Binary
activation = "sigmoid"
else:
activation = "softmax"
outputs = Conv2D(n_classes, 1, padding="same", activation=activation)(
d4
) # Change the activation based on n_classes
print(activation)
model = Model(inputs, outputs, name="U-Net")
return model