From 1af4061d7de712409b22ac1acc9354f24c3a347a Mon Sep 17 00:00:00 2001 From: Abb <62794980+Abbsalehi@users.noreply.github.com> Date: Mon, 9 Dec 2024 22:16:01 -0600 Subject: [PATCH 1/2] Update lama_inpaint.py --- lama_inpaint.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/lama_inpaint.py b/lama_inpaint.py index 1807e84..c06ad9c 100644 --- a/lama_inpaint.py +++ b/lama_inpaint.py @@ -58,13 +58,15 @@ def inpaint_img_with_lama( model = load_checkpoint( train_config, checkpoint_path, strict=False, map_location='cpu') model.freeze() - if not predict_config.get('refine', False): - model.to(device) + # if not predict_config.get('refine', False): + # model.to(device) + model.to(device) batch = {} batch['image'] = img.permute(2, 0, 1).unsqueeze(0) batch['mask'] = mask[None, None] unpad_to_size = [batch['image'].shape[2], batch['image'].shape[3]] + batch['unpad_to_size']= torch.tensor(unpad_to_size).to(device) batch['image'] = pad_tensor_to_modulo(batch['image'], mod) batch['mask'] = pad_tensor_to_modulo(batch['mask'], mod) batch = move_to_device(batch, device) @@ -73,7 +75,21 @@ def inpaint_img_with_lama( batch = model(batch) cur_res = batch[predict_config.out_key][0].permute(1, 2, 0) cur_res = cur_res.detach().cpu().numpy() - + # Feature Refinement to Improve High Resolution Image Inpainting + if predict_config.get('refine', False): + # assert 'unpad_to_size' in batch, "Unpadded size is required for the refinement" + # image unpadding is taken care of in the refiner + # is same size as the input image + cur_res = refine_predict(batch, model, **predict_config.refiner) + cur_res = cur_res[0].permute(1, 2, 0).detach().cpu().numpy() + else: + with torch.no_grad(): + batch = move_to_device(batch, device) + batch['mask'] = (batch['mask'] > 0) * 1 + batch = model(batch) + cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy() + unpad_to_size = batch.get('unpad_to_size', None) + if unpad_to_size is not None: orig_height, orig_width = unpad_to_size cur_res = cur_res[:orig_height, :orig_width] @@ -197,4 +213,4 @@ def setup_args(parser): img_inpainted_p = out_dir / f"inpainted_with_{Path(mask_p).name}" img_inpainted = inpaint_img_with_lama( img, mask, args.lama_config, args.lama_ckpt, device=device) - save_array_to_img(img_inpainted, img_inpainted_p) \ No newline at end of file + save_array_to_img(img_inpainted, img_inpainted_p) From 706139125d5bfb36f5c36f927c081749ab63f3cd Mon Sep 17 00:00:00 2001 From: Abb <62794980+Abbsalehi@users.noreply.github.com> Date: Mon, 9 Dec 2024 22:17:09 -0600 Subject: [PATCH 2/2] Update refinement.py --- lama/saicinpainting/evaluation/refinement.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lama/saicinpainting/evaluation/refinement.py b/lama/saicinpainting/evaluation/refinement.py index d9d3cba..a4a9ca1 100644 --- a/lama/saicinpainting/evaluation/refinement.py +++ b/lama/saicinpainting/evaluation/refinement.py @@ -196,7 +196,8 @@ def _get_image_mask_pyramid(batch : dict, min_side : int, max_scales : int, px_b assert batch['image'].shape[0] == 1, "refiner works on only batches of size 1!" h, w = batch['unpad_to_size'] - h, w = h[0].item(), w[0].item() + # h, w = h[0].item(), w[0].item() + h, w = h.item(), w.item() image = batch['image'][...,:h,:w] mask = batch['mask'][...,:h,:w]