forked from cszn/DnCNN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
20 changed files
with
477 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# DnCNN-keras | ||
|
||
This code is modified from (husqin)[https://github.com/husqin/DnCNN-keras]. | ||
|
||
### Dependence | ||
``` | ||
tensorflow | ||
keras2 | ||
numpy | ||
opencv | ||
``` | ||
|
||
### Train | ||
``` | ||
main_train.py | ||
``` | ||
|
||
### Test | ||
|
||
``` | ||
main_test.py | ||
``` | ||
|
||
### Results | ||
|
||
#### Gaussian Denoising | ||
|
||
The average PSNR(dB) results of different methods on the BSD68 dataset. | ||
|
||
| Noise Level | BM3D | DnCNN | DnCNN-keras | | ||
|:-------:|:-------:|:-------:|:-------:| | ||
| 25 | 28.57 | 29.23 | 29.23 | | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# ============================================================================= | ||
# @article{zhang2017beyond, | ||
# title={Beyond a {Gaussian} denoiser: Residual learning of deep {CNN} for image denoising}, | ||
# author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei}, | ||
# journal={IEEE Transactions on Image Processing}, | ||
# year={2017}, | ||
# volume={26}, | ||
# number={7}, | ||
# pages={3142-3155}, | ||
# } | ||
# by Kai Zhang (08/2018) | ||
# [email protected] | ||
# https://github.com/cszn | ||
# modified on the code from https://github.com/husqin/DnCNN-keras | ||
# ============================================================================= | ||
|
||
# no need to run this code separately | ||
|
||
|
||
import glob | ||
#import os | ||
import cv2 | ||
import numpy as np | ||
#from multiprocessing import Pool | ||
|
||
|
||
patch_size, stride = 40, 10 | ||
aug_times = 1 | ||
scales = [1, 0.9, 0.8, 0.7] | ||
batch_size = 128 | ||
|
||
|
||
def show(x,title=None,cbar=False,figsize=None): | ||
import matplotlib.pyplot as plt | ||
plt.figure(figsize=figsize) | ||
plt.imshow(x,interpolation='nearest',cmap='gray') | ||
if title: | ||
plt.title(title) | ||
if cbar: | ||
plt.colorbar() | ||
plt.show() | ||
|
||
def data_aug(img, mode=0): | ||
|
||
if mode == 0: | ||
return img | ||
elif mode == 1: | ||
return np.flipud(img) | ||
elif mode == 2: | ||
return np.rot90(img) | ||
elif mode == 3: | ||
return np.flipud(np.rot90(img)) | ||
elif mode == 4: | ||
return np.rot90(img, k=2) | ||
elif mode == 5: | ||
return np.flipud(np.rot90(img, k=2)) | ||
elif mode == 6: | ||
return np.rot90(img, k=3) | ||
elif mode == 7: | ||
return np.flipud(np.rot90(img, k=3)) | ||
|
||
def gen_patches(file_name): | ||
|
||
# read image | ||
img = cv2.imread(file_name, 0) # gray scale | ||
h, w = img.shape | ||
patches = [] | ||
for s in scales: | ||
h_scaled, w_scaled = int(h*s),int(w*s) | ||
img_scaled = cv2.resize(img, (h_scaled,w_scaled), interpolation=cv2.INTER_CUBIC) | ||
# extract patches | ||
for i in range(0, h_scaled-patch_size+1, stride): | ||
for j in range(0, w_scaled-patch_size+1, stride): | ||
x = img_scaled[i:i+patch_size, j:j+patch_size] | ||
#patches.append(x) | ||
# data aug | ||
for k in range(0, aug_times): | ||
x_aug = data_aug(x, mode=np.random.randint(0,8)) | ||
patches.append(x_aug) | ||
|
||
return patches | ||
|
||
def datagenerator(data_dir='data/Train400',verbose=False): | ||
|
||
file_list = glob.glob(data_dir+'/*.png') # get name list of all .png files | ||
# initrialize | ||
data = [] | ||
# generate patches | ||
for i in range(len(file_list)): | ||
patch = gen_patches(file_list[i]) | ||
data.append(patch) | ||
if verbose: | ||
print(str(i+1)+'/'+ str(len(file_list)) + ' is done ^_^') | ||
data = np.array(data, dtype='uint8') | ||
data = data.reshape((data.shape[0]*data.shape[1],data.shape[2],data.shape[3],1)) | ||
discard_n = len(data)-len(data)//batch_size*batch_size; | ||
data = np.delete(data,range(discard_n),axis = 0) | ||
print('^_^-training data finished-^_^') | ||
return data | ||
|
||
if __name__ == '__main__': | ||
|
||
data = datagenerator(data_dir='data/Train400') | ||
|
||
|
||
# print('Shape of result = ' + str(res.shape)) | ||
# print('Saving data...') | ||
# if not os.path.exists(save_dir): | ||
# os.mkdir(save_dir) | ||
# np.save(save_dir+'clean_patches.npy', res) | ||
# print('Done.') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# ============================================================================= | ||
# @article{zhang2017beyond, | ||
# title={Beyond a {Gaussian} denoiser: Residual learning of deep {CNN} for image denoising}, | ||
# author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei}, | ||
# journal={IEEE Transactions on Image Processing}, | ||
# year={2017}, | ||
# volume={26}, | ||
# number={7}, | ||
# pages={3142-3155}, | ||
# } | ||
# by Kai Zhang (08/2018) | ||
# [email protected] | ||
# https://github.com/cszn | ||
# modified on the code from https://github.com/husqin/DnCNN-keras | ||
# ============================================================================= | ||
|
||
# run this to test the model | ||
|
||
import argparse | ||
import os, time, datetime | ||
#import PIL.Image as Image | ||
import numpy as np | ||
from keras.models import load_model, model_from_json | ||
from skimage.measure import compare_psnr, compare_ssim | ||
from skimage.io import imread, imsave | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--set_dir', default='data/Test', type=str, help='directory of test dataset') | ||
parser.add_argument('--set_names', default=['Set68','Set12'], type=list, help='name of test dataset') | ||
parser.add_argument('--sigma', default=25, type=int, help='noise level') | ||
parser.add_argument('--model_dir', default=os.path.join('models','DnCNN_sigma25'), type=str, help='directory of the model') | ||
parser.add_argument('--model_name', default='model_001.hdf5', type=str, help='the model name') | ||
parser.add_argument('--result_dir', default='results', type=str, help='directory of results') | ||
parser.add_argument('--save_result', default=0, type=int, help='save the denoised image, 1 or 0') | ||
return parser.parse_args() | ||
|
||
def to_tensor(img): | ||
if img.ndim == 2: | ||
return img[np.newaxis,...,np.newaxis] | ||
elif img.ndim == 3: | ||
return np.moveaxis(img,2,0)[...,np.newaxis] | ||
|
||
def from_tensor(img): | ||
return np.squeeze(np.moveaxis(img[...,0],0,-1)) | ||
|
||
def log(*args,**kwargs): | ||
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"),*args,**kwargs) | ||
|
||
def save_result(result,path): | ||
path = path if path.find('.') != -1 else path+'.png' | ||
ext = os.path.splitext(path)[-1] | ||
if ext in ('.txt','.dlm'): | ||
np.savetxt(path,result,fmt='%2.4f') | ||
else: | ||
imsave(path,np.clip(result,0,1)) | ||
|
||
|
||
def show(x,title=None,cbar=False,figsize=None): | ||
import matplotlib.pyplot as plt | ||
plt.figure(figsize=figsize) | ||
plt.imshow(x,interpolation='nearest',cmap='gray') | ||
if title: | ||
plt.title(title) | ||
if cbar: | ||
plt.colorbar() | ||
plt.show() | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
args = parse_args() | ||
|
||
|
||
# ============================================================================= | ||
# # serialize model to JSON | ||
# model_json = model.to_json() | ||
# with open("model.json", "w") as json_file: | ||
# json_file.write(model_json) | ||
# # serialize weights to HDF5 | ||
# model.save_weights("model.h5") | ||
# print("Saved model") | ||
# ============================================================================= | ||
|
||
if not os.path.exists(os.path.join(args.model_dir, args.model_name)): | ||
# load json and create model | ||
json_file = open(os.path.join(args.model_dir,'model.json'), 'r') | ||
loaded_model_json = json_file.read() | ||
json_file.close() | ||
model = model_from_json(loaded_model_json) | ||
# load weights into new model | ||
model.load_weights(os.path.join(args.model_dir,'model.h5')) | ||
log('load trained model on Train400 dataset by kai') | ||
else: | ||
model = load_model(os.path.join(args.model_dir, args.model_name),compile=False) | ||
log('load trained model') | ||
|
||
if not os.path.exists(args.result_dir): | ||
os.mkdir(args.result_dir) | ||
|
||
for set_cur in args.set_names: | ||
|
||
if not os.path.exists(os.path.join(args.result_dir,set_cur)): | ||
os.mkdir(os.path.join(args.result_dir,set_cur)) | ||
psnrs = [] | ||
ssims = [] | ||
|
||
for im in os.listdir(os.path.join(args.set_dir,set_cur)): | ||
if im.endswith(".jpg") or im.endswith(".bmp") or im.endswith(".png"): | ||
#x = np.array(Image.open(os.path.join(args.set_dir,set_cur,im)), dtype='float32') / 255.0 | ||
x = np.array(imread(os.path.join(args.set_dir,set_cur,im)), dtype=np.float32) / 255.0 | ||
np.random.seed(seed=0) # for reproducibility | ||
y = x + np.random.normal(0, args.sigma/255.0, x.shape) # Add Gaussian noise without clipping | ||
y = y.astype(np.float32) | ||
y_ = to_tensor(y) | ||
start_time = time.time() | ||
x_ = model.predict(y_) # inference | ||
elapsed_time = time.time() - start_time | ||
print('%10s : %10s : %2.4f second'%(set_cur,im,elapsed_time)) | ||
x_=from_tensor(x_) | ||
psnr_x_ = compare_psnr(x, x_) | ||
ssim_x_ = compare_ssim(x, x_) | ||
if args.save_result: | ||
name, ext = os.path.splitext(im) | ||
show(np.hstack((y,x_))) # show the image | ||
save_result(x_,path=os.path.join(args.result_dir,set_cur,name+'_dncnn'+ext)) # save the denoised image | ||
psnrs.append(psnr_x_) | ||
ssims.append(ssim_x_) | ||
|
||
psnr_avg = np.mean(psnrs) | ||
ssim_avg = np.mean(ssims) | ||
psnrs.append(psnr_avg) | ||
ssims.append(ssim_avg) | ||
|
||
if args.save_result: | ||
save_result(np.hstack((psnrs,ssims)),path=os.path.join(args.result_dir,set_cur,'results.txt')) | ||
|
||
log('Datset: {0:10s} \n PSNR = {1:2.2f}dB, SSIM = {2:1.4f}'.format(set_cur, psnr_avg, ssim_avg)) | ||
|
||
|
||
|
||
|
Oops, something went wrong.