Skip to content

Commit

Permalink
[lora] lora support hybrid parallel plugin (hpcaitech#5956)
Browse files Browse the repository at this point in the history
* lora support hybrid plugin

* fix

* fix

* fix

* fix
  • Loading branch information
wangbluo authored Aug 2, 2024
1 parent 19d1510 commit 75c9636
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 5 deletions.
25 changes: 22 additions & 3 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
Expand Down Expand Up @@ -1187,7 +1188,7 @@ def support_no_sync(self) -> bool:
return True

def support_lora(self) -> bool:
return False
return True

def control_checkpoint_io(self) -> bool:
return True
Expand Down Expand Up @@ -1415,6 +1416,24 @@ def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()

def enable_lora(
self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
self,
model: Module,
pretrained_dir: Optional[str] = None,
lora_config: Optional[Dict] = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
) -> Module:
raise NotImplementedError
from peft import PeftModel, get_peft_model

assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
assert self.pp_size == 1 and self.tp_size == 1
self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")

if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config)

if pretrained_dir is None:
peft_model = get_peft_model(model, lora_config)
else:
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
return peft_model
14 changes: 14 additions & 0 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,3 +947,17 @@ def shard_from_complete_optimizer_state(
state_[k] = v.detach().clone().to(device)

return state_

def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
from peft import PeftModel

assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
3 changes: 3 additions & 0 deletions colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ def _fullname(obj):
# patch custom models which are not in transformers
# it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub)
# or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory)
if module.startswith("peft"):
klass = obj.base_model.model.__class__
module = klass.__module__
if module.startswith("transformers_modules"):
split_module = module.split(".")
if len(split_module) >= 2:
Expand Down
7 changes: 5 additions & 2 deletions tests/test_lora/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_checkpoint_io.utils import shared_tempdir
Expand All @@ -20,7 +21,7 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type
model = model_fn()
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)

test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()]
test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin(), HybridParallelPlugin(tp_size=1, pp_size=1)]
test_configs = [
{
"lora_config": lora_config,
Expand Down Expand Up @@ -59,6 +60,8 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type

# test fwd bwd correctness
test_model = model_load
if isinstance(model_load, HybridParallelModule):
model_load = model_load.module.module
model_copy = copy.deepcopy(model_load)

data = data_gen_fn()
Expand Down

0 comments on commit 75c9636

Please sign in to comment.