-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathtrain.py
93 lines (82 loc) · 3.77 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# -*- coding: utf-8 -*-
"""
Created on Sun Apr 28 18:32:15 2019
@author: wmy
"""
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
from keras import backend as K
from keras.losses import mean_absolute_error, mean_squared_error
from keras.models import load_model
from keras.optimizers import Adam
import random
from model import wdsr_a, wdsr_b
from optimizer import AdamWithWeightsNormalization
from utils import DataLoader
class SuperResolution(object):
def __init__(self, scale=4, num_res_blocks=32, pretrained_weights=None, name=None):
self.scale = scale
self.num_res_blocks = num_res_blocks
self.model = wdsr_b(scale=scale, num_res_blocks=num_res_blocks)
self.model.compile(optimizer=AdamWithWeightsNormalization(lr=0.001), \
loss=self.mae, metrics=[self.psnr])
if pretrained_weights != None:
self.model.load_weights(pretrained_weights)
print("[OK] weights loaded.")
pass
self.data_loader = DataLoader(scale=scale, crop_size=256)
self.pretrained_weights = pretrained_weights
self.default_weights_save_path = 'weights/wdsr-b-' + \
str(self.num_res_blocks) + '-x' + str(self.scale) + '.h5'
self.name = name
pass
def mae(self, hr, sr):
margin = (tf.shape(hr)[1] - tf.shape(sr)[1]) // 2
hr_crop = tf.cond(tf.equal(margin, 0), lambda: hr, lambda: hr[:, margin:-margin, margin:-margin, :])
hr = K.in_train_phase(hr_crop, hr)
hr.uses_learning_phase = True
return mean_absolute_error(hr, sr)
def psnr(self, hr, sr):
margin = (tf.shape(hr)[1] - tf.shape(sr)[1]) // 2
hr_crop = tf.cond(tf.equal(margin, 0), lambda: hr, lambda: hr[:, margin:-margin, margin:-margin, :])
hr = K.in_train_phase(hr_crop, hr)
hr.uses_learning_phase = True
return tf.image.psnr(hr, sr, max_val=255)
def train(self, epoches=10000, batch_size=8, weights_save_path=None):
if weights_save_path == None:
weights_save_path = self.default_weights_save_path
pass
for epoch in range(epoches):
for batch_i, (lrs, hrs) in enumerate(self.data_loader.batches(batch_size=batch_size)):
temp_loss, temp_psnr = self.model.train_on_batch(lrs, hrs)
print("[epoch: {}/{}][batch: {}/{}][loss: {}][psnr: {}]".format(epoch+1, epoches, \
batch_i+1, self.data_loader.n_batches, temp_loss, temp_psnr))
if (batch_i+1) % 25 == 0:
self.sample(epoch=epoch+1, batch=batch_i+1)
pass
pass
self.model.save_weights(weights_save_path)
print("[OK] weights saved.")
pass
pass
def sample(self, setpath='datasets/train', save_folder='samples', epoch=1, batch=1):
images = self.data_loader.search(setpath)
image = random.choice(images)
hr = self.data_loader.imread(image)
lr = self.data_loader.downsampling(hr)
lr_resize = lr.resize(hr.size)
lr = np.asarray(lr)
sr = self.model.predict(np.array([lr]))[0]
sr = np.clip(sr, 0, 255)
sr = sr.astype('uint8')
lr = Image.fromarray(lr)
sr = Image.fromarray(sr)
lr_resize.save(save_folder + "/" + "epoch_" + str(epoch) + "_batch_" + str(batch) + "_lr.jpg")
sr.save(save_folder + "/" + "epoch_" + str(epoch) + "_batch_" + str(batch) + "_sr.jpg")
hr.save(save_folder + "/" + "epoch_" + str(epoch) + "_batch_" + str(batch) + "_hr.jpg")
pass
pass
sr = SuperResolution(pretrained_weights='./weights/wdsr-b-32-x4.h5')
sr.train()