From 169efdeeefde88d7806012af9573286832fe2522 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 18 Nov 2024 14:45:20 -0500 Subject: [PATCH] drop prompt tuning Signed-off-by: Mayank Mishra --- dolomite_engine/arguments.py | 36 +--------- dolomite_engine/data/__init__.py | 5 -- dolomite_engine/data/base.py | 69 +++---------------- dolomite_engine/data/debug.py | 2 - dolomite_engine/data/huggingface.py | 2 - .../data/instruction_tuning/base.py | 2 - dolomite_engine/data/sst2.py | 2 - dolomite_engine/enums.py | 1 - dolomite_engine/finetune.py | 1 - dolomite_engine/generate.py | 1 - dolomite_engine/model_wrapper/__init__.py | 1 - dolomite_engine/model_wrapper/peft.py | 15 ++-- 12 files changed, 15 insertions(+), 122 deletions(-) diff --git a/dolomite_engine/arguments.py b/dolomite_engine/arguments.py index 1cac85b9..81cd73e4 100644 --- a/dolomite_engine/arguments.py +++ b/dolomite_engine/arguments.py @@ -5,7 +5,6 @@ import torch import transformers from packaging.version import Version -from peft import PromptTuningInit from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM from .defaults import INPUT_FORMAT, OUTPUT_FORMAT @@ -81,27 +80,6 @@ def model_post_init(self, __context: Any) -> None: self.model_class: AutoModelForCausalLM | AutoModelForSeq2SeqLM = getattr(transformers, self.model_class) -class PromptTuningArgs(BaseArgs): - # prompt tuning init method - prompt_tuning_init: PromptTuningInit = None - # prompt tuning init text - prompt_tuning_init_text: str | None = None - # number of virtual tokens for PEFT - num_virtual_tokens: int | None = None - - def model_post_init(self, __context: Any) -> None: - _check_not_None([(self.prompt_tuning_init, "prompt_tuning_init")]) - - if self.prompt_tuning_init == PromptTuningInit.RANDOM: - assert ( - self.prompt_tuning_init_text is None - ), f"prompt_tuning_init_text '{self.prompt_tuning_init_text}' was specified with RANDOM init method" - elif self.prompt_tuning_init == PromptTuningInit.TEXT: - assert ( - self.prompt_tuning_init_text is not None - ), f"prompt_tuning_init_text needs to be specified with TEXT init method" - - class LoRAArgs(BaseArgs): # lora rank lora_rank: int = None @@ -117,8 +95,6 @@ def model_post_init(self, __context: Any) -> None: class TuningArgs(BaseArgs): # type of tuning, full finetuning or PEFT tuning_method: TuningMethod = None - # prompt tuning related arguments - prompt_tuning_args: PromptTuningArgs | None = None # lora related arguments lora_args: LoRAArgs | None = None @@ -127,17 +103,7 @@ def model_post_init(self, __context: Any) -> None: # check whether the arguments specified are valid if self.tuning_method in [TuningMethod.full_finetuning, TuningMethod.pretraining]: - assert ( - self.prompt_tuning_args is None - ), "prompt_tuning_args should not be specified with full_finetuning or pretraining" - assert self.lora_args is None, "lora_args should not be specified with full_finetuning or pretraining" - elif self.tuning_method == TuningMethod.prompt_tuning: - assert self.lora_args is None, "lora_args should not be specified with promt_tuning" - elif self.tuning_method == TuningMethod.lora: - assert self.prompt_tuning_args is None, "prompt_tuning_args should not be specified with lora" - - def get_num_virtual_tokens(self) -> int: - return self.prompt_tuning_args.num_virtual_tokens if self.tuning_method == TuningMethod.prompt_tuning else 0 + assert self.lora_args is None, "load_args should not be specified with full_finetuning or pretraining" class TrainingParameters(BaseArgs): diff --git a/dolomite_engine/data/__init__.py b/dolomite_engine/data/__init__.py index 1a316038..abc810ef 100644 --- a/dolomite_engine/data/__init__.py +++ b/dolomite_engine/data/__init__.py @@ -35,7 +35,6 @@ def get_datasets_list( mode: Mode, tokenizer: AutoTokenizer, is_encoder_decoder: bool, - num_virtual_tokens: int = 0, ) -> tuple[list[BaseDataset], list[int]]: """get the list of datasets from their configs @@ -45,7 +44,6 @@ def get_datasets_list( mode (Mode): training / inference mode for running the program tokenizer (AutoTokenizer): tokenizer is_encoder_decoder (bool): whether the model is an encoder-decoder or a decoder-only model - num_virtual_tokens (int): number of tokens to use for prompt tuning Raises: ValueError: if invalid class_name for dataset is found @@ -71,7 +69,6 @@ def get_datasets_list( output_format=data_args.output_format, max_input_tokens=data_args.max_input_tokens, max_output_tokens=data_args.max_output_tokens, - num_virtual_tokens=num_virtual_tokens, ) if len(dataset) > 0: @@ -163,7 +160,6 @@ def _get_source_broadcast_mapping() -> dict: mode=Mode.training, tokenizer=tokenizer, is_encoder_decoder=is_encoder_decoder, - num_virtual_tokens=args.tuning_args.get_num_virtual_tokens(), ) if len(datasets_list) == 0: @@ -236,7 +232,6 @@ def _get_non_dispatching_dataloader( mode=Mode.training, tokenizer=tokenizer, is_encoder_decoder=is_encoder_decoder, - num_virtual_tokens=args.tuning_args.get_num_virtual_tokens(), ) if len(datasets_list) == 0: diff --git a/dolomite_engine/data/base.py b/dolomite_engine/data/base.py index fdf160af..31707547 100644 --- a/dolomite_engine/data/base.py +++ b/dolomite_engine/data/base.py @@ -20,7 +20,6 @@ def __init__( output_format: str, max_input_tokens: int, max_output_tokens: int, - num_virtual_tokens: int = 0, ) -> None: super().__init__() @@ -32,9 +31,6 @@ def __init__( self.tokenizer = tokenizer self.is_encoder_decoder = is_encoder_decoder - # used for prompt tuning - self.num_virtual_tokens = num_virtual_tokens - self.data_name = data_name self.input_format = input_format self.output_format = output_format @@ -44,12 +40,15 @@ def __init__( self.do_format_output = self.output_format != OUTPUT_FORMAT # length to use for trimming (excludes eos) - self.max_input_tokens = get_max_input_length( - max_input_tokens, self.num_virtual_tokens, self.is_encoder_decoder - ) - self.max_output_tokens = get_max_output_length( - max_output_tokens, self.num_virtual_tokens, self.is_encoder_decoder - ) + if max_input_tokens is None: + self.max_input_tokens = None + else: + self.max_input_tokens = max_input_tokens + + if self.is_encoder_decoder: + self.max_input_tokens -= 1 + + self.max_output_tokens = None if max_output_tokens is None else max_output_tokens - 1 self.examples = [] @@ -195,53 +194,3 @@ def __repr__(self) -> str: x += f"\nexamples in {dataset.__class__.__name__} ({dataset.data_name}) = {len(dataset)}" return x - - -def get_max_input_length( - max_input_tokens_specified: int | None, num_virtual_tokens: int, is_encoder_decoder: bool -) -> int: - """max input length for the model, depends on the training / inference type and whether the model is decoder-only or encoder-decoder - - Args: - max_input_tokens_specified (int | None): maximum number of specified input tokens - num_virtual_tokens (int): virtual tokens for prompt tuning - is_encoder_decoder (bool): whether the model is decoder-only or encoder-decoder - - Returns: - int: max input length - """ - - if max_input_tokens_specified is None: - return None - - max_input_tokens = max_input_tokens_specified - num_virtual_tokens - - if is_encoder_decoder: - max_input_tokens -= 1 - - return max_input_tokens - - -def get_max_output_length( - max_output_tokens_specified: int | None, num_virtual_tokens: int, is_encoder_decoder: bool -) -> int: - """max output length for the model, depends on the training / inference type and whether the model is decoder-only or encoder-decoder - - Args: - max_output_tokens_specified (int | None): maximum number of specified output tokens - num_virtual_tokens (int): virtual tokens for prompt tuning - is_encoder_decoder (bool): whether the model is decoder-only or encoder-decoder - - Returns: - int: max output length - """ - - if max_output_tokens_specified is None: - return None - - max_output_tokens = max_output_tokens_specified - 1 - - if is_encoder_decoder: - max_output_tokens -= num_virtual_tokens - - return max_output_tokens diff --git a/dolomite_engine/data/debug.py b/dolomite_engine/data/debug.py index 03d03061..a0325b11 100644 --- a/dolomite_engine/data/debug.py +++ b/dolomite_engine/data/debug.py @@ -19,7 +19,6 @@ def __init__( output_format: str, max_input_tokens: int, max_output_tokens: int, - num_virtual_tokens: int | None = None, ) -> None: super().__init__( class_args=class_args, @@ -32,7 +31,6 @@ def __init__( output_format=output_format, max_input_tokens=max_input_tokens, max_output_tokens=max_output_tokens, - num_virtual_tokens=num_virtual_tokens, ) if self.do_format_input: diff --git a/dolomite_engine/data/huggingface.py b/dolomite_engine/data/huggingface.py index 97d1aa90..1007de3d 100644 --- a/dolomite_engine/data/huggingface.py +++ b/dolomite_engine/data/huggingface.py @@ -20,7 +20,6 @@ def __init__( output_format: str, max_input_tokens: int, max_output_tokens: int, - num_virtual_tokens: int = 0, ) -> None: super().__init__( class_args=class_args, @@ -33,7 +32,6 @@ def __init__( output_format=output_format, max_input_tokens=max_input_tokens, max_output_tokens=max_output_tokens, - num_virtual_tokens=num_virtual_tokens, ) self.examples = self.prepare_examples() diff --git a/dolomite_engine/data/instruction_tuning/base.py b/dolomite_engine/data/instruction_tuning/base.py index ed9cf243..c6319827 100644 --- a/dolomite_engine/data/instruction_tuning/base.py +++ b/dolomite_engine/data/instruction_tuning/base.py @@ -17,7 +17,6 @@ def __init__( output_format: str, max_input_tokens: int, max_output_tokens: int, - num_virtual_tokens: int = 0, ) -> None: super().__init__( class_args=class_args, @@ -30,7 +29,6 @@ def __init__( output_format=output_format, max_input_tokens=max_input_tokens, max_output_tokens=max_output_tokens, - num_virtual_tokens=num_virtual_tokens, ) if self.do_format_input: diff --git a/dolomite_engine/data/sst2.py b/dolomite_engine/data/sst2.py index 91b4cb8d..474fa48a 100644 --- a/dolomite_engine/data/sst2.py +++ b/dolomite_engine/data/sst2.py @@ -20,7 +20,6 @@ def __init__( output_format: str, max_input_tokens: int, max_output_tokens: int, - num_virtual_tokens: int = 0, ) -> None: super().__init__( class_args=class_args, @@ -33,7 +32,6 @@ def __init__( output_format=output_format, max_input_tokens=max_input_tokens, max_output_tokens=max_output_tokens, - num_virtual_tokens=num_virtual_tokens, ) self.examples = self.prepare_examples() diff --git a/dolomite_engine/enums.py b/dolomite_engine/enums.py index 41ba915a..8fd46de9 100644 --- a/dolomite_engine/enums.py +++ b/dolomite_engine/enums.py @@ -58,7 +58,6 @@ class TuningMethod(str, Enum): pretraining = "pretraining" full_finetuning = "full_finetuning" - prompt_tuning = "prompt_tuning" lora = "lora" distillation = "distillation" diff --git a/dolomite_engine/finetune.py b/dolomite_engine/finetune.py index b9c433d5..1dd8caa4 100644 --- a/dolomite_engine/finetune.py +++ b/dolomite_engine/finetune.py @@ -224,7 +224,6 @@ def main() -> None: assert args.tuning_args.tuning_method in [ TuningMethod.full_finetuning, TuningMethod.lora, - TuningMethod.prompt_tuning, ], f"unexpected tuning method ({args.tuning_args.tuning_method})" # initialize distributed with nccl for multi-node communications diff --git a/dolomite_engine/generate.py b/dolomite_engine/generate.py index 9f2274ae..d2062756 100644 --- a/dolomite_engine/generate.py +++ b/dolomite_engine/generate.py @@ -127,7 +127,6 @@ def main() -> None: mode=mode, tokenizer=model.tokenizer, is_encoder_decoder=model.is_encoder_decoder, - num_virtual_tokens=args_from_checkpoint.tuning_args.get_num_virtual_tokens(), ) model = model.to(torch.cuda.current_device()) diff --git a/dolomite_engine/model_wrapper/__init__.py b/dolomite_engine/model_wrapper/__init__.py index f0404f21..a4d24009 100644 --- a/dolomite_engine/model_wrapper/__init__.py +++ b/dolomite_engine/model_wrapper/__init__.py @@ -13,7 +13,6 @@ TuningMethod.pretraining: ModelWrapperForPretraining, TuningMethod.full_finetuning: ModelWrapperForFinetuning, TuningMethod.lora: ModelWrapperForPEFT, - TuningMethod.prompt_tuning: ModelWrapperForPEFT, TuningMethod.distillation: ModelWrapperForDistillation, } diff --git a/dolomite_engine/model_wrapper/peft.py b/dolomite_engine/model_wrapper/peft.py index f8cc6a07..502e9661 100644 --- a/dolomite_engine/model_wrapper/peft.py +++ b/dolomite_engine/model_wrapper/peft.py @@ -1,4 +1,4 @@ -from peft import LoraConfig, PromptTuningConfig, TaskType, get_peft_model +from peft import LoraConfig, TaskType, get_peft_model from ..arguments import InferenceArgs, TrainingArgs, UnshardingArgs from ..enums import Mode, TuningMethod @@ -20,16 +20,9 @@ def _setup_model(self, args: TrainingArgs | InferenceArgs | UnshardingArgs) -> N model_kwargs["attn_implementation"] = self.attention_implementation.value assert not self.use_padding_free_transformer + tuning_method = args.tuning_args.tuning_method - if args.tuning_args.tuning_method == TuningMethod.prompt_tuning: - self.peft_config = PromptTuningConfig( - task_type=TaskType.SEQ_2_SEQ_LM if self.is_encoder_decoder else TaskType.CAUSAL_LM, - prompt_tuning_init=args.tuning_args.prompt_tuning_args.prompt_tuning_init, - num_virtual_tokens=args.tuning_args.get_num_virtual_tokens(), - prompt_tuning_init_text=args.tuning_args.prompt_tuning_args.prompt_tuning_init_text, - tokenizer_name_or_path=args.model_args.model_name, - ) - elif args.tuning_args.tuning_method == TuningMethod.lora: + if tuning_method == TuningMethod.lora: self.peft_config = LoraConfig( task_type=TaskType.SEQ_2_SEQ_LM if self.is_encoder_decoder else TaskType.CAUSAL_LM, inference_mode=self.mode != Mode.training, @@ -37,6 +30,8 @@ def _setup_model(self, args: TrainingArgs | InferenceArgs | UnshardingArgs) -> N lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, ) + else: + raise ValueError(f"unexpected tuning_method ({tuning_method})") self.model = args.model_args.model_class.from_pretrained( **model_kwargs, torch_dtype=string_to_torch_dtype(self.dtype)