diff --git a/src/anomalib/models/patchcore/lightning_model.py b/src/anomalib/models/patchcore/lightning_model.py index 2464def2a0..8018ef99d7 100644 --- a/src/anomalib/models/patchcore/lightning_model.py +++ b/src/anomalib/models/patchcore/lightning_model.py @@ -120,7 +120,10 @@ def training_step(self, batch: dict[str, str | Tensor], *args, **kwargs) -> None if not self.trainer.sanity_checking: # Initialize the embeddings tensor with the estimated number of batches self.embeddings = torch.zeros( - (embedding.shape[0] * (self.trainer.estimated_stepping_batches), *embedding.shape[1:]), + ( + embedding.shape[0] * self.trainer.estimated_stepping_batches * self.trainer.max_epochs, + *embedding.shape[1:], + ), device=self.device, dtype=embedding.dtype, )