diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index aa89adb7..a40c3aca 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -187,7 +187,7 @@ def run( exponential_decay_length_penalty: Optional[ Union[Tuple[int, float], ExponentialDecayLengthPenalty] ] = None, - stop_sequences: Optional[str] = None, + stop_sequences: Optional[List[str]] = None, ) -> GeneratedTextResult: """ Run the full text generation model. @@ -247,7 +247,7 @@ def run_stream_out( exponential_decay_length_penalty: Optional[ Union[Tuple[int, float], ExponentialDecayLengthPenalty] ] = None, - stop_sequences: Optional[str] = None, + stop_sequences: Optional[List[str]] = None, ) -> Iterable[GeneratedTextStreamResult]: """Run the text generation model with output streaming diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index ef3eb6fc..6b712104 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -184,7 +184,7 @@ def run( exponential_decay_length_penalty: Optional[ Union[Tuple[int, float], ExponentialDecayLengthPenalty] ] = None, - stop_sequences: Optional[str] = None, + stop_sequences: Optional[List[str]] = None, ) -> GeneratedTextResult: """Run inference against the model running in TGIS. @@ -240,7 +240,7 @@ def run_stream_out( exponential_decay_length_penalty: Optional[ Union[Tuple[int, float], ExponentialDecayLengthPenalty] ] = None, - stop_sequences: Optional[str] = None, + stop_sequences: Optional[List[str]] = None, ) -> Iterable[GeneratedTextStreamResult]: """Run output stream inferencing against the model running in TGIS diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index 61a5331b..e2130ed7 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -14,7 +14,7 @@ # Standard -from typing import Iterable, Optional, Tuple, Union +from typing import Iterable, List, Optional, Tuple, Union import os # Third Party @@ -220,7 +220,7 @@ def run( exponential_decay_length_penalty: Optional[ Union[Tuple[int, float], ExponentialDecayLengthPenalty] ] = None, - stop_sequences: Optional[str] = None, + stop_sequences: Optional[List[str]] = None, ) -> GeneratedTextResult: """Run inference against the model running in TGIS. @@ -271,7 +271,7 @@ def run_stream_out( exponential_decay_length_penalty: Optional[ Union[Tuple[int, float], ExponentialDecayLengthPenalty] ] = None, - stop_sequences: Optional[str] = None, + stop_sequences: Optional[List[str]] = None, ) -> Iterable[GeneratedTextStreamResult]: """Run output stream inferencing for text generation module. diff --git a/caikit_nlp/toolkit/text_generation/model_run_utils.py b/caikit_nlp/toolkit/text_generation/model_run_utils.py index b8454249..45699166 100644 --- a/caikit_nlp/toolkit/text_generation/model_run_utils.py +++ b/caikit_nlp/toolkit/text_generation/model_run_utils.py @@ -15,7 +15,7 @@ """Utility functions used for executing run function for text_generation""" # Standard -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union # Third Party from transformers import StoppingCriteria, TextStreamer @@ -97,7 +97,7 @@ consist of: (start_index, decay_factor) where start_index indicates where penalty starts and decay_factor represents the factor of exponential decay - stop_sequences: List[str]: + stop_sequences: List[str] List of strings to be used as stopping criteria """ @@ -148,7 +148,7 @@ def generate_text_func( exponential_decay_length_penalty: Optional[ Union[Tuple[int, float], ExponentialDecayLengthPenalty] ] = None, - stop_sequences: Optional[str] = None, + stop_sequences: Optional[List[str]] = None, **kwargs, ): """ @@ -256,7 +256,7 @@ def generate_text_func_stream( exponential_decay_length_penalty: Optional[ Union[Tuple[int, float], ExponentialDecayLengthPenalty] ] = None, - stop_sequences: Optional[str] = None, + stop_sequences: Optional[List[str]] = None, **kwargs, ): """