Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

drop prompt tuning #60

Merged
merged 1 commit into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 1 addition & 35 deletions dolomite_engine/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
import transformers
from packaging.version import Version
from peft import PromptTuningInit
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM

from .defaults import INPUT_FORMAT, OUTPUT_FORMAT
Expand Down Expand Up @@ -81,27 +80,6 @@ def model_post_init(self, __context: Any) -> None:
self.model_class: AutoModelForCausalLM | AutoModelForSeq2SeqLM = getattr(transformers, self.model_class)


class PromptTuningArgs(BaseArgs):
# prompt tuning init method
prompt_tuning_init: PromptTuningInit = None
# prompt tuning init text
prompt_tuning_init_text: str | None = None
# number of virtual tokens for PEFT
num_virtual_tokens: int | None = None

def model_post_init(self, __context: Any) -> None:
_check_not_None([(self.prompt_tuning_init, "prompt_tuning_init")])

if self.prompt_tuning_init == PromptTuningInit.RANDOM:
assert (
self.prompt_tuning_init_text is None
), f"prompt_tuning_init_text '{self.prompt_tuning_init_text}' was specified with RANDOM init method"
elif self.prompt_tuning_init == PromptTuningInit.TEXT:
assert (
self.prompt_tuning_init_text is not None
), f"prompt_tuning_init_text needs to be specified with TEXT init method"


class LoRAArgs(BaseArgs):
# lora rank
lora_rank: int = None
Expand All @@ -117,8 +95,6 @@ def model_post_init(self, __context: Any) -> None:
class TuningArgs(BaseArgs):
# type of tuning, full finetuning or PEFT
tuning_method: TuningMethod = None
# prompt tuning related arguments
prompt_tuning_args: PromptTuningArgs | None = None
# lora related arguments
lora_args: LoRAArgs | None = None

Expand All @@ -127,17 +103,7 @@ def model_post_init(self, __context: Any) -> None:

# check whether the arguments specified are valid
if self.tuning_method in [TuningMethod.full_finetuning, TuningMethod.pretraining]:
assert (
self.prompt_tuning_args is None
), "prompt_tuning_args should not be specified with full_finetuning or pretraining"
assert self.lora_args is None, "lora_args should not be specified with full_finetuning or pretraining"
elif self.tuning_method == TuningMethod.prompt_tuning:
assert self.lora_args is None, "lora_args should not be specified with promt_tuning"
elif self.tuning_method == TuningMethod.lora:
assert self.prompt_tuning_args is None, "prompt_tuning_args should not be specified with lora"

def get_num_virtual_tokens(self) -> int:
return self.prompt_tuning_args.num_virtual_tokens if self.tuning_method == TuningMethod.prompt_tuning else 0
assert self.lora_args is None, "load_args should not be specified with full_finetuning or pretraining"


class TrainingParameters(BaseArgs):
Expand Down
5 changes: 0 additions & 5 deletions dolomite_engine/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def get_datasets_list(
mode: Mode,
tokenizer: AutoTokenizer,
is_encoder_decoder: bool,
num_virtual_tokens: int = 0,
) -> tuple[list[BaseDataset], list[int]]:
"""get the list of datasets from their configs

Expand All @@ -45,7 +44,6 @@ def get_datasets_list(
mode (Mode): training / inference mode for running the program
tokenizer (AutoTokenizer): tokenizer
is_encoder_decoder (bool): whether the model is an encoder-decoder or a decoder-only model
num_virtual_tokens (int): number of tokens to use for prompt tuning

Raises:
ValueError: if invalid class_name for dataset is found
Expand All @@ -71,7 +69,6 @@ def get_datasets_list(
output_format=data_args.output_format,
max_input_tokens=data_args.max_input_tokens,
max_output_tokens=data_args.max_output_tokens,
num_virtual_tokens=num_virtual_tokens,
)

if len(dataset) > 0:
Expand Down Expand Up @@ -163,7 +160,6 @@ def _get_source_broadcast_mapping() -> dict:
mode=Mode.training,
tokenizer=tokenizer,
is_encoder_decoder=is_encoder_decoder,
num_virtual_tokens=args.tuning_args.get_num_virtual_tokens(),
)

if len(datasets_list) == 0:
Expand Down Expand Up @@ -236,7 +232,6 @@ def _get_non_dispatching_dataloader(
mode=Mode.training,
tokenizer=tokenizer,
is_encoder_decoder=is_encoder_decoder,
num_virtual_tokens=args.tuning_args.get_num_virtual_tokens(),
)

if len(datasets_list) == 0:
Expand Down
69 changes: 9 additions & 60 deletions dolomite_engine/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(
output_format: str,
max_input_tokens: int,
max_output_tokens: int,
num_virtual_tokens: int = 0,
) -> None:
super().__init__()

Expand All @@ -32,9 +31,6 @@ def __init__(
self.tokenizer = tokenizer
self.is_encoder_decoder = is_encoder_decoder

# used for prompt tuning
self.num_virtual_tokens = num_virtual_tokens

self.data_name = data_name
self.input_format = input_format
self.output_format = output_format
Expand All @@ -44,12 +40,15 @@ def __init__(
self.do_format_output = self.output_format != OUTPUT_FORMAT

# length to use for trimming (excludes eos)
self.max_input_tokens = get_max_input_length(
max_input_tokens, self.num_virtual_tokens, self.is_encoder_decoder
)
self.max_output_tokens = get_max_output_length(
max_output_tokens, self.num_virtual_tokens, self.is_encoder_decoder
)
if max_input_tokens is None:
self.max_input_tokens = None
else:
self.max_input_tokens = max_input_tokens

if self.is_encoder_decoder:
self.max_input_tokens -= 1

self.max_output_tokens = None if max_output_tokens is None else max_output_tokens - 1

self.examples = []

Expand Down Expand Up @@ -195,53 +194,3 @@ def __repr__(self) -> str:
x += f"\nexamples in {dataset.__class__.__name__} ({dataset.data_name}) = {len(dataset)}"

return x


def get_max_input_length(
max_input_tokens_specified: int | None, num_virtual_tokens: int, is_encoder_decoder: bool
) -> int:
"""max input length for the model, depends on the training / inference type and whether the model is decoder-only or encoder-decoder

Args:
max_input_tokens_specified (int | None): maximum number of specified input tokens
num_virtual_tokens (int): virtual tokens for prompt tuning
is_encoder_decoder (bool): whether the model is decoder-only or encoder-decoder

Returns:
int: max input length
"""

if max_input_tokens_specified is None:
return None

max_input_tokens = max_input_tokens_specified - num_virtual_tokens

if is_encoder_decoder:
max_input_tokens -= 1

return max_input_tokens


def get_max_output_length(
max_output_tokens_specified: int | None, num_virtual_tokens: int, is_encoder_decoder: bool
) -> int:
"""max output length for the model, depends on the training / inference type and whether the model is decoder-only or encoder-decoder

Args:
max_output_tokens_specified (int | None): maximum number of specified output tokens
num_virtual_tokens (int): virtual tokens for prompt tuning
is_encoder_decoder (bool): whether the model is decoder-only or encoder-decoder

Returns:
int: max output length
"""

if max_output_tokens_specified is None:
return None

max_output_tokens = max_output_tokens_specified - 1

if is_encoder_decoder:
max_output_tokens -= num_virtual_tokens

return max_output_tokens
2 changes: 0 additions & 2 deletions dolomite_engine/data/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def __init__(
output_format: str,
max_input_tokens: int,
max_output_tokens: int,
num_virtual_tokens: int | None = None,
) -> None:
super().__init__(
class_args=class_args,
Expand All @@ -32,7 +31,6 @@ def __init__(
output_format=output_format,
max_input_tokens=max_input_tokens,
max_output_tokens=max_output_tokens,
num_virtual_tokens=num_virtual_tokens,
)

if self.do_format_input:
Expand Down
2 changes: 0 additions & 2 deletions dolomite_engine/data/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(
output_format: str,
max_input_tokens: int,
max_output_tokens: int,
num_virtual_tokens: int = 0,
) -> None:
super().__init__(
class_args=class_args,
Expand All @@ -33,7 +32,6 @@ def __init__(
output_format=output_format,
max_input_tokens=max_input_tokens,
max_output_tokens=max_output_tokens,
num_virtual_tokens=num_virtual_tokens,
)

self.examples = self.prepare_examples()
Expand Down
2 changes: 0 additions & 2 deletions dolomite_engine/data/instruction_tuning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def __init__(
output_format: str,
max_input_tokens: int,
max_output_tokens: int,
num_virtual_tokens: int = 0,
) -> None:
super().__init__(
class_args=class_args,
Expand All @@ -30,7 +29,6 @@ def __init__(
output_format=output_format,
max_input_tokens=max_input_tokens,
max_output_tokens=max_output_tokens,
num_virtual_tokens=num_virtual_tokens,
)

if self.do_format_input:
Expand Down
2 changes: 0 additions & 2 deletions dolomite_engine/data/sst2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(
output_format: str,
max_input_tokens: int,
max_output_tokens: int,
num_virtual_tokens: int = 0,
) -> None:
super().__init__(
class_args=class_args,
Expand All @@ -33,7 +32,6 @@ def __init__(
output_format=output_format,
max_input_tokens=max_input_tokens,
max_output_tokens=max_output_tokens,
num_virtual_tokens=num_virtual_tokens,
)

self.examples = self.prepare_examples()
Expand Down
1 change: 0 additions & 1 deletion dolomite_engine/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class TuningMethod(str, Enum):

pretraining = "pretraining"
full_finetuning = "full_finetuning"
prompt_tuning = "prompt_tuning"
lora = "lora"
distillation = "distillation"

Expand Down
1 change: 0 additions & 1 deletion dolomite_engine/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def main() -> None:
assert args.tuning_args.tuning_method in [
TuningMethod.full_finetuning,
TuningMethod.lora,
TuningMethod.prompt_tuning,
], f"unexpected tuning method ({args.tuning_args.tuning_method})"

# initialize distributed with nccl for multi-node communications
Expand Down
1 change: 0 additions & 1 deletion dolomite_engine/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def main() -> None:
mode=mode,
tokenizer=model.tokenizer,
is_encoder_decoder=model.is_encoder_decoder,
num_virtual_tokens=args_from_checkpoint.tuning_args.get_num_virtual_tokens(),
)

model = model.to(torch.cuda.current_device())
Expand Down
1 change: 0 additions & 1 deletion dolomite_engine/model_wrapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
TuningMethod.pretraining: ModelWrapperForPretraining,
TuningMethod.full_finetuning: ModelWrapperForFinetuning,
TuningMethod.lora: ModelWrapperForPEFT,
TuningMethod.prompt_tuning: ModelWrapperForPEFT,
TuningMethod.distillation: ModelWrapperForDistillation,
}

Expand Down
15 changes: 5 additions & 10 deletions dolomite_engine/model_wrapper/peft.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from peft import LoraConfig, PromptTuningConfig, TaskType, get_peft_model
from peft import LoraConfig, TaskType, get_peft_model

from ..arguments import InferenceArgs, TrainingArgs, UnshardingArgs
from ..enums import Mode, TuningMethod
Expand All @@ -20,23 +20,18 @@ def _setup_model(self, args: TrainingArgs | InferenceArgs | UnshardingArgs) -> N
model_kwargs["attn_implementation"] = self.attention_implementation.value

assert not self.use_padding_free_transformer
tuning_method = args.tuning_args.tuning_method

if args.tuning_args.tuning_method == TuningMethod.prompt_tuning:
self.peft_config = PromptTuningConfig(
task_type=TaskType.SEQ_2_SEQ_LM if self.is_encoder_decoder else TaskType.CAUSAL_LM,
prompt_tuning_init=args.tuning_args.prompt_tuning_args.prompt_tuning_init,
num_virtual_tokens=args.tuning_args.get_num_virtual_tokens(),
prompt_tuning_init_text=args.tuning_args.prompt_tuning_args.prompt_tuning_init_text,
tokenizer_name_or_path=args.model_args.model_name,
)
elif args.tuning_args.tuning_method == TuningMethod.lora:
if tuning_method == TuningMethod.lora:
self.peft_config = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM if self.is_encoder_decoder else TaskType.CAUSAL_LM,
inference_mode=self.mode != Mode.training,
r=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
)
else:
raise ValueError(f"unexpected tuning_method ({tuning_method})")

self.model = args.model_args.model_class.from_pretrained(
**model_kwargs, torch_dtype=string_to_torch_dtype(self.dtype)
Expand Down
Loading