Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
cszn authored Aug 23, 2018
1 parent c6d6af9 commit bc43813
Show file tree
Hide file tree
Showing 20 changed files with 477 additions and 0 deletions.
39 changes: 39 additions & 0 deletions TrainingCodes/dncnn_keras/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# DnCNN-keras

This code is modified from (husqin)[https://github.com/husqin/DnCNN-keras].

### Dependence
```
tensorflow
keras2
numpy
opencv
```

### Train
```
main_train.py
```

### Test

```
main_test.py
```

### Results

#### Gaussian Denoising

The average PSNR(dB) results of different methods on the BSD68 dataset.

| Noise Level | BM3D | DnCNN | DnCNN-keras |
|:-------:|:-------:|:-------:|:-------:|
| 25 | 28.57 | 29.23 | 29.23 |







Binary file added TrainingCodes/dncnn_keras/data/Test/Set12/01.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added TrainingCodes/dncnn_keras/data/Test/Set12/02.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added TrainingCodes/dncnn_keras/data/Test/Set12/03.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added TrainingCodes/dncnn_keras/data/Test/Set12/04.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added TrainingCodes/dncnn_keras/data/Test/Set12/05.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added TrainingCodes/dncnn_keras/data/Test/Set12/07.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added TrainingCodes/dncnn_keras/data/Test/Set12/08.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added TrainingCodes/dncnn_keras/data/Test/Set12/09.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added TrainingCodes/dncnn_keras/data/Test/Set12/10.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added TrainingCodes/dncnn_keras/data/Test/Set12/11.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added TrainingCodes/dncnn_keras/data/Test/Set12/12.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
113 changes: 113 additions & 0 deletions TrainingCodes/dncnn_keras/data_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-

# =============================================================================
# @article{zhang2017beyond,
# title={Beyond a {Gaussian} denoiser: Residual learning of deep {CNN} for image denoising},
# author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
# journal={IEEE Transactions on Image Processing},
# year={2017},
# volume={26},
# number={7},
# pages={3142-3155},
# }
# by Kai Zhang (08/2018)
# [email protected]
# https://github.com/cszn
# modified on the code from https://github.com/husqin/DnCNN-keras
# =============================================================================

# no need to run this code separately


import glob
#import os
import cv2
import numpy as np
#from multiprocessing import Pool


patch_size, stride = 40, 10
aug_times = 1
scales = [1, 0.9, 0.8, 0.7]
batch_size = 128


def show(x,title=None,cbar=False,figsize=None):
import matplotlib.pyplot as plt
plt.figure(figsize=figsize)
plt.imshow(x,interpolation='nearest',cmap='gray')
if title:
plt.title(title)
if cbar:
plt.colorbar()
plt.show()

def data_aug(img, mode=0):

if mode == 0:
return img
elif mode == 1:
return np.flipud(img)
elif mode == 2:
return np.rot90(img)
elif mode == 3:
return np.flipud(np.rot90(img))
elif mode == 4:
return np.rot90(img, k=2)
elif mode == 5:
return np.flipud(np.rot90(img, k=2))
elif mode == 6:
return np.rot90(img, k=3)
elif mode == 7:
return np.flipud(np.rot90(img, k=3))

def gen_patches(file_name):

# read image
img = cv2.imread(file_name, 0) # gray scale
h, w = img.shape
patches = []
for s in scales:
h_scaled, w_scaled = int(h*s),int(w*s)
img_scaled = cv2.resize(img, (h_scaled,w_scaled), interpolation=cv2.INTER_CUBIC)
# extract patches
for i in range(0, h_scaled-patch_size+1, stride):
for j in range(0, w_scaled-patch_size+1, stride):
x = img_scaled[i:i+patch_size, j:j+patch_size]
#patches.append(x)
# data aug
for k in range(0, aug_times):
x_aug = data_aug(x, mode=np.random.randint(0,8))
patches.append(x_aug)

return patches

def datagenerator(data_dir='data/Train400',verbose=False):

file_list = glob.glob(data_dir+'/*.png') # get name list of all .png files
# initrialize
data = []
# generate patches
for i in range(len(file_list)):
patch = gen_patches(file_list[i])
data.append(patch)
if verbose:
print(str(i+1)+'/'+ str(len(file_list)) + ' is done ^_^')
data = np.array(data, dtype='uint8')
data = data.reshape((data.shape[0]*data.shape[1],data.shape[2],data.shape[3],1))
discard_n = len(data)-len(data)//batch_size*batch_size;
data = np.delete(data,range(discard_n),axis = 0)
print('^_^-training data finished-^_^')
return data

if __name__ == '__main__':

data = datagenerator(data_dir='data/Train400')


# print('Shape of result = ' + str(res.shape))
# print('Saving data...')
# if not os.path.exists(save_dir):
# os.mkdir(save_dir)
# np.save(save_dir+'clean_patches.npy', res)
# print('Done.')
145 changes: 145 additions & 0 deletions TrainingCodes/dncnn_keras/main_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# -*- coding: utf-8 -*-

# =============================================================================
# @article{zhang2017beyond,
# title={Beyond a {Gaussian} denoiser: Residual learning of deep {CNN} for image denoising},
# author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
# journal={IEEE Transactions on Image Processing},
# year={2017},
# volume={26},
# number={7},
# pages={3142-3155},
# }
# by Kai Zhang (08/2018)
# [email protected]
# https://github.com/cszn
# modified on the code from https://github.com/husqin/DnCNN-keras
# =============================================================================

# run this to test the model

import argparse
import os, time, datetime
#import PIL.Image as Image
import numpy as np
from keras.models import load_model, model_from_json
from skimage.measure import compare_psnr, compare_ssim
from skimage.io import imread, imsave


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--set_dir', default='data/Test', type=str, help='directory of test dataset')
parser.add_argument('--set_names', default=['Set68','Set12'], type=list, help='name of test dataset')
parser.add_argument('--sigma', default=25, type=int, help='noise level')
parser.add_argument('--model_dir', default=os.path.join('models','DnCNN_sigma25'), type=str, help='directory of the model')
parser.add_argument('--model_name', default='model_001.hdf5', type=str, help='the model name')
parser.add_argument('--result_dir', default='results', type=str, help='directory of results')
parser.add_argument('--save_result', default=0, type=int, help='save the denoised image, 1 or 0')
return parser.parse_args()

def to_tensor(img):
if img.ndim == 2:
return img[np.newaxis,...,np.newaxis]
elif img.ndim == 3:
return np.moveaxis(img,2,0)[...,np.newaxis]

def from_tensor(img):
return np.squeeze(np.moveaxis(img[...,0],0,-1))

def log(*args,**kwargs):
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"),*args,**kwargs)

def save_result(result,path):
path = path if path.find('.') != -1 else path+'.png'
ext = os.path.splitext(path)[-1]
if ext in ('.txt','.dlm'):
np.savetxt(path,result,fmt='%2.4f')
else:
imsave(path,np.clip(result,0,1))


def show(x,title=None,cbar=False,figsize=None):
import matplotlib.pyplot as plt
plt.figure(figsize=figsize)
plt.imshow(x,interpolation='nearest',cmap='gray')
if title:
plt.title(title)
if cbar:
plt.colorbar()
plt.show()


if __name__ == '__main__':

args = parse_args()


# =============================================================================
# # serialize model to JSON
# model_json = model.to_json()
# with open("model.json", "w") as json_file:
# json_file.write(model_json)
# # serialize weights to HDF5
# model.save_weights("model.h5")
# print("Saved model")
# =============================================================================

if not os.path.exists(os.path.join(args.model_dir, args.model_name)):
# load json and create model
json_file = open(os.path.join(args.model_dir,'model.json'), 'r')
loaded_model_json = json_file.read()
json_file.close()
model = model_from_json(loaded_model_json)
# load weights into new model
model.load_weights(os.path.join(args.model_dir,'model.h5'))
log('load trained model on Train400 dataset by kai')
else:
model = load_model(os.path.join(args.model_dir, args.model_name),compile=False)
log('load trained model')

if not os.path.exists(args.result_dir):
os.mkdir(args.result_dir)

for set_cur in args.set_names:

if not os.path.exists(os.path.join(args.result_dir,set_cur)):
os.mkdir(os.path.join(args.result_dir,set_cur))
psnrs = []
ssims = []

for im in os.listdir(os.path.join(args.set_dir,set_cur)):
if im.endswith(".jpg") or im.endswith(".bmp") or im.endswith(".png"):
#x = np.array(Image.open(os.path.join(args.set_dir,set_cur,im)), dtype='float32') / 255.0
x = np.array(imread(os.path.join(args.set_dir,set_cur,im)), dtype=np.float32) / 255.0
np.random.seed(seed=0) # for reproducibility
y = x + np.random.normal(0, args.sigma/255.0, x.shape) # Add Gaussian noise without clipping
y = y.astype(np.float32)
y_ = to_tensor(y)
start_time = time.time()
x_ = model.predict(y_) # inference
elapsed_time = time.time() - start_time
print('%10s : %10s : %2.4f second'%(set_cur,im,elapsed_time))
x_=from_tensor(x_)
psnr_x_ = compare_psnr(x, x_)
ssim_x_ = compare_ssim(x, x_)
if args.save_result:
name, ext = os.path.splitext(im)
show(np.hstack((y,x_))) # show the image
save_result(x_,path=os.path.join(args.result_dir,set_cur,name+'_dncnn'+ext)) # save the denoised image
psnrs.append(psnr_x_)
ssims.append(ssim_x_)

psnr_avg = np.mean(psnrs)
ssim_avg = np.mean(ssims)
psnrs.append(psnr_avg)
ssims.append(ssim_avg)

if args.save_result:
save_result(np.hstack((psnrs,ssims)),path=os.path.join(args.result_dir,set_cur,'results.txt'))

log('Datset: {0:10s} \n PSNR = {1:2.2f}dB, SSIM = {2:1.4f}'.format(set_cur, psnr_avg, ssim_avg))




Loading

0 comments on commit bc43813

Please sign in to comment.