Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
cszn authored Nov 19, 2018
1 parent 5222842 commit 6f286d5
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions TrainingCodes/dncnn_pytorch/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def show(x, title=None, cbar=False, figsize=None):


def data_aug(img, mode=0):

# data augmentation
if mode == 0:
return img
elif mode == 1:
Expand All @@ -85,7 +85,7 @@ def data_aug(img, mode=0):


def gen_patches(file_name):

# get multiscale patches from a single image
img = cv2.imread(file_name, 0) # gray scale
h, w = img.shape
patches = []
Expand All @@ -103,18 +103,20 @@ def gen_patches(file_name):


def datagenerator(data_dir='data/Train400', verbose=False):
# generate clean patches from a dataset
file_list = glob.glob(data_dir+'/*.png') # get name list of all .png files
# initrialize
data = []
# generate patches
for i in range(len(file_list)):
patch = gen_patches(file_list[i])
data.append(patch)
patches = gen_patches(file_list[i])
for patch in patches:
data.append(patch)
if verbose:
print(str(i+1) + '/' + str(len(file_list)) + ' is done ^_^')
data = np.array(data, dtype='uint8')
data = data.reshape((data.shape[0]*data.shape[1], data.shape[2], data.shape[3], 1))
discard_n = len(data)-len(data)//batch_size*batch_size
data = np.expand_dims(data, axis=3)
discard_n = len(data)-len(data)//batch_size*batch_size # because of batch namalization
data = np.delete(data, range(discard_n), axis=0)
print('^_^-training data finished-^_^')
return data
Expand Down

0 comments on commit 6f286d5

Please sign in to comment.