Skip to content

Commit

Permalink
Merge branch 'zyddnys:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
dmMaze authored Apr 20, 2022
2 parents b76d7aa + c24ef01 commit 9b828fc
Show file tree
Hide file tree
Showing 8 changed files with 824 additions and 28 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ GPU server is not cheap, please consider to donate to us.
## Online Demo

Official Demo (by zyddnys): <https://touhou.ai/imgtrans/>\
Mirror Demo (by Eidenz): <https://manga.eidenz.com/>\
Mirror Demo (by Eidenz): <https://manga.eidenz.moe/>\
Browser Userscript (by QiroNT): <https://greasyfork.org/scripts/437569>

- Note this may not work sometimes due to stupid google gcp kept restarting my instance.
Expand Down
21 changes: 15 additions & 6 deletions inpainting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,31 @@
import cv2
import numpy as np
from .inpainting_aot import AOTGenerator
from .inpainting_lama import get_generator as get_lama_generator
from utils import resize_keep_aspect

DEFAULT_MODEL = None
INPAINTING_MODEL = None

def load_model(cuda: bool, model_name: str = 'default') :
global DEFAULT_MODEL
if model_name not in ['default'] :
global INPAINTING_MODEL
if model_name not in ['default', 'lama'] :
raise Exception
if model_name == 'default' and DEFAULT_MODEL is None :
if model_name == 'default' and INPAINTING_MODEL is None :
model = AOTGenerator()
sd = torch.load('inpainting.ckpt', map_location = 'cpu')
model.load_state_dict(sd['model'] if 'model' in sd else sd)
model.eval()
if cuda :
model = model.cuda()
DEFAULT_MODEL = model
INPAINTING_MODEL = model
if model_name == 'lama' and INPAINTING_MODEL is None :
model = get_lama_generator()
sd = torch.load('inpainting_lama.ckpt', map_location = 'cpu')
model.load_state_dict(sd['model'] if 'model' in sd else sd)
model.eval()
if cuda :
model = model.cuda()
INPAINTING_MODEL = model

async def dispatch(use_inpainting: bool, use_poisson_blending: bool, cuda: bool, img: np.ndarray, mask: np.ndarray, inpainting_size: int = 1024, model_name: str = 'default', verbose: bool = False) -> np.ndarray :
img_original = np.copy(img)
Expand Down Expand Up @@ -58,7 +67,7 @@ async def dispatch(use_inpainting: bool, use_poisson_blending: bool, cuda: bool,
mask_torch = mask_torch.cuda()
with torch.no_grad() :
img_torch *= (1 - mask_torch)
img_inpainted_torch = DEFAULT_MODEL(img_torch, mask_torch)
img_inpainted_torch = INPAINTING_MODEL(img_torch, mask_torch)
img_inpainted = ((img_inpainted_torch.cpu().squeeze_(0).permute(1, 2, 0).numpy() + 1.0) * 127.5).astype(np.uint8)
if new_h != height or new_w != width :
img_inpainted = cv2.resize(img_inpainted, (width, height), interpolation = cv2.INTER_LINEAR)
Expand Down
Loading

0 comments on commit 9b828fc

Please sign in to comment.