forked from LeeJunHyun/Image_Segmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
113 lines (94 loc) · 4.84 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
from torch.utils.data import DataLoader
import os
from solver import Solver
# from data_loader import get_loader
from torch.backends import cudnn
import random
from fudan_2d_loader import Fudan_2D_Dataset
def main(config):
cudnn.benchmark = True
if config.model_type not in ['U_Net','R2U_Net','AttU_Net','R2AttU_Net']:
print('ERROR!! model_type should be selected in U_Net/R2U_Net/AttU_Net/R2AttU_Net')
print('Your input for model_type was %s'%config.model_type)
return
# Create directories if not exist
if not os.path.exists(config.model_path):
os.makedirs(config.model_path)
if not os.path.exists(config.result_path):
os.makedirs(config.result_path)
config.result_path = os.path.join(config.result_path,config.model_type)
if not os.path.exists(config.result_path):
os.makedirs(config.result_path)
# lr = random.random()*0.0005 + 0.0000005
augmentation_prob= random.random()*0.7
# epoch = random.choice([100,150])
# decay_ratio = random.random()*0.8
# decay_epoch = int(epoch*decay_ratio)
config.augmentation_prob = augmentation_prob
config.num_epochs = 150
config.lr = 0.00005
config.num_epochs_decay = 50
print(config)
# train_loader = get_loader(image_path=config.train_path,
# image_size=config.image_size,
# batch_size=config.batch_size,
# num_workers=config.num_workers,
# mode='train',
# augmentation_prob=config.augmentation_prob)
# valid_loader = get_loader(image_path=config.valid_path,
# image_size=config.image_size,
# batch_size=config.batch_size,
# num_workers=config.num_workers,
# mode='valid',
# augmentation_prob=0.)
# test_loader = get_loader(image_path=config.test_path,
# image_size=config.image_size,
# batch_size=config.batch_size,
# num_workers=config.num_workers,
# mode='test',
# augmentation_prob=0.)
train_file = './data/train.csv'
val_file = './data/val.csv'
test_file = './data/val.csv'
train_data = Fudan_2D_Dataset(csv_file=train_file, phase='train', flip_rate=0)
train_loader = DataLoader(train_data, batch_size=config.batch_size, shuffle=True, num_workers=8)
valid_data = Fudan_2D_Dataset(csv_file=val_file, phase='val', flip_rate=0)
valid_loader = DataLoader(valid_data, batch_size=1, num_workers=8)
test_data = Fudan_2D_Dataset(csv_file=test_file, phase='val', flip_rate=0)
test_loader = DataLoader(test_data, batch_size=1, num_workers=8)
solver = Solver(config, train_loader, valid_loader, test_loader)
# Train and sample the images
if config.mode == 'train':
solver.train()
elif config.mode == 'test':
solver.test()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# model hyper-parameters
parser.add_argument('--image_size', type=int, default=224)
parser.add_argument('--t', type=int, default=3, help='t for Recurrent step of R2U_Net or R2AttU_Net')
# training hyper-parameters
parser.add_argument('--img_ch', type=int, default=3)
parser.add_argument('--output_ch', type=int, default=1)
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--num_epochs_decay', type=int, default=70)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--lr', type=float, default=0.0002)
parser.add_argument('--beta1', type=float, default=0.5) # momentum1 in Adam
parser.add_argument('--beta2', type=float, default=0.999) # momentum2 in Adam
parser.add_argument('--augmentation_prob', type=float, default=0.4)
parser.add_argument('--log_step', type=int, default=2)
parser.add_argument('--val_step', type=int, default=5)
# misc
parser.add_argument('--mode', type=str, default='train')
parser.add_argument('--model_type', type=str, default='AttU_Net', help='U_Net/R2U_Net/AttU_Net/R2AttU_Net')
parser.add_argument('--model_path', type=str, default='/mnt/HDD2/mingjian/results/Atten_unet/model/')
parser.add_argument('--train_path', type=str, default='./dataset/train/')
parser.add_argument('--valid_path', type=str, default='./dataset/valid/')
parser.add_argument('--test_path', type=str, default='./dataset/test/')
parser.add_argument('--result_path', type=str, default='/mnt/HDD2/mingjian/results/Atten_unet/result/')
parser.add_argument('--cuda_idx', type=int, default=1)
config = parser.parse_args()
main(config)