Skip to content

Commit

Permalink
remove unsupport plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
TongLi3701 committed Sep 10, 2024
1 parent b59c487 commit a732586
Showing 1 changed file with 3 additions and 13 deletions.
16 changes: 3 additions & 13 deletions applications/ColossalChat/examples/training_scripts/train_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
Expand All @@ -29,8 +29,6 @@ def train(args):
# check lora compatibility
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")

# ==============================
# Initialize Distributed Training
Expand All @@ -46,7 +44,7 @@ def train(args):
Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose
"""
plugin = TorchDDPPlugin(find_unused_parameters=True)
plugin = TorchDDPPlugin(find_unused_parameters=True if args.grad_checkpoint is False else False)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
Expand All @@ -56,14 +54,6 @@ def train(args):
enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
Expand Down Expand Up @@ -312,7 +302,7 @@ def train(args):
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
choices=["gemini", "zero2", "zero2_cpu", "3d", "ddp"],
help="Choose which plugin to use",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
Expand Down

0 comments on commit a732586

Please sign in to comment.