From 75c963686f889fa68110cac460baaad08ac78f82 Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Fri, 2 Aug 2024 10:36:58 +0800 Subject: [PATCH] [lora] lora support hybrid parallel plugin (#5956) * lora support hybrid plugin * fix * fix * fix * fix --- .../booster/plugin/hybrid_parallel_plugin.py | 25 ++++++++++++++++--- .../hybrid_parallel_checkpoint_io.py | 14 +++++++++++ .../shardformer/policies/auto_policy.py | 3 +++ tests/test_lora/test_lora.py | 7 ++++-- 4 files changed, 44 insertions(+), 5 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index fcb7478140fb..d2933a4afe7f 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -30,6 +30,7 @@ from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.policies.base_policy import Policy @@ -1187,7 +1188,7 @@ def support_no_sync(self) -> bool: return True def support_lora(self) -> bool: - return False + return True def control_checkpoint_io(self) -> bool: return True @@ -1415,6 +1416,24 @@ def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]: return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() def enable_lora( - self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + self, + model: Module, + pretrained_dir: Optional[str] = None, + lora_config: Optional[Dict] = None, + bnb_quantization_config: Optional[BnbQuantizationConfig] = None, ) -> Module: - raise NotImplementedError + from peft import PeftModel, get_peft_model + + assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model." + assert self.pp_size == 1 and self.tp_size == 1 + self.lora_enabled = True + warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") + + if bnb_quantization_config is not None: + model = quantize_model(model, bnb_quantization_config) + + if pretrained_dir is None: + peft_model = get_peft_model(model, lora_config) + else: + peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True) + return peft_model diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index b7097e432a1d..0310df5489b0 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -947,3 +947,17 @@ def shard_from_complete_optimizer_state( state_[k] = v.detach().clone().to(device) return state_ + + def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + from peft import PeftModel + + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model._force_wait_all_gather() + peft_model = model.unwrap() + assert isinstance( + peft_model, PeftModel + ), "The model doesn't have lora adapters, please enable lora before saving." + return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index f2533da4ba4d..7b9c759a66c2 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -243,6 +243,9 @@ def _fullname(obj): # patch custom models which are not in transformers # it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub) # or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory) + if module.startswith("peft"): + klass = obj.base_model.model.__class__ + module = klass.__module__ if module.startswith("transformers_modules"): split_module = module.split(".") if len(split_module) >= 2: diff --git a/tests/test_lora/test_lora.py b/tests/test_lora/test_lora.py index b8daf775db0e..1ae17025d31e 100644 --- a/tests/test_lora/test_lora.py +++ b/tests/test_lora/test_lora.py @@ -9,7 +9,8 @@ import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo from tests.test_checkpoint_io.utils import shared_tempdir @@ -20,7 +21,7 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type model = model_fn() lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) - test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()] + test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin(), HybridParallelPlugin(tp_size=1, pp_size=1)] test_configs = [ { "lora_config": lora_config, @@ -59,6 +60,8 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type # test fwd bwd correctness test_model = model_load + if isinstance(model_load, HybridParallelModule): + model_load = model_load.module.module model_copy = copy.deepcopy(model_load) data = data_gen_fn()