Skip to content

Commit

Permalink
support pp training
Browse files Browse the repository at this point in the history
  • Loading branch information
TongLi3701 committed Aug 12, 2024
1 parent ad3fa4f commit 4a541aa
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 54 deletions.
3 changes: 3 additions & 0 deletions applications/ColossalChat/coati/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch.optim import Optimizer

from colossalai.booster import Booster
from colossalai.booster import Plugin

from .utils import is_rank_0

Expand All @@ -38,13 +39,15 @@ def __init__(
max_epochs: int,
model: nn.Module,
optimizer: Optimizer,
plugin: Plugin,
start_epoch: int = 0,
) -> None:
super().__init__()
self.booster = booster
self.max_epochs = max_epochs
self.model = model
self.optimizer = optimizer
self.plugin = plugin
self.start_epoch = start_epoch

@abstractmethod
Expand Down
131 changes: 78 additions & 53 deletions applications/ColossalChat/coati/trainer/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
from typing import Optional

import torch
import torch.distributed as dist
from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import trange
from tqdm import tqdm, trange

from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin, Plugin
from colossalai.cluster import DistCoordinator

from .base import SLTrainer
Expand All @@ -40,14 +42,15 @@ def __init__(
optim: Optimizer,
lr_scheduler: _LRScheduler,
max_epochs: int = 2,
plugin: Plugin = None,
accumulation_steps: int = 8,
apply_loss_mask: bool = True,
start_epoch=0,
save_interval: int = None,
save_dir: str = None,
coordinator: Optional[DistCoordinator] = None,
) -> None:
super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch)
super().__init__(booster, max_epochs, model, optim, plugin, start_epoch=start_epoch)

self.accumulation_steps = accumulation_steps
self.scheduler = lr_scheduler
Expand Down Expand Up @@ -94,60 +97,82 @@ def _before_fit(

def _train(self, epoch: int):
self.model.train()
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
batch_size = batch["input_ids"].size(0)
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
data_iter = iter(self.train_dataloader)
step_bar = tqdm(
range(len(self.train_dataloader)),
desc="Step",
disable=not (dist.get_rank() == dist.get_world_size() - 1),
)
loss = outputs.loss

self.booster.backward(loss=loss, optimizer=self.optimizer)

loss_mean = all_reduce_mean(tensor=loss)
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())

# Gradient accumulation
if (i + 1) % self.accumulation_steps == 0:
for step in step_bar:
outputs = self.booster.execute_pipeline(
data_iter,
self.model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=self.optimizer,
return_loss=True,
)
loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1:
step_bar.set_postfix({"train/loss": loss.item()})
step_bar.update()
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()

step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
self.num_train_step += 1
self.accumulative_meter.reset()
step_bar.update()

# Save checkpoint
if (
self.save_dir is not None
and self.save_interval is not None
and (self.num_train_step + 1) % self.save_interval == 0
):
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.scheduler,
epoch=epoch,
step=self.num_train_step + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
)
else:
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
batch_size = batch["input_ids"].size(0)
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
)
loss = outputs.loss

self.booster.backward(loss=loss, optimizer=self.optimizer)

loss_mean = all_reduce_mean(tensor=loss)
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())

# Gradient accumulation
if (i + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()

step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
self.num_train_step += 1
self.accumulative_meter.reset()
step_bar.update()

# Save checkpoint
if (
self.save_dir is not None
and self.save_interval is not None
and (self.num_train_step + 1) % self.save_interval == 0
):
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.scheduler,
epoch=epoch,
step=self.num_train_step + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
)
step_bar.close()

def _eval(self, epoch: int):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def train(args):
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.batch_size,
microbatch_size=args.microbatch_size,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
Expand Down Expand Up @@ -269,6 +269,7 @@ def train(args):
model=model,
booster=booster,
optim=optim,
plugin=plugin,
lr_scheduler=lr_scheduler,
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps,
Expand Down Expand Up @@ -344,6 +345,7 @@ def train(args):
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
parser.add_argument("--microbatch_size", type=int, default=1)
args = parser.parse_args()
if args.config_file is not None:
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
Expand Down

0 comments on commit 4a541aa

Please sign in to comment.