diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index 95b13bc5..7df78761 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -182,11 +182,9 @@ def set_rng_generator(self, rng_generator: torch.Generator): def forward(self, batch): latents, text_embeds, text_pooled_embeds, attention_mask, encoder_attention_mask = None, None, None, None, None - if 'attention_mask' in batch: + if 'attention_mask' in batch and self.mask_pad_tokens: attention_mask = batch['attention_mask'] # mask for text encoders - # text mask for U-Net - if self.mask_pad_tokens: - encoder_attention_mask = _create_unet_attention_mask(attention_mask) + encoder_attention_mask = _create_unet_attention_mask(attention_mask) # text mask for U-Net # Use latents if specified and available. When specified, they might not exist during eval if self.precomputed_latents and self.image_latents_key in batch and self.text_latents_key in batch: