From 0c2abb121c1ff2555fc0eb38ba5b8343d816802a Mon Sep 17 00:00:00 2001 From: John Pope Date: Thu, 10 Oct 2024 16:30:00 +1100 Subject: [PATCH] fix for resuming --- config.yaml | 2 +- train.py | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/config.yaml b/config.yaml index 984f4f4..18c9167 100644 --- a/config.yaml +++ b/config.yaml @@ -11,7 +11,7 @@ profiling: profile_step: 10 training: - load_checkpoint: False # Set this to true when you want to load from a checkpoint + load_checkpoint: True # Set this to true when you want to load from a checkpoint checkpoint_path: './checkpoints/checkpoint.pth' use_eye_loss: False use_subsampling: False # saves ram? https://github.com/johndpope/MegaPortrait-hack/issues/41 diff --git a/train.py b/train.py index 800a86d..27cf595 100644 --- a/train.py +++ b/train.py @@ -332,17 +332,22 @@ def save_checkpoint(self, epoch, is_final=False): def load_checkpoint(self, checkpoint_path): try: - checkpoint = self.accelerator.load(checkpoint_path) + checkpoint = torch.load(checkpoint_path, map_location=self.accelerator.device) - self.model.load_state_dict(checkpoint['model_state_dict']) - self.discriminator.load_state_dict(checkpoint['discriminator_state_dict']) + # Unwrap the models before loading state dict + unwrapped_model = self.accelerator.unwrap_model(self.model) + unwrapped_discriminator = self.accelerator.unwrap_model(self.discriminator) + + unwrapped_model.load_state_dict(checkpoint['model_state_dict']) + unwrapped_discriminator.load_state_dict(checkpoint['discriminator_state_dict']) self.optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict']) self.optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict']) self.scheduler_g.load_state_dict(checkpoint['scheduler_g_state_dict']) self.scheduler_d.load_state_dict(checkpoint['scheduler_d_state_dict']) if self.ema and 'ema_state_dict' in checkpoint: - self.ema.load_state_dict(checkpoint['ema_state_dict']) + unwrapped_ema = self.accelerator.unwrap_model(self.ema) + unwrapped_ema.load_state_dict(checkpoint['ema_state_dict']) start_epoch = checkpoint['epoch'] + 1 print(f"Loaded checkpoint from epoch {start_epoch - 1}") @@ -398,6 +403,9 @@ def main(): collate_fn=gpu_padded_collate ) + print("using float32 for onnx training....") + torch.set_default_dtype(torch.float32) + trainer = IMFTrainer(config, model, discriminator, dataloader, accelerator) # Check if a checkpoint path is provided in the config