Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
cszn authored Aug 30, 2018
1 parent 0e51288 commit e9eb8c2
Show file tree
Hide file tree
Showing 5 changed files with 532 additions and 0 deletions.
36 changes: 36 additions & 0 deletions TrainingCodes/dncnn_pytorch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# DnCNN-keras

This code is modified from (SaoYan)[https://github.com/SaoYan/DnCNN-PyTorch].

## Dependence
```
pytorch 0.4.1
```

## Train
```
main_train.py
```

## Test

```
main_test.py
```

## Results

### Gaussian Denoising

The average PSNR(dB) results of different methods on the BSD68 dataset.

| Noise Level | BM3D | DnCNN | DnCNN-PyTorch |
|:-------:|:-------:|:-------:|:-------:|
| 25 | 28.57 | 29.23 | 29.24 |







133 changes: 133 additions & 0 deletions TrainingCodes/dncnn_pytorch/data_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# -*- coding: utf-8 -*-

# =============================================================================
# @article{zhang2017beyond,
# title={Beyond a {Gaussian} denoiser: Residual learning of deep {CNN} for image denoising},
# author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
# journal={IEEE Transactions on Image Processing},
# year={2017},
# volume={26},
# number={7},
# pages={3142-3155},
# }
# by Kai Zhang (08/2018)
# [email protected]
# https://github.com/cszn
# modified on the code from https://github.com/SaoYan/DnCNN-PyTorch
# =============================================================================

# no need to run this code separately


import glob
import cv2
import numpy as np
# from multiprocessing import Pool
from torch.utils.data import Dataset
import torch

patch_size, stride = 40, 10
aug_times = 1
scales = [1, 0.9, 0.8, 0.7]
batch_size = 128


class DenoisingDataset(Dataset):
"""Dataset wrapping tensors.
Arguments:
xs (Tensor): clean image patches
sigma: noise level, e.g., 25
"""
def __init__(self, xs, sigma):
super(DenoisingDataset, self).__init__()
self.xs = xs
self.sigma = sigma

def __getitem__(self, index):
batch_x = self.xs[index]
noise = torch.randn(batch_x.size()).mul_(self.sigma/255.0)
batch_y = batch_x + noise
return batch_y, batch_x

def __len__(self):
return self.xs.size(0)


def show(x, title=None, cbar=False, figsize=None):
import matplotlib.pyplot as plt
plt.figure(figsize=figsize)
plt.imshow(x, interpolation='nearest', cmap='gray')
if title:
plt.title(title)
if cbar:
plt.colorbar()
plt.show()


def data_aug(img, mode=0):

if mode == 0:
return img
elif mode == 1:
return np.flipud(img)
elif mode == 2:
return np.rot90(img)
elif mode == 3:
return np.flipud(np.rot90(img))
elif mode == 4:
return np.rot90(img, k=2)
elif mode == 5:
return np.flipud(np.rot90(img, k=2))
elif mode == 6:
return np.rot90(img, k=3)
elif mode == 7:
return np.flipud(np.rot90(img, k=3))


def gen_patches(file_name):

img = cv2.imread(file_name, 0) # gray scale
h, w = img.shape
patches = []
for s in scales:
h_scaled, w_scaled = int(h*s), int(w*s)
img_scaled = cv2.resize(img, (h_scaled, w_scaled), interpolation=cv2.INTER_CUBIC)
# extract patches
for i in range(0, h_scaled-patch_size+1, stride):
for j in range(0, w_scaled-patch_size+1, stride):
x = img_scaled[i:i+patch_size, j:j+patch_size]
for k in range(0, aug_times):
x_aug = data_aug(x, mode=np.random.randint(0, 8))
patches.append(x_aug)
return patches


def datagenerator(data_dir='data/Train400', verbose=False):
file_list = glob.glob(data_dir+'/*.png') # get name list of all .png files
# initrialize
data = []
# generate patches
for i in range(len(file_list)):
patch = gen_patches(file_list[i])
data.append(patch)
if verbose:
print(str(i+1) + '/' + str(len(file_list)) + ' is done ^_^')
data = np.array(data, dtype='uint8')
data = data.reshape((data.shape[0]*data.shape[1], data.shape[2], data.shape[3], 1))
discard_n = len(data)-len(data)//batch_size*batch_size
data = np.delete(data, range(discard_n), axis=0)
print('^_^-training data finished-^_^')
return data


if __name__ == '__main__':

data = datagenerator(data_dir='data/Train400')


# print('Shape of result = ' + str(res.shape))
# print('Saving data...')
# if not os.path.exists(save_dir):
# os.mkdir(save_dir)
# np.save(save_dir+'clean_patches.npy', res)
# print('Done.')
184 changes: 184 additions & 0 deletions TrainingCodes/dncnn_pytorch/main_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# -*- coding: utf-8 -*-

# =============================================================================
# @article{zhang2017beyond,
# title={Beyond a {Gaussian} denoiser: Residual learning of deep {CNN} for image denoising},
# author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
# journal={IEEE Transactions on Image Processing},
# year={2017},
# volume={26},
# number={7},
# pages={3142-3155},
# }
# by Kai Zhang (08/2018)
# [email protected]
# https://github.com/cszn
# modified on the code from https://github.com/SaoYan/DnCNN-PyTorch
# =============================================================================

# run this to test the model

import argparse
import os, time, datetime
# import PIL.Image as Image
import numpy as np
import torch.nn as nn
import torch.nn.init as init
import torch
from skimage.measure import compare_psnr, compare_ssim
from skimage.io import imread, imsave


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--set_dir', default='data/Test', type=str, help='directory of test dataset')
parser.add_argument('--set_names', default=['Set68', 'Set12'], help='directory of test dataset')
parser.add_argument('--sigma', default=25, type=int, help='noise level')
parser.add_argument('--model_dir', default=os.path.join('models', 'DnCNN_sigma25'), help='directory of the model')
parser.add_argument('--model_name', default='model_001.pth', type=str, help='the model name')
parser.add_argument('--result_dir', default='results', type=str, help='directory of test dataset')
parser.add_argument('--save_result', default=0, type=int, help='save the denoised image, 1 or 0')
return parser.parse_args()


def log(*args, **kwargs):
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)


def save_result(result, path):
path = path if path.find('.') != -1 else path+'.png'
ext = os.path.splitext(path)[-1]
if ext in ('.txt', '.dlm'):
np.savetxt(path, result, fmt='%2.4f')
else:
imsave(path, np.clip(result, 0, 1))


def show(x, title=None, cbar=False, figsize=None):
import matplotlib.pyplot as plt
plt.figure(figsize=figsize)
plt.imshow(x, interpolation='nearest', cmap='gray')
if title:
plt.title(title)
if cbar:
plt.colorbar()
plt.show()


class DnCNN(nn.Module):

def __init__(self, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3):
super(DnCNN, self).__init__()
kernel_size = 3
padding = 1
layers = []
layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
layers.append(nn.ReLU(inplace=True))
for _ in range(depth-2):
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum=0.95))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
self.dncnn = nn.Sequential(*layers)
self._initialize_weights()

def forward(self, x):
y = x
out = self.dncnn(x)
return y-out

def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.orthogonal_(m.weight)
print('init weight')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)


if __name__ == '__main__':

args = parse_args()

# model = DnCNN()
if not os.path.exists(os.path.join(args.model_dir, args.model_name)):

model = torch.load(os.path.join(args.model_dir, 'model.pth'))
# load weights into new model
log('load trained model on Train400 dataset by kai')
else:
# model.load_state_dict(torch.load(os.path.join(args.model_dir, args.model_name)))
model = torch.load(os.path.join(args.model_dir, args.model_name))
log('load trained model')

# params = model.state_dict()
# print(params.values())
# print(params.keys())
#
# for key, value in params.items():
# print(key) # parameter name
# print(params['dncnn.12.running_mean'])
# print(model.state_dict())

model.eval() # evaluation mode
# model.train()

if torch.cuda.is_available():
model = model.cuda()

if not os.path.exists(args.result_dir):
os.mkdir(args.result_dir)

for set_cur in args.set_names:

if not os.path.exists(os.path.join(args.result_dir, set_cur)):
os.mkdir(os.path.join(args.result_dir, set_cur))
psnrs = []
ssims = []

for im in os.listdir(os.path.join(args.set_dir, set_cur)):
if im.endswith(".jpg") or im.endswith(".bmp") or im.endswith(".png"):

x = np.array(imread(os.path.join(args.set_dir, set_cur, im)), dtype=np.float32)/255.0
np.random.seed(seed=0) # for reproducibility
y = x + np.random.normal(0, args.sigma/255.0, x.shape) # Add Gaussian noise without clipping
y = y.astype(np.float32)
y_ = torch.from_numpy(y).view(1, -1, y.shape[0], y.shape[1])

torch.cuda.synchronize()
start_time = time.time()
y_ = y_.cuda()
x_ = model(y_) # inference
x_ = x_.view(y.shape[0], y.shape[1])
x_ = x_.cpu()
x_ = x_.detach().numpy().astype(np.float32)
torch.cuda.synchronize()
elapsed_time = time.time() - start_time
print('%10s : %10s : %2.4f second' % (set_cur, im, elapsed_time))

psnr_x_ = compare_psnr(x, x_)
ssim_x_ = compare_ssim(x, x_)
if args.save_result:
name, ext = os.path.splitext(im)
show(np.hstack((y, x_))) # show the image
save_result(x_, path=os.path.join(args.result_dir, set_cur, name+'_dncnn'+ext)) # save the denoised image
psnrs.append(psnr_x_)
ssims.append(ssim_x_)
psnr_avg = np.mean(psnrs)
ssim_avg = np.mean(ssims)
psnrs.append(psnr_avg)
ssims.append(ssim_avg)
if args.save_result:
save_result(np.hstack((psnrs, ssims)), path=os.path.join(args.result_dir, set_cur, 'results.txt'))
log('Datset: {0:10s} \n PSNR = {1:2.2f}dB, SSIM = {2:1.4f}'.format(set_cur, psnr_avg, ssim_avg))








Loading

0 comments on commit e9eb8c2

Please sign in to comment.