Skip to content

Commit

Permalink
early stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
33tm committed Jul 28, 2024
1 parent d397e5d commit 0e88a5e
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def forward(self, input):
if torch.cuda.is_available():
print(f"CUDA via {torch.cuda.get_device_name()}")

epochs_without_improvement = 0
best_validation_loss = float('inf')

for epoch in range(24):
model.train()
training_loss = 0
Expand All @@ -107,8 +110,17 @@ def forward(self, input):
avg_validation_loss = validation_loss / len(validation)
scheduler.step(avg_validation_loss)

print(f"[{getElapsed()}] Epoch {epoch + 1:02} of 24, Training Loss: {avg_training_loss:.4f}, Validation Loss: {avg_validation_loss:.4f}")
if avg_validation_loss < best_validation_loss:
best_validation_loss = avg_validation_loss
epochs_without_improvement = 0
torch.save(model.state_dict(), "out/model.pt")
else:
epochs_without_improvement += 1

if epochs_without_improvement >= 5:
print(f"Early stop at epoch {epoch + 1:02}")
break

torch.save(model.state_dict(), "out/model.pt")
print(f"[{getElapsed()}] Epoch {epoch + 1:02}, Training Loss: {avg_training_loss:.4f}, Validation Loss: {avg_validation_loss:.4f}")

print(f"\nFinished in {getElapsed()}")

0 comments on commit 0e88a5e

Please sign in to comment.