diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py index 3b324ee784e0..90527c88b38e 100755 --- a/applications/ColossalChat/examples/training_scripts/train_dpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py @@ -13,7 +13,7 @@ import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.logging import get_dist_logger from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR @@ -29,8 +29,6 @@ def train(args): # check lora compatibility if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0: raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") - if args.plugin == "gemini_auto" and args.accumulation_steps > 1: - raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin") # ============================== # Initialize Distributed Training @@ -46,7 +44,7 @@ def train(args): Default torch ddp plugin without any acceleration, for debugging purpose acceleration, for debugging purpose """ - plugin = TorchDDPPlugin(find_unused_parameters=True) + plugin = TorchDDPPlugin(find_unused_parameters=True if args.grad_checkpoint is False else False) elif args.plugin == "gemini": plugin = GeminiPlugin( precision=args.mixed_precision, @@ -56,14 +54,6 @@ def train(args): enable_gradient_accumulation=True, enable_flash_attention=args.use_flash_attn, ) - elif args.plugin == "gemini_auto": - plugin = GeminiPlugin( - precision=args.mixed_precision, - placement_policy="auto", - initial_scale=2**16, - max_norm=args.grad_clip, - enable_flash_attention=args.use_flash_attn, - ) elif args.plugin == "zero2": plugin = LowLevelZeroPlugin( stage=2, @@ -312,7 +302,7 @@ def train(args): "--plugin", type=str, default="gemini", - choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"], + choices=["gemini", "zero2", "zero2_cpu", "3d", "ddp"], help="Choose which plugin to use", ) parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")