Skip to content

Commit

Permalink
drop prompt tuning (#60)
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 authored Nov 18, 2024
1 parent e59b47f commit d108aa0
Show file tree
Hide file tree
Showing 12 changed files with 15 additions and 122 deletions.
36 changes: 1 addition & 35 deletions dolomite_engine/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
import transformers
from packaging.version import Version
from peft import PromptTuningInit
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM

from .defaults import INPUT_FORMAT, OUTPUT_FORMAT
Expand Down Expand Up @@ -81,27 +80,6 @@ def model_post_init(self, __context: Any) -> None:
self.model_class: AutoModelForCausalLM | AutoModelForSeq2SeqLM = getattr(transformers, self.model_class)


class PromptTuningArgs(BaseArgs):
# prompt tuning init method
prompt_tuning_init: PromptTuningInit = None
# prompt tuning init text
prompt_tuning_init_text: str | None = None
# number of virtual tokens for PEFT
num_virtual_tokens: int | None = None

def model_post_init(self, __context: Any) -> None:
_check_not_None([(self.prompt_tuning_init, "prompt_tuning_init")])

if self.prompt_tuning_init == PromptTuningInit.RANDOM:
assert (
self.prompt_tuning_init_text is None
), f"prompt_tuning_init_text '{self.prompt_tuning_init_text}' was specified with RANDOM init method"
elif self.prompt_tuning_init == PromptTuningInit.TEXT:
assert (
self.prompt_tuning_init_text is not None
), f"prompt_tuning_init_text needs to be specified with TEXT init method"


class LoRAArgs(BaseArgs):
# lora rank
lora_rank: int = None
Expand All @@ -117,8 +95,6 @@ def model_post_init(self, __context: Any) -> None:
class TuningArgs(BaseArgs):
# type of tuning, full finetuning or PEFT
tuning_method: TuningMethod = None
# prompt tuning related arguments
prompt_tuning_args: PromptTuningArgs | None = None
# lora related arguments
lora_args: LoRAArgs | None = None

Expand All @@ -127,17 +103,7 @@ def model_post_init(self, __context: Any) -> None:

# check whether the arguments specified are valid
if self.tuning_method in [TuningMethod.full_finetuning, TuningMethod.pretraining]:
assert (
self.prompt_tuning_args is None
), "prompt_tuning_args should not be specified with full_finetuning or pretraining"
assert self.lora_args is None, "lora_args should not be specified with full_finetuning or pretraining"
elif self.tuning_method == TuningMethod.prompt_tuning:
assert self.lora_args is None, "lora_args should not be specified with promt_tuning"
elif self.tuning_method == TuningMethod.lora:
assert self.prompt_tuning_args is None, "prompt_tuning_args should not be specified with lora"

def get_num_virtual_tokens(self) -> int:
return self.prompt_tuning_args.num_virtual_tokens if self.tuning_method == TuningMethod.prompt_tuning else 0
assert self.lora_args is None, "load_args should not be specified with full_finetuning or pretraining"


class TrainingParameters(BaseArgs):
Expand Down
5 changes: 0 additions & 5 deletions dolomite_engine/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def get_datasets_list(
mode: Mode,
tokenizer: AutoTokenizer,
is_encoder_decoder: bool,
num_virtual_tokens: int = 0,
) -> tuple[list[BaseDataset], list[int]]:
"""get the list of datasets from their configs
Expand All @@ -45,7 +44,6 @@ def get_datasets_list(
mode (Mode): training / inference mode for running the program
tokenizer (AutoTokenizer): tokenizer
is_encoder_decoder (bool): whether the model is an encoder-decoder or a decoder-only model
num_virtual_tokens (int): number of tokens to use for prompt tuning
Raises:
ValueError: if invalid class_name for dataset is found
Expand All @@ -71,7 +69,6 @@ def get_datasets_list(
output_format=data_args.output_format,
max_input_tokens=data_args.max_input_tokens,
max_output_tokens=data_args.max_output_tokens,
num_virtual_tokens=num_virtual_tokens,
)

if len(dataset) > 0:
Expand Down Expand Up @@ -163,7 +160,6 @@ def _get_source_broadcast_mapping() -> dict:
mode=Mode.training,
tokenizer=tokenizer,
is_encoder_decoder=is_encoder_decoder,
num_virtual_tokens=args.tuning_args.get_num_virtual_tokens(),
)

if len(datasets_list) == 0:
Expand Down Expand Up @@ -236,7 +232,6 @@ def _get_non_dispatching_dataloader(
mode=Mode.training,
tokenizer=tokenizer,
is_encoder_decoder=is_encoder_decoder,
num_virtual_tokens=args.tuning_args.get_num_virtual_tokens(),
)

if len(datasets_list) == 0:
Expand Down
69 changes: 9 additions & 60 deletions dolomite_engine/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(
output_format: str,
max_input_tokens: int,
max_output_tokens: int,
num_virtual_tokens: int = 0,
) -> None:
super().__init__()

Expand All @@ -32,9 +31,6 @@ def __init__(
self.tokenizer = tokenizer
self.is_encoder_decoder = is_encoder_decoder

# used for prompt tuning
self.num_virtual_tokens = num_virtual_tokens

self.data_name = data_name
self.input_format = input_format
self.output_format = output_format
Expand All @@ -44,12 +40,15 @@ def __init__(
self.do_format_output = self.output_format != OUTPUT_FORMAT

# length to use for trimming (excludes eos)
self.max_input_tokens = get_max_input_length(
max_input_tokens, self.num_virtual_tokens, self.is_encoder_decoder
)
self.max_output_tokens = get_max_output_length(
max_output_tokens, self.num_virtual_tokens, self.is_encoder_decoder
)
if max_input_tokens is None:
self.max_input_tokens = None
else:
self.max_input_tokens = max_input_tokens

if self.is_encoder_decoder:
self.max_input_tokens -= 1

self.max_output_tokens = None if max_output_tokens is None else max_output_tokens - 1

self.examples = []

Expand Down Expand Up @@ -195,53 +194,3 @@ def __repr__(self) -> str:
x += f"\nexamples in {dataset.__class__.__name__} ({dataset.data_name}) = {len(dataset)}"

return x


def get_max_input_length(
max_input_tokens_specified: int | None, num_virtual_tokens: int, is_encoder_decoder: bool
) -> int:
"""max input length for the model, depends on the training / inference type and whether the model is decoder-only or encoder-decoder
Args:
max_input_tokens_specified (int | None): maximum number of specified input tokens
num_virtual_tokens (int): virtual tokens for prompt tuning
is_encoder_decoder (bool): whether the model is decoder-only or encoder-decoder
Returns:
int: max input length
"""

if max_input_tokens_specified is None:
return None

max_input_tokens = max_input_tokens_specified - num_virtual_tokens

if is_encoder_decoder:
max_input_tokens -= 1

return max_input_tokens


def get_max_output_length(
max_output_tokens_specified: int | None, num_virtual_tokens: int, is_encoder_decoder: bool
) -> int:
"""max output length for the model, depends on the training / inference type and whether the model is decoder-only or encoder-decoder
Args:
max_output_tokens_specified (int | None): maximum number of specified output tokens
num_virtual_tokens (int): virtual tokens for prompt tuning
is_encoder_decoder (bool): whether the model is decoder-only or encoder-decoder
Returns:
int: max output length
"""

if max_output_tokens_specified is None:
return None

max_output_tokens = max_output_tokens_specified - 1

if is_encoder_decoder:
max_output_tokens -= num_virtual_tokens

return max_output_tokens
2 changes: 0 additions & 2 deletions dolomite_engine/data/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def __init__(
output_format: str,
max_input_tokens: int,
max_output_tokens: int,
num_virtual_tokens: int | None = None,
) -> None:
super().__init__(
class_args=class_args,
Expand All @@ -32,7 +31,6 @@ def __init__(
output_format=output_format,
max_input_tokens=max_input_tokens,
max_output_tokens=max_output_tokens,
num_virtual_tokens=num_virtual_tokens,
)

if self.do_format_input:
Expand Down
2 changes: 0 additions & 2 deletions dolomite_engine/data/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(
output_format: str,
max_input_tokens: int,
max_output_tokens: int,
num_virtual_tokens: int = 0,
) -> None:
super().__init__(
class_args=class_args,
Expand All @@ -33,7 +32,6 @@ def __init__(
output_format=output_format,
max_input_tokens=max_input_tokens,
max_output_tokens=max_output_tokens,
num_virtual_tokens=num_virtual_tokens,
)

self.examples = self.prepare_examples()
Expand Down
2 changes: 0 additions & 2 deletions dolomite_engine/data/instruction_tuning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def __init__(
output_format: str,
max_input_tokens: int,
max_output_tokens: int,
num_virtual_tokens: int = 0,
) -> None:
super().__init__(
class_args=class_args,
Expand All @@ -30,7 +29,6 @@ def __init__(
output_format=output_format,
max_input_tokens=max_input_tokens,
max_output_tokens=max_output_tokens,
num_virtual_tokens=num_virtual_tokens,
)

if self.do_format_input:
Expand Down
2 changes: 0 additions & 2 deletions dolomite_engine/data/sst2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(
output_format: str,
max_input_tokens: int,
max_output_tokens: int,
num_virtual_tokens: int = 0,
) -> None:
super().__init__(
class_args=class_args,
Expand All @@ -33,7 +32,6 @@ def __init__(
output_format=output_format,
max_input_tokens=max_input_tokens,
max_output_tokens=max_output_tokens,
num_virtual_tokens=num_virtual_tokens,
)

self.examples = self.prepare_examples()
Expand Down
1 change: 0 additions & 1 deletion dolomite_engine/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class TuningMethod(str, Enum):

pretraining = "pretraining"
full_finetuning = "full_finetuning"
prompt_tuning = "prompt_tuning"
lora = "lora"
distillation = "distillation"

Expand Down
1 change: 0 additions & 1 deletion dolomite_engine/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def main() -> None:
assert args.tuning_args.tuning_method in [
TuningMethod.full_finetuning,
TuningMethod.lora,
TuningMethod.prompt_tuning,
], f"unexpected tuning method ({args.tuning_args.tuning_method})"

# initialize distributed with nccl for multi-node communications
Expand Down
1 change: 0 additions & 1 deletion dolomite_engine/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def main() -> None:
mode=mode,
tokenizer=model.tokenizer,
is_encoder_decoder=model.is_encoder_decoder,
num_virtual_tokens=args_from_checkpoint.tuning_args.get_num_virtual_tokens(),
)

model = model.to(torch.cuda.current_device())
Expand Down
1 change: 0 additions & 1 deletion dolomite_engine/model_wrapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
TuningMethod.pretraining: ModelWrapperForPretraining,
TuningMethod.full_finetuning: ModelWrapperForFinetuning,
TuningMethod.lora: ModelWrapperForPEFT,
TuningMethod.prompt_tuning: ModelWrapperForPEFT,
TuningMethod.distillation: ModelWrapperForDistillation,
}

Expand Down
15 changes: 5 additions & 10 deletions dolomite_engine/model_wrapper/peft.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from peft import LoraConfig, PromptTuningConfig, TaskType, get_peft_model
from peft import LoraConfig, TaskType, get_peft_model

from ..arguments import InferenceArgs, TrainingArgs, UnshardingArgs
from ..enums import Mode, TuningMethod
Expand All @@ -20,23 +20,18 @@ def _setup_model(self, args: TrainingArgs | InferenceArgs | UnshardingArgs) -> N
model_kwargs["attn_implementation"] = self.attention_implementation.value

assert not self.use_padding_free_transformer
tuning_method = args.tuning_args.tuning_method

if args.tuning_args.tuning_method == TuningMethod.prompt_tuning:
self.peft_config = PromptTuningConfig(
task_type=TaskType.SEQ_2_SEQ_LM if self.is_encoder_decoder else TaskType.CAUSAL_LM,
prompt_tuning_init=args.tuning_args.prompt_tuning_args.prompt_tuning_init,
num_virtual_tokens=args.tuning_args.get_num_virtual_tokens(),
prompt_tuning_init_text=args.tuning_args.prompt_tuning_args.prompt_tuning_init_text,
tokenizer_name_or_path=args.model_args.model_name,
)
elif args.tuning_args.tuning_method == TuningMethod.lora:
if tuning_method == TuningMethod.lora:
self.peft_config = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM if self.is_encoder_decoder else TaskType.CAUSAL_LM,
inference_mode=self.mode != Mode.training,
r=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
)
else:
raise ValueError(f"unexpected tuning_method ({tuning_method})")

self.model = args.model_args.model_class.from_pretrained(
**model_kwargs, torch_dtype=string_to_torch_dtype(self.dtype)
Expand Down

0 comments on commit d108aa0

Please sign in to comment.