-
Notifications
You must be signed in to change notification settings - Fork 0
/
resunet_test_elli.py
66 lines (52 loc) · 2.22 KB
/
resunet_test_elli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from model import resunet
from utility.loss import seg_loss, dice_coef, iou_coef
import os
import glob
import numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
def npy_loader(maindir, seed: int = 2):
# Get list of files in directory
directory = maindir + '*.npy'
pathlist = glob.glob(directory)
# Iterate over list of files
for path in pathlist:
array = np.load(path)
img = array[:, :, 0]
img = (img * 2) - 1
mask = array[:, :, 1]
yield img[..., np.newaxis], mask[..., np.newaxis]
def npy_dataset(maindir, shape_i, shape_m, seed: int = 2, batch: int = 1):
ds = tf.data.Dataset.from_generator(lambda: npy_loader(maindir=maindir, seed=seed),
output_types=(tf.float16, tf.float16),
output_shapes=(shape_i, shape_m))
return ds.batch(batch)
# load data
batch_size = 4
train_df = npy_dataset('/home/careinfolab/unet_mammo/images/pos_norm_elli/train/',
(1024, 832, 1),
(1024, 832, 1),
seed=2,
batch=batch_size)
val_df = npy_dataset('/home/careinfolab/unet_mammo/images/pos_norm_elli/val/',
(1024, 832, 1),
(1024, 832, 1),
seed=2,
batch=batch_size)
test_df = npy_dataset('/home/careinfolab/unet_mammo/images/pos_norm_elli/test/',
(1024, 832, 1),
(1024, 832, 1),
seed=2,
batch=batch_size)
# define model
model = resunet.get_res_unet(1024, 832)
model.compile(optimizer=Adam(), loss=[seg_loss], metrics=[dice_coef, iou_coef])
name = './saved_models/resunet_test1_elli'
# get callback functions
model_checkpoint = ModelCheckpoint(name, monitor='val_loss', save_best_only=True)
early_stopping = EarlyStopping(monitor='val_loss', patience=6)
history = model.fit(train_df, epochs=150, verbose=1, shuffle=True,
validation_data=val_df,
callbacks=[model_checkpoint, early_stopping])