Skip to content

Commit

Permalink
refactor: Include max_epochs in pre allocation
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzomammana committed Oct 18, 2024
1 parent a5d5b72 commit 4d10383
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/anomalib/models/patchcore/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ def training_step(self, batch: dict[str, str | Tensor], *args, **kwargs) -> None
if not self.trainer.sanity_checking:
# Initialize the embeddings tensor with the estimated number of batches
self.embeddings = torch.zeros(
(embedding.shape[0] * (self.trainer.estimated_stepping_batches), *embedding.shape[1:]),
(
embedding.shape[0] * self.trainer.estimated_stepping_batches * self.trainer.max_epochs,
*embedding.shape[1:],
),
device=self.device,
dtype=embedding.dtype,
)
Expand Down

0 comments on commit 4d10383

Please sign in to comment.