diff --git a/train.py b/train.py index de5756e..22c324a 100644 --- a/train.py +++ b/train.py @@ -470,6 +470,10 @@ def main(): logger.info(f"Validation loss after short context recovery: {val_loss:.4f}") wandb.log({"short_context_val_loss": val_loss}) + # Save the final model + accelerator.save_state("final_model.pt") + wandb.save("final_model.pt") + # Finish logging and close the Weights & Biases run wandb.finish()