diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 78b90f8c3..ab5399842 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -504,7 +504,8 @@ def __call__( uncond_embeddings = tes_uncond_embs[0] for i in range(1, len(tes_text_embs)): text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 - uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 + if do_classifier_free_guidance: + uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 if do_classifier_free_guidance: if negative_scale is None: @@ -567,9 +568,11 @@ def __call__( text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) c_vector = torch.cat([text_pool, c_vector], dim=1) - uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) - - vector_embeddings = torch.cat([uc_vector, c_vector]) + if do_classifier_free_guidance: + uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) + vector_embeddings = torch.cat([uc_vector, c_vector]) + else: + vector_embeddings = c_vector # set timesteps self.scheduler.set_timesteps(num_inference_steps, self.device)