diff --git a/applications/ColossalChat/coati/trainer/base.py b/applications/ColossalChat/coati/trainer/base.py index 63c903a51940..2e63fc5c8971 100755 --- a/applications/ColossalChat/coati/trainer/base.py +++ b/applications/ColossalChat/coati/trainer/base.py @@ -17,6 +17,7 @@ from torch.optim import Optimizer from colossalai.booster import Booster +from colossalai.booster import Plugin from .utils import is_rank_0 @@ -38,6 +39,7 @@ def __init__( max_epochs: int, model: nn.Module, optimizer: Optimizer, + plugin: Plugin, start_epoch: int = 0, ) -> None: super().__init__() @@ -45,6 +47,7 @@ def __init__( self.max_epochs = max_epochs self.model = model self.optimizer = optimizer + self.plugin = plugin self.start_epoch = start_epoch @abstractmethod diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py index d37676ada3e0..ebdfd502491f 100755 --- a/applications/ColossalChat/coati/trainer/sft.py +++ b/applications/ColossalChat/coati/trainer/sft.py @@ -6,14 +6,16 @@ from typing import Optional import torch +import torch.distributed as dist from coati.trainer.utils import all_reduce_mean from coati.utils import AccumulativeMeanMeter, save_checkpoint from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader -from tqdm import trange +from tqdm import tqdm, trange from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin, Plugin from colossalai.cluster import DistCoordinator from .base import SLTrainer @@ -40,6 +42,7 @@ def __init__( optim: Optimizer, lr_scheduler: _LRScheduler, max_epochs: int = 2, + plugin: Plugin = None, accumulation_steps: int = 8, apply_loss_mask: bool = True, start_epoch=0, @@ -47,7 +50,7 @@ def __init__( save_dir: str = None, coordinator: Optional[DistCoordinator] = None, ) -> None: - super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch) + super().__init__(booster, max_epochs, model, optim, plugin, start_epoch=start_epoch) self.accumulation_steps = accumulation_steps self.scheduler = lr_scheduler @@ -94,60 +97,82 @@ def _before_fit( def _train(self, epoch: int): self.model.train() - step_bar = trange( - len(self.train_dataloader) // self.accumulation_steps, - desc=f"Epoch {epoch + 1}/{self.max_epochs}", - disable=not is_rank_0(), - ) - for i, batch in enumerate(self.train_dataloader): - batch = to_device(batch, torch.cuda.current_device()) - batch_size = batch["input_ids"].size(0) - outputs = self.model( - batch["input_ids"], - attention_mask=batch["attention_mask"], - labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"], + if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1: + data_iter = iter(self.train_dataloader) + step_bar = tqdm( + range(len(self.train_dataloader)), + desc="Step", + disable=not (dist.get_rank() == dist.get_world_size() - 1), ) - loss = outputs.loss - - self.booster.backward(loss=loss, optimizer=self.optimizer) - - loss_mean = all_reduce_mean(tensor=loss) - self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) - - # Gradient accumulation - if (i + 1) % self.accumulation_steps == 0: + for step in step_bar: + outputs = self.booster.execute_pipeline( + data_iter, + self.model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=self.optimizer, + return_loss=True, + ) + loss = outputs["loss"] + if dist.get_rank() == dist.get_world_size() - 1: + step_bar.set_postfix({"train/loss": loss.item()}) + step_bar.update() self.optimizer.step() self.optimizer.zero_grad() - self.scheduler.step() - - step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")}) - if self.writer: - self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) - self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step) - self.num_train_step += 1 - self.accumulative_meter.reset() - step_bar.update() - - # Save checkpoint - if ( - self.save_dir is not None - and self.save_interval is not None - and (self.num_train_step + 1) % self.save_interval == 0 - ): - save_checkpoint( - save_dir=self.save_dir, - booster=self.booster, - model=self.model, - optimizer=self.optimizer, - lr_scheduler=self.scheduler, - epoch=epoch, - step=self.num_train_step + 1, - batch_size=batch_size, - coordinator=self.coordinator, - ) - self.coordinator.print_on_master( - f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}" - ) + else: + step_bar = trange( + len(self.train_dataloader) // self.accumulation_steps, + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for i, batch in enumerate(self.train_dataloader): + batch = to_device(batch, torch.cuda.current_device()) + batch_size = batch["input_ids"].size(0) + outputs = self.model( + batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"], + ) + loss = outputs.loss + + self.booster.backward(loss=loss, optimizer=self.optimizer) + + loss_mean = all_reduce_mean(tensor=loss) + self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) + + # Gradient accumulation + if (i + 1) % self.accumulation_steps == 0: + self.optimizer.step() + self.optimizer.zero_grad() + self.scheduler.step() + + step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")}) + if self.writer: + self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) + self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step) + self.num_train_step += 1 + self.accumulative_meter.reset() + step_bar.update() + + # Save checkpoint + if ( + self.save_dir is not None + and self.save_interval is not None + and (self.num_train_step + 1) % self.save_interval == 0 + ): + save_checkpoint( + save_dir=self.save_dir, + booster=self.booster, + model=self.model, + optimizer=self.optimizer, + lr_scheduler=self.scheduler, + epoch=epoch, + step=self.num_train_step + 1, + batch_size=batch_size, + coordinator=self.coordinator, + ) + self.coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}" + ) step_bar.close() def _eval(self, epoch: int): diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py index c4ef3b783d4d..62acad32f66a 100755 --- a/applications/ColossalChat/examples/training_scripts/train_sft.py +++ b/applications/ColossalChat/examples/training_scripts/train_sft.py @@ -114,7 +114,7 @@ def train(args): parallel_output=False, max_norm=args.grad_clip, precision=args.mixed_precision, - microbatch_size=args.batch_size, + microbatch_size=args.microbatch_size, ) else: raise ValueError(f"Unknown plugin {args.plugin}") @@ -269,6 +269,7 @@ def train(args): model=model, booster=booster, optim=optim, + plugin=plugin, lr_scheduler=lr_scheduler, max_epochs=args.max_epochs, accumulation_steps=args.accumulation_steps, @@ -344,6 +345,7 @@ def train(args): parser.add_argument("--use_wandb", default=False, action="store_true") parser.add_argument("--grad_checkpoint", default=False, action="store_true") parser.add_argument("--use_flash_attn", default=False, action="store_true") + parser.add_argument("--microbatch_size", type=int, default=1) args = parser.parse_args() if args.config_file is not None: os.makedirs(os.path.dirname(args.config_file), exist_ok=True)