forked from kookmin-sw/cap-template
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #35 from kookmin-sw/feat/flask_server
Feat/flask server
- Loading branch information
Showing
146 changed files
with
9,310 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from torch.utils.data import Dataset | ||
from PIL import Image | ||
import PIL | ||
from utils import data_utils | ||
import torchvision.transforms as transforms | ||
import os | ||
|
||
class ImagesDataset(Dataset): | ||
|
||
def __init__(self, opts, image_path=None): | ||
if not image_path: | ||
image_root = opts.input_dir | ||
self.image_paths = sorted(data_utils.make_dataset(image_root)) | ||
elif type(image_path) == str: | ||
self.image_paths = [image_path] | ||
elif type(image_path) == list: | ||
self.image_paths = image_path | ||
|
||
self.image_transform = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) | ||
self.opts = opts | ||
|
||
def __len__(self): | ||
return len(self.image_paths) | ||
|
||
def __getitem__(self, index): | ||
im_path = self.image_paths[index] | ||
im_H = Image.open(im_path).convert('RGB') | ||
im_L = im_H.resize((256, 256), PIL.Image.LANCZOS) | ||
im_name = os.path.splitext(os.path.basename(im_path))[0] | ||
if self.image_transform: | ||
im_H = self.image_transform(im_H) | ||
im_L = self.image_transform(im_L) | ||
|
||
return im_H, im_L, im_name | ||
|
||
|
||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import torch | ||
from losses.style.style_loss import StyleLoss | ||
|
||
class AlignLossBuilder(torch.nn.Module): | ||
def __init__(self, opt, no_face=False): | ||
super(AlignLossBuilder, self).__init__() | ||
|
||
self.opt = opt | ||
self.parsed_loss = [[opt.l2_lambda, 'l2'], [opt.percept_lambda, 'percep']] | ||
if opt.device == 'cuda': | ||
use_gpu = True | ||
else: | ||
use_gpu = False | ||
|
||
self.cross_entropy = torch.nn.CrossEntropyLoss() | ||
self.style = StyleLoss(distance="l2", VGG16_ACTIVATIONS_LIST=[3, 8, 15, 22], normalize=False).to(opt.device) | ||
self.style.eval() | ||
|
||
|
||
tmp = torch.zeros(16).to(opt.device) | ||
tmp[0] = 1 | ||
|
||
tmp_hair = torch.zeros(16).to(opt.device) | ||
tmp_hair[10] = 1 | ||
|
||
weight_wo_background = 1 - tmp | ||
if no_face: | ||
weight_wo_background[1] = 0 | ||
self.cross_entropy_wo_background = torch.nn.CrossEntropyLoss(weight=weight_wo_background) | ||
self.cross_entropy_only_background = torch.nn.CrossEntropyLoss(weight=tmp) | ||
self.cross_entropy_only_hair = torch.nn.CrossEntropyLoss(weight=tmp_hair) | ||
|
||
|
||
|
||
def cross_entropy_loss(self, down_seg, target_mask): | ||
loss = self.opt.ce_lambda * self.cross_entropy(down_seg, target_mask) | ||
return loss | ||
|
||
|
||
def style_loss(self, im1, im2, mask1, mask2): | ||
loss = self.opt.style_lambda * self.style(im1 * mask1, im2 * mask2, mask1=mask1, mask2=mask2) | ||
return loss | ||
|
||
|
||
def cross_entropy_loss_wo_background(self, down_seg, target_mask): | ||
loss = self.opt.ce_lambda * self.cross_entropy_wo_background(down_seg, target_mask) | ||
return loss | ||
|
||
def cross_entropy_loss_only_background(self, down_seg, target_mask): | ||
loss = self.opt.ce_lambda * self.cross_entropy_only_background(down_seg, target_mask) | ||
return loss | ||
|
||
def cross_entropy_loss_only_hair(self, down_seg, target_mask): | ||
loss = self.opt.ce_lambda * self.cross_entropy_only_hair(down_seg, target_mask) | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import torch | ||
import PIL | ||
import os | ||
from losses import masked_lpips | ||
|
||
class BlendLossBuilder(torch.nn.Module): | ||
def __init__(self, opt): | ||
super(BlendLossBuilder, self).__init__() | ||
|
||
self.opt = opt | ||
self.parsed_loss = [[1.0, 'face'], [1.0, 'hair']] | ||
if opt.device == 'cuda': | ||
use_gpu = True | ||
else: | ||
use_gpu = False | ||
|
||
self.face_percept = masked_lpips.PerceptualLoss( | ||
model="net-lin", net="vgg", vgg_blocks=['1', '2', '3'], use_gpu=use_gpu | ||
) | ||
self.face_percept.eval() | ||
|
||
self.hair_percept = masked_lpips.PerceptualLoss( | ||
model="net-lin", net="vgg", vgg_blocks=['1', '2', '3'], use_gpu=use_gpu | ||
) | ||
self.hair_percept.eval() | ||
|
||
|
||
|
||
def _loss_face_percept(self, gen_im, ref_im, mask, **kwargs): | ||
|
||
return self.face_percept(gen_im, ref_im, mask=mask) | ||
|
||
def _loss_hair_percept(self, gen_im, ref_im, mask, **kwargs): | ||
|
||
return self.hair_percept(gen_im, ref_im, mask=mask) | ||
|
||
|
||
def forward(self, gen_im, im_1, im_3, mask_face, mask_hair): | ||
|
||
loss = 0 | ||
loss_fun_dict = { | ||
'face': self._loss_face_percept, | ||
'hair': self._loss_hair_percept, | ||
} | ||
losses = {} | ||
for weight, loss_type in self.parsed_loss: | ||
if loss_type == 'face': | ||
var_dict = { | ||
'gen_im': gen_im, | ||
'ref_im': im_1, | ||
'mask': mask_face | ||
} | ||
elif loss_type == 'hair': | ||
var_dict = { | ||
'gen_im': gen_im, | ||
'ref_im': im_3, | ||
'mask': mask_hair | ||
} | ||
tmp_loss = loss_fun_dict[loss_type](**var_dict) | ||
losses[loss_type] = tmp_loss | ||
loss += weight*tmp_loss | ||
return loss, losses |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import torch | ||
from losses import lpips | ||
import PIL | ||
import os | ||
from losses.style.style_loss import StyleLoss | ||
|
||
class EmbeddingLossBuilder(torch.nn.Module): | ||
def __init__(self, opt): | ||
super(EmbeddingLossBuilder, self).__init__() | ||
|
||
self.opt = opt | ||
self.parsed_loss = [[opt.l2_lambda, 'l2'], [opt.percept_lambda, 'percep'], [opt.sp_hair_lambda, 'sp_hair']] | ||
self.l2 = torch.nn.MSELoss() | ||
if opt.device == 'cuda': | ||
use_gpu = True | ||
else: | ||
use_gpu = False | ||
self.percept = lpips.PerceptualLoss(model="net-lin", net="vgg", use_gpu=use_gpu) | ||
self.percept.eval() | ||
# self.percept = VGGLoss() | ||
|
||
# style loss | ||
self.style = StyleLoss(distance="l2", VGG16_ACTIVATIONS_LIST=[3, 8, 15, 22], normalize=False).to(opt.device) | ||
self.style.eval() | ||
|
||
|
||
def _loss_l2(self, gen_im, ref_im, **kwargs): | ||
return self.l2(gen_im, ref_im) | ||
|
||
|
||
def _loss_lpips(self, gen_im, ref_im, **kwargs): | ||
|
||
return self.percept(gen_im, ref_im).sum() | ||
|
||
|
||
def _loss_sp_hair(self, gen_im, ref_im, sp_mask): | ||
return self.style(gen_im * sp_mask, ref_im * sp_mask, mask1=sp_mask, mask2=sp_mask) | ||
|
||
|
||
|
||
def forward(self, ref_im_H,ref_im_L, gen_im_H, gen_im_L, sp_mask=None): | ||
|
||
loss = 0 | ||
loss_fun_dict = { | ||
'l2': self._loss_l2, | ||
'percep': self._loss_lpips, | ||
'sp_hair': self._loss_sp_hair, | ||
} | ||
losses = {} | ||
for weight, loss_type in self.parsed_loss: | ||
if loss_type == 'l2': | ||
var_dict = { | ||
'gen_im': gen_im_H, | ||
'ref_im': ref_im_H, | ||
} | ||
elif loss_type == 'percep': | ||
var_dict = { | ||
'gen_im': gen_im_L, | ||
'ref_im': ref_im_L, | ||
} | ||
elif loss_type == 'sp_hair': | ||
if weight == 0 or sp_mask is None: | ||
continue | ||
var_dict = { | ||
'gen_im': gen_im_L, | ||
'ref_im': ref_im_L, | ||
'sp_mask': sp_mask, | ||
} | ||
tmp_loss = loss_fun_dict[loss_type](**var_dict) | ||
losses[loss_type] = tmp_loss | ||
loss += weight*tmp_loss | ||
return loss, losses |
Oops, something went wrong.