diff --git a/README.md b/README.md index f14c936..36d683f 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,32 @@ +# What is different between this fixed version and the original ZeroDCE? + +(1) providing 7 different color spaces for training ("RGB", "HSV", "HLS", "YCbCr", "YUV", "LAB", and "LUV"). + +``` +cd Zero-DCE_code +``` +``` +python lowlight_train.py --channel ("RGB", "HSV", "HLS", "YCbCr", "YUV", "LAB", and "LUV") +``` +(2) providing 7 different color spaces of 200 epochs pretrained weight. + +``` +./Zero-DCE_code/snapshots/("RGB", "HSV", "HLS", "YCbCr", "YUV", "LAB", and "LUV").pth +``` +(3) providing applications on videos. + +``` +cd Zero-DCE_code +``` +``` +python lowlight_test.py --mode (video/image) --channel ("RGB", "HSV", "HLS", "YCbCr", "YUV", "LAB", and "LUV") +``` +(4) providing a tensorboard to display training loss. + +``` +tensorboard --logdir log/train_loss_("RGB", "HSV", "HLS", "YCbCr", "YUV", "LAB", and "LUV") +``` + # Zero-Reference Deep Curve Estimation for Low-Light Image Enhancement You can find more details here: https://li-chongyi.github.io/Proj_Zero-DCE.html. Have fun! @@ -76,8 +105,5 @@ The code is made available for academic research purpose only. Under Attribution (Full paper: http://openaccess.thecvf.com/content_CVPR_2020/papers/Guo_Zero-Reference_Deep_Curve_Estimation_for_Low-Light_Image_Enhancement_CVPR_2020_paper.pdf) ## Contact -If you have any questions, please contact Chongyi Li at lichongyi25@gmail.com or Chunle Guo at guochunle@tju.edu.cn. -## TensorFlow Version -Thanks tuvovan (vovantu.hust@gmail.com) who re-produces our code by TF. The results of TF version look similar with our Pytorch version. But I do not have enough time to check the details. -https://github.com/tuvovan/Zero_DCE_TF +If you have any questions, please contact ICHEN LU at luyijan@gmail.com. diff --git a/Zero-DCE_code/dataloader.py b/Zero-DCE_code/dataloader.py index e01217c..4254d48 100644 --- a/Zero-DCE_code/dataloader.py +++ b/Zero-DCE_code/dataloader.py @@ -15,10 +15,7 @@ def populate_train_list(lowlight_images_path): - - - - image_list_lowlight = glob.glob(lowlight_images_path + "*.jpg") + image_list_lowlight = glob.glob(lowlight_images_path + "*.JPG") train_list = image_list_lowlight @@ -26,34 +23,67 @@ def populate_train_list(lowlight_images_path): return train_list - - class lowlight_loader(data.Dataset): - - def __init__(self, lowlight_images_path): - - self.train_list = populate_train_list(lowlight_images_path) - self.size = 256 - - self.data_list = self.train_list - print("Total training examples:", len(self.train_list)) - - - - - def __getitem__(self, index): - - data_lowlight_path = self.data_list[index] - - data_lowlight = Image.open(data_lowlight_path) - - data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS) - - data_lowlight = (np.asarray(data_lowlight)/255.0) - data_lowlight = torch.from_numpy(data_lowlight).float() - - return data_lowlight.permute(2,0,1) - - def __len__(self): - return len(self.data_list) - + def __init__(self, lowlight_images_path, channel): + self.train_list = populate_train_list(lowlight_images_path) + self.size = 256 + self.channel = channel + + self.data_list = self.train_list + print("Total training examples:", len(self.train_list)) + + def __getitem__(self, index): + + data_lowlight_path = self.data_list[index] + data_lowlight = cv2.imread(data_lowlight_path) + data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_BGR2RGB) + data_lowlight = cv2.resize(data_lowlight,(self.size,self.size),interpolation = cv2.INTER_AREA) + if self.channel=="RGB": + data_lowlight = (np.asarray(data_lowlight)/255.0) + else: + if self.channel=="HSV": + data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2HSV) + H, S, V = cv2.split(data_lowlight) + data_lowlight = np.asarray(data_lowlight).copy() + data_lowlight_1 = ((H)/(180.0)) + data_lowlight_2 = ((S)/(255.0)) + data_lowlight_3 = ((V)/(255.0)) + elif self.channel=="HLS": + data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2HLS) + H, L, S = cv2.split(data_lowlight) + data_lowlight = np.asarray(data_lowlight).copy() + data_lowlight_1 = ((H)/(180.0)) + data_lowlight_2 = ((L)/(255.0)) + data_lowlight_3 = ((S)/(255.0)) + elif self.channel=="YCbCr": + data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2YCrCb) + Y, Cr, Cb = cv2.split(data_lowlight) + data_lowlight_1 = ((Y)/(255.0)) + data_lowlight_2 = ((Cr)/(255.0)) + data_lowlight_3 = ((Cb)/(255.0)) + elif self.channel=="YUV": + data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2YUV) + Y, U, V = cv2.split(data_lowlight) + data_lowlight_1 = ((Y)/(255.0)) + data_lowlight_2 = ((U)/(255.0)) + data_lowlight_3 = ((V)/(255.0)) + elif self.channel=="LAB": + data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2Lab) + L, A, B = cv2.split(data_lowlight) + data_lowlight_1 = ((L-0.0)/(255.0-0.0)) + data_lowlight_2 = ((A-1.0)/(255.0-1.0)) + data_lowlight_3 = ((B-1.0)/(255.0-1.0)) + elif self.channel=="LUV": + data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2Luv) + L, U, V = cv2.split(data_lowlight) + data_lowlight_1 = ((L)/(255.0)) + data_lowlight_2 = ((U)/(255.0)) + data_lowlight_3 = ((V)/(255.0)) + + data_lowlight = cv2.merge([data_lowlight_1,data_lowlight_2,data_lowlight_3]) + + data_lowlight = torch.from_numpy(data_lowlight).float() + data_lowlight = data_lowlight.permute(2,0,1) + return data_lowlight + def __len__(self): + return len(self.data_list) diff --git a/Zero-DCE_code/lowlight_test.py b/Zero-DCE_code/lowlight_test.py index 8bf2bd3..bb4b17f 100644 --- a/Zero-DCE_code/lowlight_test.py +++ b/Zero-DCE_code/lowlight_test.py @@ -14,49 +14,165 @@ from PIL import Image import glob import time +import argparse +import torch +import torchvision +import torchvision.transforms as T +from PIL import Image +import cv2 +import tensorflow as tf + - -def lowlight(image_path): - os.environ['CUDA_VISIBLE_DEVICES']='0' - data_lowlight = Image.open(image_path) - - - - data_lowlight = (np.asarray(data_lowlight)/255.0) +def Color_Choice(color_space,data_lowlight): + #data_lowlight = Image.open(data_lowlight_path).convert(color) + #data_lowlight = cv2.imread(data_lowlight_path) + data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_BGR2RGB) + if color_space == "RGB": + data_lowlight = (np.asarray(data_lowlight)/255.0) + n = [255,0,255,0,255,0] + back = cv2.COLOR_RGB2BGR + else: + if color_space == "HSV": + data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2HSV) + n = [180,0,255,0,255,0] + back = cv2.COLOR_HSV2BGR + elif color_space == "HLS": + data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2HLS) + n = [180,0,255,0,255,0] + back = cv2.COLOR_HLS2BGR + elif color_space == "YCbCr": + data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2YCrCb) + n = [255,0,255,0,255,0] + back = cv2.COLOR_YCrCb2BGR + elif color_space == "YUV": + data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2YUV) + n = [255,0,255,0,255,0] + back = cv2.COLOR_YUV2BGR + elif color_space == "LAB": + data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2Lab) + n = [255,0,255-1,1,255-1,1] + back = cv2.COLOR_Lab2BGR + elif color_space == "LUV": + data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2Luv) + n = [255,0,255,0,255,0] + back = cv2.COLOR_Luv2BGR + c1,c2,c3 = cv2.split(data_lowlight) + data_lowlight_1 = ((c1-n[1])/(n[0])) + data_lowlight_2 = ((c2-n[3])/(n[2])) + data_lowlight_3 = ((c3-n[5])/(n[4])) + data_lowlight = cv2.merge([data_lowlight_1,data_lowlight_2,data_lowlight_3]) + + data_lowlight = torch.from_numpy(data_lowlight).float() + data_lowlight = data_lowlight.permute(2,0,1) + return data_lowlight,n,back - data_lowlight = torch.from_numpy(data_lowlight).float() - data_lowlight = data_lowlight.permute(2,0,1) - data_lowlight = data_lowlight.cuda().unsqueeze(0) - DCE_net = model.enhance_net_nopool().cuda() - DCE_net.load_state_dict(torch.load('snapshots/Epoch99.pth')) - start = time.time() - _,enhanced_image,_ = DCE_net(data_lowlight) +def lowlight(color_channel,lowlight_image): + os.environ['CUDA_VISIBLE_DEVICES']='0' + + data_lowlight,con,inchan = Color_Choice(color_channel,lowlight_image) + data_lowlight = data_lowlight.cuda().unsqueeze(0) + + if config.channel=="RGB": + DCE_net = model.enhance_net_nopool_3().cuda() + elif config.channel=="HSV": + DCE_net = model.enhance_net_nopool_1_3().cuda() + elif config.channel=="HLS": + DCE_net = model.enhance_net_nopool_1_2().cuda() + elif config.channel=="YCbCr" or config.channel=="YUV" or config.channel=="LAB" or config.channel=="LUV": + DCE_net = model.enhance_net_nopool_1_1().cuda() + + DCE_net.load_state_dict(torch.load("snapshots/"+config.channel+".pth")) + #start = time.time() + + _,enhanced_image,_ = DCE_net(data_lowlight) + data_lowlight = enhanced_image[0].permute(1,2,0).cpu().numpy() + temp1 = np.zeros((data_lowlight[:,:,0].shape[0],data_lowlight[:,:,0].shape[1]), dtype="uint8") + temp2 = np.zeros((data_lowlight[:,:,0].shape[0],data_lowlight[:,:,0].shape[1]), dtype="uint8") + temp3 = np.zeros((data_lowlight[:,:,0].shape[0],data_lowlight[:,:,0].shape[1]), dtype="uint8") + temp1[:,:] = (data_lowlight[:,:,0]*con[0]+con[1]).astype(dtype="uint8") + temp2[:,:] = (data_lowlight[:,:,1]*con[2]+con[3]).astype(dtype="uint8") + temp3[:,:] = (data_lowlight[:,:,2]*con[4]+con[5]).astype(dtype="uint8") - end_time = (time.time() - start) - print(end_time) - image_path = image_path.replace('test_data','result') - result_path = image_path - if not os.path.exists(image_path.replace('/'+image_path.split("/")[-1],'')): - os.makedirs(image_path.replace('/'+image_path.split("/")[-1],'')) + data_enhanced = cv2.cvtColor(cv2.merge([temp1,temp2,temp3]), inchan) + + #end_time = (time.time() - start) + #print(end_time) + return data_enhanced + """ + result_path = lowlight_images_path.replace('test_data','result') + if not os.path.exists(result_path.replace('/'+result_path.split("/")[-1],'')): + os.makedirs(result_path.replace('/'+result_path.split("/")[-1],'')) + + torchvision.utils.save_image(enhanced_image, result_path) + """ - torchvision.utils.save_image(enhanced_image, result_path) if __name__ == '__main__': # test_images - with torch.no_grad(): - filePath = 'data/test_data/' - - file_list = os.listdir(filePath) - - for file_name in file_list: - test_list = glob.glob(filePath+file_name+"/*") - for image in test_list: + parser = argparse.ArgumentParser() + parser.add_argument("--lowlight_images_path", type=str, default="data/test_data/") + parser.add_argument("--mode", type=str, default="image") + parser.add_argument("--channel", type=str, default="RGB") + parser.add_argument("--save_images_path", type=str, default="data/result") + config = parser.parse_args() + with torch.no_grad(): + file_path = config.lowlight_images_path + file_list = os.listdir(config.lowlight_images_path) + if config.mode == "image": + start1 = time.time() + for file_name in file_list: + test_list = glob.glob(file_path+file_name+"/*") + for image_path in test_list: # image = image - print(image) - lowlight(image) - - - + print(image_path) + start2 = time.time() + data_lowlight = cv2.imread(image_path) + data_enhanced = lowlight(config.channel,data_lowlight) + result_path = image_path.replace('test_data','result') + if not os.path.exists(result_path.replace('/'+result_path.split("/")[-1],'')): + os.makedirs(result_path.replace('/'+result_path.split("/")[-1],'')) + cv2.imwrite(result_path,data_enhanced) + end_time2 = (time.time() - start2) + print("executive time of each frame: ", end_time2) + end_time1 = (time.time() - start1) + print("executive time of all images: ", end_time1) + elif config.mode == "video": + start1 = time.time() + for file_name in file_list: + test_list = glob.glob(file_path+file_name+"/*") + for video_path in test_list: + print(video_path) + start2 = time.time() + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) + width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float + height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float + fps = cap.get(cv2.CAP_PROP_FPS) + result_path = video_path.replace('test_data','result') + if not os.path.exists(result_path.replace('/'+result_path.split("/")[-1],'')): + os.makedirs(result_path.replace('/'+result_path.split("/")[-1],'')) + vid_writer = cv2.VideoWriter( + result_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (int(width), int(height)) + ) + i = 0 + while True: + ret_val, frame = cap.read() + if ret_val: + start3 = time.time() + frame = np.array(frame) + lowlight(config.channel,frame) + frame_enhanced = lowlight(config.channel,frame) + vid_writer.write(frame_enhanced) + end_time3 = (time.time() - start3) + print("executive time of "+str(i+1)+"-th frame: ", end_time3) + i = i+1 + else: + break + end_time2 = (time.time() - start2) + print("original frame per second: ",fps) + print("executive time of full video: ", end_time2) + end_time1 = (time.time() - start1) + print("executive time of all videos: ", end_time1) diff --git a/Zero-DCE_code/lowlight_train.py b/Zero-DCE_code/lowlight_train.py index 7743e8d..1d07273 100644 --- a/Zero-DCE_code/lowlight_train.py +++ b/Zero-DCE_code/lowlight_train.py @@ -12,7 +12,7 @@ import Myloss import numpy as np from torchvision import transforms - +import tensorflow as tf def weights_init(m): classname = m.__class__.__name__ @@ -27,15 +27,24 @@ def weights_init(m): def train(config): - os.environ['CUDA_VISIBLE_DEVICES']='0' - - DCE_net = model.enhance_net_nopool().cuda() - + if config.channel=="RGB": + DCE_net = model.enhance_net_nopool_3().cuda() + elif config.channel=="HSV": + DCE_net = model.enhance_net_nopool_1_3().cuda() + elif config.channel=="HLS": + DCE_net = model.enhance_net_nopool_1_2().cuda() + elif config.channel=="YCbCr" or config.channel=="YUV" or config.channel=="LAB" or config.channel=="LUV": + DCE_net = model.enhance_net_nopool_1_1().cuda() + DCE_net.apply(weights_init) + """ if config.load_pretrain == True: DCE_net.load_state_dict(torch.load(config.pretrain_dir)) - train_dataset = dataloader.lowlight_loader(config.lowlight_images_path) + """ + + train_dataset = dataloader.lowlight_loader(config.lowlight_images_path, config.channel) + print(len(train_dataset)) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True) @@ -49,29 +58,38 @@ def train(config): optimizer = torch.optim.Adam(DCE_net.parameters(), lr=config.lr, weight_decay=config.weight_decay) - + data_ter = len(train_loader) DCE_net.train() for epoch in range(config.num_epochs): + train_summary_writer = tf.summary.create_file_writer("log/train_loss_"+config.channel) + Loss_TV = [] + loss_spa = [] + loss_col = [] + loss_exp = [] + loss_tot = [] for iteration, img_lowlight in enumerate(train_loader): - img_lowlight = img_lowlight.cuda() enhanced_image_1,enhanced_image,A = DCE_net(img_lowlight) - Loss_TV = 200*L_TV(A) + Loss_TV.append(L_TV(A)) + print("Loss_TV:",Loss_TV[iteration].item()) - loss_spa = torch.mean(L_spa(enhanced_image, img_lowlight)) + loss_spa.append(torch.mean(L_spa(enhanced_image, img_lowlight))) + print("loss_spa:",loss_spa[iteration].item()) - loss_col = 5*torch.mean(L_color(enhanced_image)) + loss_col.append(torch.mean(L_color(enhanced_image))) + print("loss_col:",loss_col[iteration].item()) - loss_exp = 10*torch.mean(L_exp(enhanced_image)) + loss_exp.append(torch.mean(L_exp(enhanced_image))) + print("loss_exp:",loss_exp[iteration].item()) # best_loss - loss = Loss_TV + loss_spa + loss_col + loss_exp + loss = 200*Loss_TV[iteration] + loss_spa[iteration] + 5*loss_col[iteration] + 10*loss_exp[iteration] + loss_tot.append(loss) # - optimizer.zero_grad() loss.backward() @@ -82,7 +100,14 @@ def train(config): print("Loss at iteration", iteration+1, ":", loss.item()) if ((iteration+1) % config.snapshot_iter) == 0: - torch.save(DCE_net.state_dict(), config.snapshots_folder + "Epoch" + str(epoch) + '.pth') + torch.save(DCE_net.state_dict(), config.snapshots_folder + config.channel + '.pth') #+ "Epoch" + str(epoch) + with train_summary_writer.as_default(): + tf.summary.scalar('Illumination Smoothness Loss', sum(Loss_TV).item()/data_ter, step=epoch) + tf.summary.scalar('Spatial Loss', sum(loss_spa).item()/data_ter, step=epoch) + tf.summary.scalar('Color Loss', sum(loss_col).item()/data_ter, step=epoch) + tf.summary.scalar('Exposure Control Loss', sum(loss_exp).item()/data_ter, step=epoch) + tf.summary.scalar('Total Loss', sum(loss_tot).item()/data_ter, step=epoch) + train_summary_writer.close() @@ -93,6 +118,7 @@ def train(config): # Input Parameters parser.add_argument('--lowlight_images_path', type=str, default="data/train_data/") + parser.add_argument("--channel", type=str, default="RGB") # HSV, YCbCr, LAB, Luv, HLS, YUV parser.add_argument('--lr', type=float, default=0.0001) parser.add_argument('--weight_decay', type=float, default=0.0001) parser.add_argument('--grad_clip_norm', type=float, default=0.1) @@ -113,12 +139,3 @@ def train(config): train(config) - - - - - - - - - diff --git a/Zero-DCE_code/model.py b/Zero-DCE_code/model.py index 3b710a5..ed69bed 100644 --- a/Zero-DCE_code/model.py +++ b/Zero-DCE_code/model.py @@ -5,10 +5,10 @@ #import pytorch_colors as colors import numpy as np -class enhance_net_nopool(nn.Module): +class enhance_net_nopool_3(nn.Module): def __init__(self): - super(enhance_net_nopool, self).__init__() + super(enhance_net_nopool_3, self).__init__() self.relu = nn.ReLU(inplace=True) @@ -57,3 +57,186 @@ def forward(self, x): + + + + +class enhance_net_nopool_1_1(nn.Module): + + def __init__(self): + super(enhance_net_nopool_1_1, self).__init__() + + self.relu = nn.ReLU(inplace=True) + + number_f = 32 + self.e_conv1 = nn.Conv2d(1,number_f,3,1,1,bias=True) + self.e_conv2 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) + self.e_conv3 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) + self.e_conv4 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) + self.e_conv5 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) + self.e_conv6 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) + self.e_conv7 = nn.Conv2d(number_f*2,8,3,1,1,bias=True) + + self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False) + self.upsample = nn.UpsamplingBilinear2d(scale_factor=2) + + + + def forward(self, X): + #print("forward: ",X.size()) + #print("forward: ",X[:,:1,:,:].size()) + + y = X[:,:1,:,:] + x1 = self.relu(self.e_conv1(y)) + # p1 = self.maxpool(x1) + x2 = self.relu(self.e_conv2(x1)) + # p2 = self.maxpool(x2) + x3 = self.relu(self.e_conv3(x2)) + # p3 = self.maxpool(x3) + x4 = self.relu(self.e_conv4(x3)) + + x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1))) + # x5 = self.upsample(x5) + x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1))) + + x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1))) + #print("x_r: ",x_r.size()) + r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 1, dim=1) + #print("r1: ",r1.size()) + + x = y + r1*(torch.pow(y,2)-y) + x = x + r2*(torch.pow(x,2)-x) + x = x + r3*(torch.pow(x,2)-x) + enhance_image_1 = x + r4*(torch.pow(x,2)-x) + #print("enhance_image_1: ",enhance_image_1.size()) + #print("enhance_image_1: ",enhance_image_1[:,0,:,:].size()) + #print("enhance_image_1: ",X[:,2:,:,:].size()) + enhance_image_2 = torch.cat([enhance_image_1,X[:,1:2,:,:],X[:,2:,:,:]],1) + x = enhance_image_1 + r5*(torch.pow(enhance_image_1,2)-enhance_image_1) + x = x + r6*(torch.pow(x,2)-x) + x = x + r7*(torch.pow(x,2)-x) + enhance_image = x + r8*(torch.pow(x,2)-x) + #print("enhance_image: ",enhance_image_1.size()) + enhance_image_3 = torch.cat([enhance_image,X[:,1:2,:,:],X[:,2:,:,:]],1) + r = torch.cat([r1,r2,r3,r4,r5,r6,r7,r8],1) + return enhance_image_2,enhance_image_3,r + +class enhance_net_nopool_1_2(nn.Module): + + def __init__(self): + super(enhance_net_nopool_1_2, self).__init__() + + self.relu = nn.ReLU(inplace=True) + + number_f = 32 + self.e_conv1 = nn.Conv2d(1,number_f,3,1,1,bias=True) + self.e_conv2 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) + self.e_conv3 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) + self.e_conv4 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) + self.e_conv5 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) + self.e_conv6 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) + self.e_conv7 = nn.Conv2d(number_f*2,8,3,1,1,bias=True) + + self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False) + self.upsample = nn.UpsamplingBilinear2d(scale_factor=2) + + + + def forward(self, X): + #print("forward: ",X.size()) + #print("forward: ",X[:,:1,:,:].size()) + + y = X[:,1:2,:,:] + x1 = self.relu(self.e_conv1(y)) + # p1 = self.maxpool(x1) + x2 = self.relu(self.e_conv2(x1)) + # p2 = self.maxpool(x2) + x3 = self.relu(self.e_conv3(x2)) + # p3 = self.maxpool(x3) + x4 = self.relu(self.e_conv4(x3)) + + x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1))) + # x5 = self.upsample(x5) + x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1))) + + x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1))) + #print("x_r: ",x_r.size()) + r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 1, dim=1) + #print("r1: ",r1.size()) + + x = y + r1*(torch.pow(y,2)-y) + x = x + r2*(torch.pow(x,2)-x) + x = x + r3*(torch.pow(x,2)-x) + enhance_image_1 = x + r4*(torch.pow(x,2)-x) + #print("enhance_image_1: ",enhance_image_1.size()) + #print("enhance_image_1: ",enhance_image_1[:,0,:,:].size()) + #print("enhance_image_1: ",X[:,2:,:,:].size()) + enhance_image_2 = torch.cat([X[:,:1,:,:],enhance_image_1,X[:,2:,:,:]],1) + x = enhance_image_1 + r5*(torch.pow(enhance_image_1,2)-enhance_image_1) + x = x + r6*(torch.pow(x,2)-x) + x = x + r7*(torch.pow(x,2)-x) + enhance_image = x + r8*(torch.pow(x,2)-x) + #print("enhance_image: ",enhance_image_1.size()) + enhance_image_3 = torch.cat([X[:,:1,:,:],enhance_image,X[:,2:,:,:]],1) + r = torch.cat([r1,r2,r3,r4,r5,r6,r7,r8],1) + return enhance_image_2,enhance_image_3,r + +class enhance_net_nopool_1_3(nn.Module): + + def __init__(self): + super(enhance_net_nopool_1_3, self).__init__() + + self.relu = nn.ReLU(inplace=True) + + number_f = 32 + self.e_conv1 = nn.Conv2d(1,number_f,3,1,1,bias=True) + self.e_conv2 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) + self.e_conv3 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) + self.e_conv4 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) + self.e_conv5 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) + self.e_conv6 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) + self.e_conv7 = nn.Conv2d(number_f*2,8,3,1,1,bias=True) + + self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False) + self.upsample = nn.UpsamplingBilinear2d(scale_factor=2) + + + + def forward(self, X): + #print("forward: ",X.size()) + #print("forward: ",X[:,:1,:,:].size()) + + y = X[:,2:,:,:] + x1 = self.relu(self.e_conv1(y)) + # p1 = self.maxpool(x1) + x2 = self.relu(self.e_conv2(x1)) + # p2 = self.maxpool(x2) + x3 = self.relu(self.e_conv3(x2)) + # p3 = self.maxpool(x3) + x4 = self.relu(self.e_conv4(x3)) + + x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1))) + # x5 = self.upsample(x5) + x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1))) + + x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1))) + #print("x_r: ",x_r.size()) + r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 1, dim=1) + #print("r1: ",r1.size()) + + x = y + r1*(torch.pow(y,2)-y) + x = x + r2*(torch.pow(x,2)-x) + x = x + r3*(torch.pow(x,2)-x) + enhance_image_1 = x + r4*(torch.pow(x,2)-x) + #print("enhance_image_1: ",enhance_image_1.size()) + #print("enhance_image_1: ",enhance_image_1[:,0,:,:].size()) + #print("enhance_image_1: ",X[:,2:,:,:].size()) + enhance_image_2 = torch.cat([X[:,:1,:,:],X[:,1:2,:,:],enhance_image_1],1) + x = enhance_image_1 + r5*(torch.pow(enhance_image_1,2)-enhance_image_1) + x = x + r6*(torch.pow(x,2)-x) + x = x + r7*(torch.pow(x,2)-x) + enhance_image = x + r8*(torch.pow(x,2)-x) + #print("enhance_image: ",enhance_image_1.size()) + enhance_image_3 = torch.cat([X[:,:1,:,:],X[:,1:2,:,:],enhance_image],1) + r = torch.cat([r1,r2,r3,r4,r5,r6,r7,r8],1) + return enhance_image_2,enhance_image_3,r diff --git a/Zero-DCE_code/snapshots/HLS.pth b/Zero-DCE_code/snapshots/HLS.pth new file mode 100644 index 0000000..25be598 Binary files /dev/null and b/Zero-DCE_code/snapshots/HLS.pth differ diff --git a/Zero-DCE_code/snapshots/HSV.pth b/Zero-DCE_code/snapshots/HSV.pth new file mode 100644 index 0000000..25664fc Binary files /dev/null and b/Zero-DCE_code/snapshots/HSV.pth differ diff --git a/Zero-DCE_code/snapshots/LAB.pth b/Zero-DCE_code/snapshots/LAB.pth new file mode 100644 index 0000000..5cf6f38 Binary files /dev/null and b/Zero-DCE_code/snapshots/LAB.pth differ diff --git a/Zero-DCE_code/snapshots/LUV.pth b/Zero-DCE_code/snapshots/LUV.pth new file mode 100644 index 0000000..c07103b Binary files /dev/null and b/Zero-DCE_code/snapshots/LUV.pth differ diff --git a/Zero-DCE_code/snapshots/RGB.pth b/Zero-DCE_code/snapshots/RGB.pth new file mode 100644 index 0000000..300ee83 Binary files /dev/null and b/Zero-DCE_code/snapshots/RGB.pth differ diff --git a/Zero-DCE_code/snapshots/YCbCr.pth b/Zero-DCE_code/snapshots/YCbCr.pth new file mode 100644 index 0000000..f5291b3 Binary files /dev/null and b/Zero-DCE_code/snapshots/YCbCr.pth differ diff --git a/Zero-DCE_code/snapshots/YUV.pth b/Zero-DCE_code/snapshots/YUV.pth new file mode 100644 index 0000000..f323c34 Binary files /dev/null and b/Zero-DCE_code/snapshots/YUV.pth differ