diff --git a/TrainingCodes/dncnn_pytorch/data_generator.py b/TrainingCodes/dncnn_pytorch/data_generator.py index 6be841b0..f5402e6f 100644 --- a/TrainingCodes/dncnn_pytorch/data_generator.py +++ b/TrainingCodes/dncnn_pytorch/data_generator.py @@ -65,7 +65,7 @@ def show(x, title=None, cbar=False, figsize=None): def data_aug(img, mode=0): - + # data augmentation if mode == 0: return img elif mode == 1: @@ -85,7 +85,7 @@ def data_aug(img, mode=0): def gen_patches(file_name): - + # get multiscale patches from a single image img = cv2.imread(file_name, 0) # gray scale h, w = img.shape patches = [] @@ -103,18 +103,20 @@ def gen_patches(file_name): def datagenerator(data_dir='data/Train400', verbose=False): + # generate clean patches from a dataset file_list = glob.glob(data_dir+'/*.png') # get name list of all .png files # initrialize data = [] # generate patches for i in range(len(file_list)): - patch = gen_patches(file_list[i]) - data.append(patch) + patches = gen_patches(file_list[i]) + for patch in patches: + data.append(patch) if verbose: print(str(i+1) + '/' + str(len(file_list)) + ' is done ^_^') data = np.array(data, dtype='uint8') - data = data.reshape((data.shape[0]*data.shape[1], data.shape[2], data.shape[3], 1)) - discard_n = len(data)-len(data)//batch_size*batch_size + data = np.expand_dims(data, axis=3) + discard_n = len(data)-len(data)//batch_size*batch_size # because of batch namalization data = np.delete(data, range(discard_n), axis=0) print('^_^-training data finished-^_^') return data