diff --git a/apps/accelerate/chatllama/artifacts/config/config.yaml b/apps/accelerate/chatllama/artifacts/config/config.yaml index bbffcd04..1cdba191 100644 --- a/apps/accelerate/chatllama/artifacts/config/config.yaml +++ b/apps/accelerate/chatllama/artifacts/config/config.yaml @@ -51,7 +51,7 @@ actor_config: additonal_prompt_tokens: 20 # temperature for the actor temperature: 0.1 - batch_size: 1 + batch_size: 2 # number iteration after print iteration_per_print: 1 lr: 0.000009 @@ -69,6 +69,9 @@ actor_config: deepspeed_config_path: "./artifacts/config/ds_config.json" # accelerate settings accelerate_enable: False + # use_peft + peft_enable: True + peft_config_path: "./artifacts/config/peft_config.yaml" reward_config: # model to be chosen are gp2-large, bart-base, longformer-base-4096 diff --git a/apps/accelerate/chatllama/artifacts/config/peft_config.yaml b/apps/accelerate/chatllama/artifacts/config/peft_config.yaml new file mode 100644 index 00000000..c19599b4 --- /dev/null +++ b/apps/accelerate/chatllama/artifacts/config/peft_config.yaml @@ -0,0 +1,5 @@ +--- +inference_mode: False +r: 8 +lora_alpha: 32 +lora_dropout: 0.1 diff --git a/apps/accelerate/chatllama/chatllama/rlhf/actor.py b/apps/accelerate/chatllama/chatllama/rlhf/actor.py index 36a72de3..ae33b4b1 100644 --- a/apps/accelerate/chatllama/chatllama/rlhf/actor.py +++ b/apps/accelerate/chatllama/chatllama/rlhf/actor.py @@ -1,6 +1,7 @@ import json -import shutil +import yaml import os +import shutil import deepspeed import torch @@ -8,6 +9,7 @@ from beartype import beartype from beartype.typing import Tuple from einops import rearrange +from peft import get_peft_model, LoraConfig, TaskType from torch.utils.data import DataLoader, Dataset from transformers import ( AutoModelForCausalLM, @@ -72,7 +74,34 @@ def __init__(self, config: ConfigActor) -> None: self.model = AutoModelForCausalLM.from_pretrained( config.model, ) + + # Setup PEFT model + if config.peft_enable: + + # check that the peft config exist + if os.path.exists(config.peft_config_path): + # Read the peft config from yaml + with open(config.peft_config_path, "r") as c: + config_peft = yaml.safe_load(c) + else: + raise ValueError( + f"PEFT config {config.peft_config_path} not found" + ) + + print(config_peft) + # define lora config for peft + peft_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, **config_peft + ) + + # create peft model + self.model = get_peft_model( + model=self.model, + peft_config=peft_config, + ) + self.model.to(config.device) + else: raise ValueError(f"Model {config.model} not supported") diff --git a/apps/accelerate/chatllama/chatllama/rlhf/config.py b/apps/accelerate/chatllama/chatllama/rlhf/config.py index d5904ed9..baaa5380 100644 --- a/apps/accelerate/chatllama/chatllama/rlhf/config.py +++ b/apps/accelerate/chatllama/chatllama/rlhf/config.py @@ -125,7 +125,10 @@ class ConfigActor: device (torch.device): Device to be used for the actor checkpoint_name (Optional[str]): Name of the checkpoint. Default to None. + peft_enable (bool): Enable peft for the actor + peft_config_path (str): Path to the peft config file. debug (bool): Enable prints for debugging + """ model: str @@ -153,6 +156,8 @@ class ConfigActor: accelerate_enable: bool device: torch.device + peft_enable: bool + peft_config_path: str checkpoint_name: Optional[str] = None debug: bool = False