forked from cszn/DnCNN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
532 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# DnCNN-keras | ||
|
||
This code is modified from (SaoYan)[https://github.com/SaoYan/DnCNN-PyTorch]. | ||
|
||
## Dependence | ||
``` | ||
pytorch 0.4.1 | ||
``` | ||
|
||
## Train | ||
``` | ||
main_train.py | ||
``` | ||
|
||
## Test | ||
|
||
``` | ||
main_test.py | ||
``` | ||
|
||
## Results | ||
|
||
### Gaussian Denoising | ||
|
||
The average PSNR(dB) results of different methods on the BSD68 dataset. | ||
|
||
| Noise Level | BM3D | DnCNN | DnCNN-PyTorch | | ||
|:-------:|:-------:|:-------:|:-------:| | ||
| 25 | 28.57 | 29.23 | 29.24 | | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# ============================================================================= | ||
# @article{zhang2017beyond, | ||
# title={Beyond a {Gaussian} denoiser: Residual learning of deep {CNN} for image denoising}, | ||
# author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei}, | ||
# journal={IEEE Transactions on Image Processing}, | ||
# year={2017}, | ||
# volume={26}, | ||
# number={7}, | ||
# pages={3142-3155}, | ||
# } | ||
# by Kai Zhang (08/2018) | ||
# [email protected] | ||
# https://github.com/cszn | ||
# modified on the code from https://github.com/SaoYan/DnCNN-PyTorch | ||
# ============================================================================= | ||
|
||
# no need to run this code separately | ||
|
||
|
||
import glob | ||
import cv2 | ||
import numpy as np | ||
# from multiprocessing import Pool | ||
from torch.utils.data import Dataset | ||
import torch | ||
|
||
patch_size, stride = 40, 10 | ||
aug_times = 1 | ||
scales = [1, 0.9, 0.8, 0.7] | ||
batch_size = 128 | ||
|
||
|
||
class DenoisingDataset(Dataset): | ||
"""Dataset wrapping tensors. | ||
Arguments: | ||
xs (Tensor): clean image patches | ||
sigma: noise level, e.g., 25 | ||
""" | ||
def __init__(self, xs, sigma): | ||
super(DenoisingDataset, self).__init__() | ||
self.xs = xs | ||
self.sigma = sigma | ||
|
||
def __getitem__(self, index): | ||
batch_x = self.xs[index] | ||
noise = torch.randn(batch_x.size()).mul_(self.sigma/255.0) | ||
batch_y = batch_x + noise | ||
return batch_y, batch_x | ||
|
||
def __len__(self): | ||
return self.xs.size(0) | ||
|
||
|
||
def show(x, title=None, cbar=False, figsize=None): | ||
import matplotlib.pyplot as plt | ||
plt.figure(figsize=figsize) | ||
plt.imshow(x, interpolation='nearest', cmap='gray') | ||
if title: | ||
plt.title(title) | ||
if cbar: | ||
plt.colorbar() | ||
plt.show() | ||
|
||
|
||
def data_aug(img, mode=0): | ||
|
||
if mode == 0: | ||
return img | ||
elif mode == 1: | ||
return np.flipud(img) | ||
elif mode == 2: | ||
return np.rot90(img) | ||
elif mode == 3: | ||
return np.flipud(np.rot90(img)) | ||
elif mode == 4: | ||
return np.rot90(img, k=2) | ||
elif mode == 5: | ||
return np.flipud(np.rot90(img, k=2)) | ||
elif mode == 6: | ||
return np.rot90(img, k=3) | ||
elif mode == 7: | ||
return np.flipud(np.rot90(img, k=3)) | ||
|
||
|
||
def gen_patches(file_name): | ||
|
||
img = cv2.imread(file_name, 0) # gray scale | ||
h, w = img.shape | ||
patches = [] | ||
for s in scales: | ||
h_scaled, w_scaled = int(h*s), int(w*s) | ||
img_scaled = cv2.resize(img, (h_scaled, w_scaled), interpolation=cv2.INTER_CUBIC) | ||
# extract patches | ||
for i in range(0, h_scaled-patch_size+1, stride): | ||
for j in range(0, w_scaled-patch_size+1, stride): | ||
x = img_scaled[i:i+patch_size, j:j+patch_size] | ||
for k in range(0, aug_times): | ||
x_aug = data_aug(x, mode=np.random.randint(0, 8)) | ||
patches.append(x_aug) | ||
return patches | ||
|
||
|
||
def datagenerator(data_dir='data/Train400', verbose=False): | ||
file_list = glob.glob(data_dir+'/*.png') # get name list of all .png files | ||
# initrialize | ||
data = [] | ||
# generate patches | ||
for i in range(len(file_list)): | ||
patch = gen_patches(file_list[i]) | ||
data.append(patch) | ||
if verbose: | ||
print(str(i+1) + '/' + str(len(file_list)) + ' is done ^_^') | ||
data = np.array(data, dtype='uint8') | ||
data = data.reshape((data.shape[0]*data.shape[1], data.shape[2], data.shape[3], 1)) | ||
discard_n = len(data)-len(data)//batch_size*batch_size | ||
data = np.delete(data, range(discard_n), axis=0) | ||
print('^_^-training data finished-^_^') | ||
return data | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
data = datagenerator(data_dir='data/Train400') | ||
|
||
|
||
# print('Shape of result = ' + str(res.shape)) | ||
# print('Saving data...') | ||
# if not os.path.exists(save_dir): | ||
# os.mkdir(save_dir) | ||
# np.save(save_dir+'clean_patches.npy', res) | ||
# print('Done.') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# ============================================================================= | ||
# @article{zhang2017beyond, | ||
# title={Beyond a {Gaussian} denoiser: Residual learning of deep {CNN} for image denoising}, | ||
# author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei}, | ||
# journal={IEEE Transactions on Image Processing}, | ||
# year={2017}, | ||
# volume={26}, | ||
# number={7}, | ||
# pages={3142-3155}, | ||
# } | ||
# by Kai Zhang (08/2018) | ||
# [email protected] | ||
# https://github.com/cszn | ||
# modified on the code from https://github.com/SaoYan/DnCNN-PyTorch | ||
# ============================================================================= | ||
|
||
# run this to test the model | ||
|
||
import argparse | ||
import os, time, datetime | ||
# import PIL.Image as Image | ||
import numpy as np | ||
import torch.nn as nn | ||
import torch.nn.init as init | ||
import torch | ||
from skimage.measure import compare_psnr, compare_ssim | ||
from skimage.io import imread, imsave | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--set_dir', default='data/Test', type=str, help='directory of test dataset') | ||
parser.add_argument('--set_names', default=['Set68', 'Set12'], help='directory of test dataset') | ||
parser.add_argument('--sigma', default=25, type=int, help='noise level') | ||
parser.add_argument('--model_dir', default=os.path.join('models', 'DnCNN_sigma25'), help='directory of the model') | ||
parser.add_argument('--model_name', default='model_001.pth', type=str, help='the model name') | ||
parser.add_argument('--result_dir', default='results', type=str, help='directory of test dataset') | ||
parser.add_argument('--save_result', default=0, type=int, help='save the denoised image, 1 or 0') | ||
return parser.parse_args() | ||
|
||
|
||
def log(*args, **kwargs): | ||
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs) | ||
|
||
|
||
def save_result(result, path): | ||
path = path if path.find('.') != -1 else path+'.png' | ||
ext = os.path.splitext(path)[-1] | ||
if ext in ('.txt', '.dlm'): | ||
np.savetxt(path, result, fmt='%2.4f') | ||
else: | ||
imsave(path, np.clip(result, 0, 1)) | ||
|
||
|
||
def show(x, title=None, cbar=False, figsize=None): | ||
import matplotlib.pyplot as plt | ||
plt.figure(figsize=figsize) | ||
plt.imshow(x, interpolation='nearest', cmap='gray') | ||
if title: | ||
plt.title(title) | ||
if cbar: | ||
plt.colorbar() | ||
plt.show() | ||
|
||
|
||
class DnCNN(nn.Module): | ||
|
||
def __init__(self, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3): | ||
super(DnCNN, self).__init__() | ||
kernel_size = 3 | ||
padding = 1 | ||
layers = [] | ||
layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True)) | ||
layers.append(nn.ReLU(inplace=True)) | ||
for _ in range(depth-2): | ||
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False)) | ||
layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum=0.95)) | ||
layers.append(nn.ReLU(inplace=True)) | ||
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False)) | ||
self.dncnn = nn.Sequential(*layers) | ||
self._initialize_weights() | ||
|
||
def forward(self, x): | ||
y = x | ||
out = self.dncnn(x) | ||
return y-out | ||
|
||
def _initialize_weights(self): | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
init.orthogonal_(m.weight) | ||
print('init weight') | ||
if m.bias is not None: | ||
init.constant_(m.bias, 0) | ||
elif isinstance(m, nn.BatchNorm2d): | ||
init.constant_(m.weight, 1) | ||
init.constant_(m.bias, 0) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
args = parse_args() | ||
|
||
# model = DnCNN() | ||
if not os.path.exists(os.path.join(args.model_dir, args.model_name)): | ||
|
||
model = torch.load(os.path.join(args.model_dir, 'model.pth')) | ||
# load weights into new model | ||
log('load trained model on Train400 dataset by kai') | ||
else: | ||
# model.load_state_dict(torch.load(os.path.join(args.model_dir, args.model_name))) | ||
model = torch.load(os.path.join(args.model_dir, args.model_name)) | ||
log('load trained model') | ||
|
||
# params = model.state_dict() | ||
# print(params.values()) | ||
# print(params.keys()) | ||
# | ||
# for key, value in params.items(): | ||
# print(key) # parameter name | ||
# print(params['dncnn.12.running_mean']) | ||
# print(model.state_dict()) | ||
|
||
model.eval() # evaluation mode | ||
# model.train() | ||
|
||
if torch.cuda.is_available(): | ||
model = model.cuda() | ||
|
||
if not os.path.exists(args.result_dir): | ||
os.mkdir(args.result_dir) | ||
|
||
for set_cur in args.set_names: | ||
|
||
if not os.path.exists(os.path.join(args.result_dir, set_cur)): | ||
os.mkdir(os.path.join(args.result_dir, set_cur)) | ||
psnrs = [] | ||
ssims = [] | ||
|
||
for im in os.listdir(os.path.join(args.set_dir, set_cur)): | ||
if im.endswith(".jpg") or im.endswith(".bmp") or im.endswith(".png"): | ||
|
||
x = np.array(imread(os.path.join(args.set_dir, set_cur, im)), dtype=np.float32)/255.0 | ||
np.random.seed(seed=0) # for reproducibility | ||
y = x + np.random.normal(0, args.sigma/255.0, x.shape) # Add Gaussian noise without clipping | ||
y = y.astype(np.float32) | ||
y_ = torch.from_numpy(y).view(1, -1, y.shape[0], y.shape[1]) | ||
|
||
torch.cuda.synchronize() | ||
start_time = time.time() | ||
y_ = y_.cuda() | ||
x_ = model(y_) # inference | ||
x_ = x_.view(y.shape[0], y.shape[1]) | ||
x_ = x_.cpu() | ||
x_ = x_.detach().numpy().astype(np.float32) | ||
torch.cuda.synchronize() | ||
elapsed_time = time.time() - start_time | ||
print('%10s : %10s : %2.4f second' % (set_cur, im, elapsed_time)) | ||
|
||
psnr_x_ = compare_psnr(x, x_) | ||
ssim_x_ = compare_ssim(x, x_) | ||
if args.save_result: | ||
name, ext = os.path.splitext(im) | ||
show(np.hstack((y, x_))) # show the image | ||
save_result(x_, path=os.path.join(args.result_dir, set_cur, name+'_dncnn'+ext)) # save the denoised image | ||
psnrs.append(psnr_x_) | ||
ssims.append(ssim_x_) | ||
psnr_avg = np.mean(psnrs) | ||
ssim_avg = np.mean(ssims) | ||
psnrs.append(psnr_avg) | ||
ssims.append(ssim_avg) | ||
if args.save_result: | ||
save_result(np.hstack((psnrs, ssims)), path=os.path.join(args.result_dir, set_cur, 'results.txt')) | ||
log('Datset: {0:10s} \n PSNR = {1:2.2f}dB, SSIM = {2:1.4f}'.format(set_cur, psnr_avg, ssim_avg)) | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
Oops, something went wrong.