diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py index faa7a90d92de..bd95e61b769c 100755 --- a/applications/ColossalChat/coati/trainer/dpo.py +++ b/applications/ColossalChat/coati/trainer/dpo.py @@ -191,13 +191,7 @@ def _train(self, epoch: int): # DPO Loss loss = losses.mean() - self.booster.backward(loss=loss, optimizer=self.optimizer) - if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1: - self.optimizer.step() - self.optimizer.zero_grad() - self.actor_scheduler.step() - # sync loss_mean = all_reduce_mean(tensor=loss) chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards) @@ -208,10 +202,20 @@ def _train(self, epoch: int): self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item()) - if i % self.accumulation_steps == self.accumulation_steps - 1: - self.num_train_step += 1 + if (i + 1) % self.accumulation_steps == 0: + self.optimizer.step() + self.optimizer.zero_grad() + self.actor_scheduler.step() + + step_bar.set_postfix( + { + "train/loss": self.accumulative_meter.get("loss"), + "train/chosen_rewards": self.accumulative_meter.get("chosen_rewards"), + "train/rejected_rewards": self.accumulative_meter.get("rejected_rewards"), + "train/accuracy": self.accumulative_meter.get("accuracy"), + } + ) step_bar.update() - # logging if self.writer and is_rank_0(): self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step) @@ -233,25 +237,26 @@ def _train(self, epoch: int): self.accumulative_meter.get("accuracy"), self.num_train_step, ) + self.num_train_step += 1 self.accumulative_meter.reset() - if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0: - # save checkpoint - self.coordinator.print_on_master("\nStart saving model checkpoint with running states") - save_checkpoint( - save_dir=self.save_dir, - booster=self.booster, - model=self.model, - optimizer=self.optimizer, - lr_scheduler=self.actor_scheduler, - epoch=epoch, - step=i + 1, - batch_size=batch_size, - coordinator=self.coordinator, - ) - self.coordinator.print_on_master( - f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}" - ) + if self.save_dir is not None and self.num_train_step > 0 and self.num_train_step % self.save_interval == 0: + # save checkpoint + self.coordinator.print_on_master("\nStart saving model checkpoint with running states") + save_checkpoint( + save_dir=self.save_dir, + booster=self.booster, + model=self.model, + optimizer=self.optimizer, + lr_scheduler=self.actor_scheduler, + epoch=epoch, + step=self.num_train_step, + batch_size=batch_size, + coordinator=self.coordinator, + ) + self.coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}" + ) step_bar.close() @@ -356,7 +361,8 @@ def _eval(self, epoch: int): for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]: msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" self.coordinator.print_on_master(msg) - os.makedirs(self.save_dir, exist_ok=True) - with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: - f.write(msg) + if self.save_dir is not None: + os.makedirs(self.save_dir, exist_ok=True) + with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: + f.write(msg) step_bar.close()