From 9b7cfb785809bef5ed84dad26c8ae925f3721c91 Mon Sep 17 00:00:00 2001 From: Tal Muskal Date: Mon, 26 Sep 2022 00:08:46 +0300 Subject: [PATCH] input image support through --init-img argument --- scripts/stable_txt2img.py | 60 ++++++++++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 13 deletions(-) diff --git a/scripts/stable_txt2img.py b/scripts/stable_txt2img.py index 1f9cc4a..5a09b6e 100644 --- a/scripts/stable_txt2img.py +++ b/scripts/stable_txt2img.py @@ -15,6 +15,7 @@ from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler +import PIL def chunk(it, size): @@ -42,6 +43,21 @@ def load_model_from_config(config, ckpt, verbose=False): return model +def load_img(path,W = None,H= None): + image = Image.open(path).convert("RGB") + w, h = image.size + print(f"loaded input image of size ({w}, {h}) from {path}") + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + if(W): + w = W + if(H): + h = H + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.*image - 1. + def main(): parser = argparse.ArgumentParser() @@ -174,7 +190,11 @@ def main(): choices=["full", "autocast"], default="autocast" ) - + parser.add_argument( + "--init-img", + type=str, + help="path to the input image" + ) parser.add_argument( "--embedding_path", @@ -188,7 +208,8 @@ def main(): opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" opt.ckpt = "models/ldm/text2img-large/model.ckpt" opt.outdir = "outputs/txt2img-samples-laion400m" - + if(opt.plms and opt.init_img): + raise Exception("input image is incompatible with PLMS") seed_everything(opt.seed) config = OmegaConf.load(f"{opt.config}") @@ -225,7 +246,14 @@ def main(): grid_count = len(os.listdir(outpath)) - 1 start_code = None - if opt.fixed_code: + t_enc = None + if(opt.init_img): + init_image = load_img(opt.init_img,opt.W,opt.H).to(device) + init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) + init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space + sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False) + t_enc = int(opt.strength * opt.ddim_steps) + elif opt.fixed_code: start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) precision_scope = autocast if opt.precision=="autocast" else nullcontext @@ -242,16 +270,22 @@ def main(): if isinstance(prompts, tuple): prompts = list(prompts) c = model.get_learned_conditioning(prompts) - shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=opt.n_samples, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta, - x_T=start_code) + if(opt.init_image != None): + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device)) + # decode it + samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc,) + else: + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code) x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)