Skip to content

Commit

Permalink
Merge pull request #173 from gkumbhat/change_param_name
Browse files Browse the repository at this point in the history
🚚 Change parameter name and mark them optional for runtime
  • Loading branch information
gkumbhat authored Sep 7, 2023
2 parents 1855669 + f7b68a7 commit c16c33a
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,21 +298,23 @@ def train(
DataStream[ClassificationTrainRecord],
],
tuning_config: TuningConfig,
val_stream: Union[
DataStream[GenerationTrainRecord],
DataStream[ClassificationTrainRecord],
val_stream: Optional[
Union[
DataStream[GenerationTrainRecord],
DataStream[ClassificationTrainRecord],
]
] = None, # TODO: Optional[DataStream[GenerationTrainRecord]]
device: str = _DETECT_DEVICE, # TODO: Union[int, str]
tuning_type: str = "PROMPT_TUNING", # TODO: Union[str, TuningType]
num_epochs: int = 20,
lr: float = 0.3,
verbalizer: str = "{{input}}",
batch_size: int = 8,
max_source_length: int = 256,
max_target_length: int = 128,
accumulate_steps: int = 32,
torch_dtype: str = None, # TODO: Optional[Union[torch.dtype, str]]
silence_progress_bars: bool = True,
device: Optional[str] = _DETECT_DEVICE, # TODO: Union[int, str]
tuning_type: Optional[str] = "PROMPT_TUNING", # TODO: Union[str, TuningType]
num_epochs: Optional[int] = 20,
learning_rate: Optional[float] = 0.3,
verbalizer: Optional[str] = "{{input}}",
batch_size: Optional[int] = 8,
max_source_length: Optional[int] = 256,
max_target_length: Optional[int] = 128,
accumulate_steps: Optional[int] = 32,
torch_dtype: Optional[str] = None, # TODO: Optional[Union[torch.dtype, str]]
silence_progress_bars: Optional[bool] = True,
**kwargs,
) -> "PeftPromptTuning":
"""Run prompt tuning (vanilla or MPT) through PEFT on a CausalLM or Seq2seq model
Expand All @@ -336,7 +338,7 @@ def train(
Type of Peft Tuning config which we would like to build.
num_epochs: int
Number of epochs to tune the prompt vectors. Default: 20.
lr: float
learning_rate: float
Learning rate to be used while tuning prompt vectors. Default: 1e-3.
verbalizer: str
Verbalizer template to be used for formatting data at train and inference time.
Expand Down Expand Up @@ -543,7 +545,7 @@ def train(
device,
eval_dataloader=val_dataloader,
metric=metric,
lr=lr,
learning_rate=learning_rate,
tokenizer=base_model.tokenizer,
accumulate_steps=accumulate_steps,
silence_progress_bars=silence_progress_bars,
Expand Down Expand Up @@ -1082,7 +1084,7 @@ def _execute_train_loop(
device: str,
eval_dataloader: Union[DataLoader, None] = None,
metric: Optional[Callable] = None,
lr: int = 1e-3,
learning_rate: int = 1e-3,
tokenizer: Union[AutoTokenizer, None] = None,
accumulate_steps: int = 1,
silence_progress_bars: bool = True,
Expand All @@ -1104,7 +1106,7 @@ def _execute_train_loop(
metric: Union[Callable, None]
Function to be used for evaluating data if an eval data loader is provided.
Default: None.
lr: float
learning_rate: float
Learning rate to be used while tuning prompt vectors. Default: 1e-3.
tokenizer: Union[AutoTokenizer, None]
Tokenizer for default evaluation; only used if no metric is provided and we have
Expand All @@ -1115,7 +1117,7 @@ def _execute_train_loop(
silence_progress_bars: bool
Silences TQDM progress bars. Default: True
"""
optimizer = AdamW(params=model.parameters(), lr=lr)
optimizer = AdamW(params=model.parameters(), lr=learning_rate)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=0,
Expand Down

0 comments on commit c16c33a

Please sign in to comment.