From a2949a705d346da258a1b6a3e2d4ece684ffffae Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Fri, 25 Aug 2023 21:17:36 -0700 Subject: [PATCH 01/12] =?UTF-8?q?=E2=9C=A8=20Add=20additional=20text=20gen?= =?UTF-8?q?=20inference=20parameters?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- caikit_nlp/toolkit/tgis_utils.py | 65 ++++++++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 4 deletions(-) diff --git a/caikit_nlp/toolkit/tgis_utils.py b/caikit_nlp/toolkit/tgis_utils.py index a1d44d4c..a2a32be7 100644 --- a/caikit_nlp/toolkit/tgis_utils.py +++ b/caikit_nlp/toolkit/tgis_utils.py @@ -33,10 +33,17 @@ def get_params( preserve_input_text, - eos_token, max_new_tokens, min_new_tokens, truncate_input_tokens, + decoding_method, + temperature, + top_k, + top_p, + typical_p, + # seed, + repetition_penalty, + stop_sequences, ): """Get generation parameters @@ -53,6 +60,22 @@ def get_params( truncate_input_tokens: int Truncate inputs to provided number of tokens. """ + + if decoding_method == "GREEDY": + decoding = generation_pb2.DecodingMethod.GREEDY + elif decoding_method == "SAMPLING": + decoding = generation_pb2.DecodingMethod.SAMPLE + + # decoding = generation_pb2.DecodingMethod.__getattr__(decoding_method) + + sampling_parameters = generation_pb2.SamplingParameters( + temperature=temperature, + top_k=top_k, + top_p=top_p, + typical_p=typical_p, + # seed=seed + ) + res_options = generation_pb2.ResponseOptions( input_text=preserve_input_text, generated_tokens=True, @@ -61,13 +84,21 @@ def get_params( token_ranks=True, ) stopping = generation_pb2.StoppingCriteria( - stop_sequences=[eos_token] if eos_token is not None else None, + stop_sequences=stop_sequences, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, ) + + decoding_parameters = generation_pb2.DecodingParameters( + repetition_penalty=repetition_penalty + ) + params = generation_pb2.Parameters( + method=decoding, + sampling=sampling_parameters, response=res_options, stopping=stopping, + decoding=decoding_parameters, truncate_input_tokens=truncate_input_tokens, ) return params @@ -92,6 +123,13 @@ def unary_generate( max_new_tokens, min_new_tokens, truncate_input_tokens, + decoding_method, + temperature, + top_k, + top_p, + typical_p, + repetition_penalty, + stop_sequences, ) -> GeneratedTextResult: """Generate unary output from model in TGIS @@ -129,10 +167,16 @@ def unary_generate( params = get_params( preserve_input_text=preserve_input_text, - eos_token=self.eos_token, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, truncate_input_tokens=truncate_input_tokens, + decoding_method=decoding_method, + temperature=temperature, + top_k=top_k, + top_p=top_p, + typical_p=typical_p, + repetition_penalty=repetition_penalty, + stop_sequences=stop_sequences, ) gen_reqs = [generation_pb2.GenerationRequest(text=text)] @@ -175,6 +219,13 @@ def stream_generate( max_new_tokens, min_new_tokens, truncate_input_tokens, + decoding_method, + temperature, + top_k, + top_p, + typical_p, + repetition_penalty, + stop_sequences, ) -> Iterable[GeneratedTextStreamResult]: """Generate stream output from model in TGIS @@ -209,10 +260,16 @@ def stream_generate( params = get_params( preserve_input_text=preserve_input_text, - eos_token=self.eos_token, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, truncate_input_tokens=truncate_input_tokens, + decoding_method=decoding_method, + temperature=temperature, + top_k=top_k, + top_p=top_p, + typical_p=typical_p, + repetition_penalty=repetition_penalty, + stop_sequences=stop_sequences, ) gen_req = generation_pb2.GenerationRequest(text=text) From a599d13d4980daf81b1a024e317086fc163459ab Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Fri, 25 Aug 2023 21:37:32 -0700 Subject: [PATCH 02/12] =?UTF-8?q?=E2=9C=A8=20Add=20new=20parameters=20in?= =?UTF-8?q?=20text=20gen=20modules?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- .../text_generation/peft_tgis_remote.py | 121 ++++++++++++++++-- .../text_generation/text_generation_tgis.py | 113 +++++++++++++++- caikit_nlp/toolkit/tgis_utils.py | 2 + .../test_text_generation_tgis.py | 2 +- 4 files changed, 221 insertions(+), 17 deletions(-) diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index 65112bea..c7c5d170 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -31,7 +31,7 @@ import alog # Local -from ...toolkit.tgis_utils import TGISGenerationClient +from ...toolkit.tgis_utils import VALID_DECODING_METHODS, TGISGenerationClient from ...toolkit.verbalizer_utils import render_verbalizer from . import PeftPromptTuning @@ -158,14 +158,22 @@ def save(self, model_path: str): } ) + # pylint: disable=duplicate-code @TextGenerationTask.taskmethod() def run( self, - text, - preserve_input_text=False, - max_new_tokens=20, - min_new_tokens=0, - truncate_input_tokens=0, + text: str, + preserve_input_text: bool = False, + max_new_tokens: int = 20, + min_new_tokens: int = 0, + truncate_input_tokens: int = 0, + decoding_method: str = "GREEDY", + temperature: float = 0.0, + top_k: int = 0, + top_p: float = 0.0, + typical_p: float = 0.0, + repetition_penalty: float = 0.0, + stop_sequences: List[str] = None, ) -> GeneratedTextResult: """Run inference against the model running in TGIS. Currently we leverage greedy decoding and apply the same verbalizer used for training the local model prior to sending the @@ -188,6 +196,36 @@ def run( use to avoid failing due to input being longer than configured limits. Default: 0 - means don't truncate, thus throw error. + decoding_method: str + Parameters for conditionally penalizing / boosting + candidate tokens during decoding. + Options: "GREEDY" (default), "SAMPLING" + temperature: float + The value used to modulate the next token probabilities. + Default: 0.0 - means disabled - equivalent to 1.0 + top_k: int + The number of highest probability vocabulary tokens to keep for + top-k-filtering. Only applicable when decoding_method is SAMPLING. + Default: 0 - means disabled + top_p: float + If set to float < 1, only the smallest set of most probable tokens + with probabilities that add up to top_p or higher are kept for + generation. + Default: 0.0 - means disabled - equivalent to 1.0 + typical_p: float + Local typicality measures how similar the conditional probability of + predicting a target token next is to the expected conditional + probability of predicting a random token next, given the partial text + already generated. If set to float < 1, the smallest set of the most + locally typical tokens with probabilities that add up to typical_p + or higher are kept for generation. + Default: 0.0 - means disabled - equivalent to 1.0 + repetition_penalty: float + The more a token is used within generation the more it is penalized + to not be picked in successive generation passes. + Default: 0.0 - means no penalty - equivalent to 1.0 + stop_sequences: List(str) + Sequences to be considered for stopping generation. Returns: GeneratedTextResult Generated text result produced by TGIS. @@ -197,6 +235,12 @@ def run( self.enable_backend, "Backend must be configured and loaded with this module before executing `run` call.", ) + error.value_check( + "", + decoding_method in VALID_DECODING_METHODS, + f"Decoding method [{decoding_method}] not in valid decoding methods: " + f"[{VALID_DECODING_METHODS}]", + ) verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) return self.tgis_generation_client.unary_generate( verbalized_text, @@ -204,16 +248,30 @@ def run( max_new_tokens, min_new_tokens, truncate_input_tokens, + decoding_method, + temperature, + top_k, + top_p, + typical_p, + repetition_penalty, + stop_sequences, ) @TextGenerationTask.taskmethod(output_streaming=True) def run_stream_out( self, text: str, - preserve_input_text=False, - max_new_tokens=20, - min_new_tokens=0, - truncate_input_tokens=0, + preserve_input_text: bool = False, + max_new_tokens: int = 20, + min_new_tokens: int = 0, + truncate_input_tokens: int = 0, + decoding_method: str = "GREEDY", + temperature: float = 0.0, + top_k: int = 0, + top_p: float = 0.0, + typical_p: float = 0.0, + repetition_penalty: float = 0.0, + stop_sequences: List[str] = None, ) -> Iterable[GeneratedTextStreamResult]: """Run output stream inferencing against the model running in TGIS @@ -234,6 +292,36 @@ def run_stream_out( use to avoid failing due to input being longer than configured limits. Default: 0 - means don't truncate, thus throw error. + decoding_method: str + Parameters for conditionally penalizing / boosting + candidate tokens during decoding. + Options: "GREEDY" (default), "SAMPLING" + temperature: float + The value used to modulate the next token probabilities. + Default: 0.0 - means disabled - equivalent to 1.0 + top_k: int + The number of highest probability vocabulary tokens to keep for + top-k-filtering. Only applicable when decoding_method is SAMPLING. + Default: 0 - means disabled + top_p: float + If set to float < 1, only the smallest set of most probable tokens + with probabilities that add up to top_p or higher are kept for + generation. + Default: 0.0 - means disabled - equivalent to 1.0 + typical_p: float + Local typicality measures how similar the conditional probability of + predicting a target token next is to the expected conditional + probability of predicting a random token next, given the partial text + already generated. If set to float < 1, the smallest set of the most + locally typical tokens with probabilities that add up to typical_p + or higher are kept for generation. + Default: 0.0 - means disabled - equivalent to 1.0 + repetition_penalty: float + The more a token is used within generation the more it is penalized + to not be picked in successive generation passes. + Default: 0.0 - means no penalty - equivalent to 1.0 + stop_sequences: List(str) + Sequences to be considered for stopping generation. Returns: Iterable[GeneratedTextStreamResult] """ @@ -243,6 +331,12 @@ def run_stream_out( "Backend must be configured and loaded with this module \ before executing `run_stream_out` call.", ) + error.value_check( + "", + decoding_method in VALID_DECODING_METHODS, + f"Decoding method [{decoding_method}] not in valid decoding methods: " + f"[{VALID_DECODING_METHODS}]", + ) verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) return self.tgis_generation_client.stream_generate( verbalized_text, @@ -250,4 +344,11 @@ def run_stream_out( max_new_tokens, min_new_tokens, truncate_input_tokens, + decoding_method, + temperature, + top_k, + top_p, + typical_p, + repetition_penalty, + stop_sequences, ) diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index 0504a97d..c5c0c292 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -14,7 +14,7 @@ # Standard -from typing import Iterable, Optional, Union +from typing import Iterable, List, Optional, Union import os # First Party @@ -35,7 +35,7 @@ HFAutoSeq2SeqLM, PretrainedModelBase, ) -from ...toolkit.tgis_utils import TGISGenerationClient +from ...toolkit.tgis_utils import VALID_DECODING_METHODS, TGISGenerationClient from .text_generation_local import TextGeneration log = alog.use_channel("TXT_GEN") @@ -200,6 +200,13 @@ def run( max_new_tokens: int = 20, min_new_tokens: int = 0, truncate_input_tokens: int = 0, + decoding_method: str = "GREEDY", + temperature: float = 0.0, + top_k: int = 0, + top_p: float = 0.0, + typical_p: float = 0.0, + repetition_penalty: float = 0.0, + stop_sequences: List[str] = None, ) -> GeneratedTextResult: """Run inference against the model running in TGIS. @@ -220,10 +227,47 @@ def run( use to avoid failing due to input being longer than configured limits. Default: 0 - means don't truncate, thus throw error. + decoding_method: str + Parameters for conditionally penalizing / boosting + candidate tokens during decoding. + Options: "GREEDY" (default), "SAMPLING" + temperature: float + The value used to modulate the next token probabilities. + Default: 0.0 - means disabled - equivalent to 1.0 + top_k: int + The number of highest probability vocabulary tokens to keep for + top-k-filtering. Only applicable when decoding_method is SAMPLING. + Default: 0 - means disabled + top_p: float + If set to float < 1, only the smallest set of most probable tokens + with probabilities that add up to top_p or higher are kept for + generation. + Default: 0.0 - means disabled - equivalent to 1.0 + typical_p: float + Local typicality measures how similar the conditional probability of + predicting a target token next is to the expected conditional + probability of predicting a random token next, given the partial text + already generated. If set to float < 1, the smallest set of the most + locally typical tokens with probabilities that add up to typical_p + or higher are kept for generation. + Default: 0.0 - means disabled - equivalent to 1.0 + repetition_penalty: float + The more a token is used within generation the more it is penalized + to not be picked in successive generation passes. + Default: 0.0 - means no penalty - equivalent to 1.0 + stop_sequences: List(str) + Sequences to be considered for stopping generation. Returns: GeneratedTextResult Generated text result produced by TGIS. """ + error.value_check( + "", + decoding_method in VALID_DECODING_METHODS, + f"Decoding method [{decoding_method}] not in valid decoding methods: " + f"[{VALID_DECODING_METHODS}]", + ) + if self._model_loaded: return self.tgis_generation_client.unary_generate( text, @@ -231,16 +275,30 @@ def run( max_new_tokens, min_new_tokens, truncate_input_tokens, + decoding_method, + temperature, + top_k, + top_p, + typical_p, + repetition_penalty, + stop_sequences, ) @TextGenerationTask.taskmethod(output_streaming=True) def run_stream_out( self, text: str, - preserve_input_text=False, - max_new_tokens=20, - min_new_tokens=0, - truncate_input_tokens=0, + preserve_input_text: bool = False, + max_new_tokens: int = 20, + min_new_tokens: int = 0, + truncate_input_tokens: int = 0, + decoding_method: str = "GREEDY", + temperature: float = 0.0, + top_k: int = 0, + top_p: float = 0.0, + typical_p: float = 0.0, + repetition_penalty: float = 0.0, + stop_sequences: List[str] = None, ) -> Iterable[GeneratedTextStreamResult]: """Run output stream inferencing for text generation module. @@ -259,9 +317,45 @@ def run_stream_out( use to avoid failing due to input being longer than configured limits. Default: 0 - means don't truncate, thus throw error. + decoding_method: str + Parameters for conditionally penalizing / boosting + candidate tokens during decoding. + Options: "GREEDY" (default), "SAMPLING" + temperature: float + The value used to modulate the next token probabilities. + Default: 0.0 - means disabled - equivalent to 1.0 + top_k: int + The number of highest probability vocabulary tokens to keep for + top-k-filtering. Only applicable when decoding_method is SAMPLING. + Default: 0 - means disabled + top_p: float + If set to float < 1, only the smallest set of most probable tokens + with probabilities that add up to top_p or higher are kept for + generation. + Default: 0.0 - means disabled - equivalent to 1.0 + typical_p: float + Local typicality measures how similar the conditional probability of + predicting a target token next is to the expected conditional + probability of predicting a random token next, given the partial text + already generated. If set to float < 1, the smallest set of the most + locally typical tokens with probabilities that add up to typical_p + or higher are kept for generation. + Default: 0.0 - means disabled - equivalent to 1.0 + repetition_penalty: float + The more a token is used within generation the more it is penalized + to not be picked in successive generation passes. + Default: 0.0 - means no penalty - equivalent to 1.0 + stop_sequences: List(str) + Sequences to be considered for stopping generation. Returns: Iterable[GeneratedTextStreamResult] """ + error.value_check( + "", + decoding_method in VALID_DECODING_METHODS, + f"Decoding method [{decoding_method}] not in valid decoding methods: " + f"[{VALID_DECODING_METHODS}]", + ) if self._model_loaded: return self.tgis_generation_client.stream_generate( text, @@ -269,4 +363,11 @@ def run_stream_out( max_new_tokens, min_new_tokens, truncate_input_tokens, + decoding_method, + temperature, + top_k, + top_p, + typical_p, + repetition_penalty, + stop_sequences, ) diff --git a/caikit_nlp/toolkit/tgis_utils.py b/caikit_nlp/toolkit/tgis_utils.py index a2a32be7..a520fc16 100644 --- a/caikit_nlp/toolkit/tgis_utils.py +++ b/caikit_nlp/toolkit/tgis_utils.py @@ -30,6 +30,8 @@ log = alog.use_channel("TGIS_UTILS") error = error_handler.get(log) +VALID_DECODING_METHODS = ["GREEDY", "SAMPLING"] + def get_params( preserve_input_text, diff --git a/tests/modules/text_generation/test_text_generation_tgis.py b/tests/modules/text_generation/test_text_generation_tgis.py index 0be8afcf..6434305b 100644 --- a/tests/modules/text_generation/test_text_generation_tgis.py +++ b/tests/modules/text_generation/test_text_generation_tgis.py @@ -35,7 +35,7 @@ def test_bootstrap_and_run_causallm(): CAUSAL_LM_MODEL, load_backend=StubTGISBackend() ) - result = model.run(SAMPLE_TEXT, preserve_input_text=True) + result = model.run(SAMPLE_TEXT, preserve_input_text=True, repetition_penalty=50.0) StubTGISClient.validate_unary_generate_response(result) From def12c2d9a88fde1d63260e7157f3ba5c0722a81 Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Tue, 29 Aug 2023 14:27:18 -0700 Subject: [PATCH 03/12] =?UTF-8?q?=E2=9C=A8=20Add=20more=20params?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- .../text_generation/peft_tgis_remote.py | 98 +++++++++--------- .../text_generation/text_generation_tgis.py | 99 ++++++++++--------- caikit_nlp/toolkit/tgis_utils.py | 27 +++-- 3 files changed, 127 insertions(+), 97 deletions(-) diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index c7c5d170..030194f7 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -15,7 +15,7 @@ prompt vectors in TGIS generation requests. """ # Standard -from typing import Iterable, List, Optional +from typing import Iterable, List, Optional, Tuple import os # First Party @@ -164,16 +164,18 @@ def run( self, text: str, preserve_input_text: bool = False, - max_new_tokens: int = 20, - min_new_tokens: int = 0, - truncate_input_tokens: int = 0, - decoding_method: str = "GREEDY", - temperature: float = 0.0, - top_k: int = 0, - top_p: float = 0.0, - typical_p: float = 0.0, - repetition_penalty: float = 0.0, - stop_sequences: List[str] = None, + max_new_tokens: Optional[int] = 20, + min_new_tokens: Optional[int] = 0, + truncate_input_tokens: Optional[int] = 0, + decoding_method: Optional[str] = "GREEDY", + top_k: Optional[int] = 0, + top_p: Optional[float] = 0.0, + typical_p: Optional[float] = 0.0, + temperature: Optional[float] = 1.0, + repetition_penalty: Optional[float] = 0.0, + max_time: Optional[float] = None, + exponential_decay_length_penalty: Optional[Tuple[int, float]] = None, + stop_sequences: Optional[str] = None, ) -> GeneratedTextResult: """Run inference against the model running in TGIS. Currently we leverage greedy decoding and apply the same verbalizer used for training the local model prior to sending the @@ -243,18 +245,20 @@ def run( ) verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) return self.tgis_generation_client.unary_generate( - verbalized_text, - preserve_input_text, - max_new_tokens, - min_new_tokens, - truncate_input_tokens, - decoding_method, - temperature, - top_k, - top_p, - typical_p, - repetition_penalty, - stop_sequences, + text=verbalized_text, + preserve_input_text=preserve_input_text, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + truncate_input_tokens=truncate_input_tokens, + decoding_method=decoding_method, + top_k=top_k, + top_p=top_p, + typical_p=typical_p, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_time=max_time, + exponential_decay_length_penalty=exponential_decay_length_penalty, + stop_sequences=stop_sequences, ) @TextGenerationTask.taskmethod(output_streaming=True) @@ -262,16 +266,18 @@ def run_stream_out( self, text: str, preserve_input_text: bool = False, - max_new_tokens: int = 20, - min_new_tokens: int = 0, - truncate_input_tokens: int = 0, - decoding_method: str = "GREEDY", - temperature: float = 0.0, - top_k: int = 0, - top_p: float = 0.0, - typical_p: float = 0.0, - repetition_penalty: float = 0.0, - stop_sequences: List[str] = None, + max_new_tokens: Optional[int] = 20, + min_new_tokens: Optional[int] = 0, + truncate_input_tokens: Optional[int] = 0, + decoding_method: Optional[str] = "GREEDY", + top_k: Optional[int] = 0, + top_p: Optional[float] = 0.0, + typical_p: Optional[float] = 0.0, + temperature: Optional[float] = 1.0, + repetition_penalty: Optional[float] = 0.0, + max_time: Optional[float] = None, + exponential_decay_length_penalty: Optional[Tuple[int, float]] = None, + stop_sequences: Optional[str] = None, ) -> Iterable[GeneratedTextStreamResult]: """Run output stream inferencing against the model running in TGIS @@ -339,16 +345,18 @@ def run_stream_out( ) verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) return self.tgis_generation_client.stream_generate( - verbalized_text, - preserve_input_text, - max_new_tokens, - min_new_tokens, - truncate_input_tokens, - decoding_method, - temperature, - top_k, - top_p, - typical_p, - repetition_penalty, - stop_sequences, + text=verbalized_text, + preserve_input_text=preserve_input_text, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + truncate_input_tokens=truncate_input_tokens, + decoding_method=decoding_method, + top_k=top_k, + top_p=top_p, + typical_p=typical_p, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_time=max_time, + exponential_decay_length_penalty=exponential_decay_length_penalty, + stop_sequences=stop_sequences, ) diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index c5c0c292..2feb2778 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -14,7 +14,7 @@ # Standard -from typing import Iterable, List, Optional, Union +from typing import Iterable, Optional, Tuple, Union import os # First Party @@ -197,16 +197,18 @@ def run( self, text: str, preserve_input_text: bool = False, - max_new_tokens: int = 20, - min_new_tokens: int = 0, - truncate_input_tokens: int = 0, - decoding_method: str = "GREEDY", - temperature: float = 0.0, - top_k: int = 0, - top_p: float = 0.0, - typical_p: float = 0.0, - repetition_penalty: float = 0.0, - stop_sequences: List[str] = None, + max_new_tokens: Optional[int] = 20, + min_new_tokens: Optional[int] = 0, + truncate_input_tokens: Optional[int] = 0, + decoding_method: Optional[str] = "GREEDY", + top_k: Optional[int] = 0, + top_p: Optional[float] = 0.0, + typical_p: Optional[float] = 0.0, + temperature: Optional[float] = 1.0, + repetition_penalty: Optional[float] = 0.0, + max_time: Optional[float] = None, + exponential_decay_length_penalty: Optional[Tuple[int, float]] = None, + stop_sequences: Optional[str] = None, ) -> GeneratedTextResult: """Run inference against the model running in TGIS. @@ -270,35 +272,40 @@ def run( if self._model_loaded: return self.tgis_generation_client.unary_generate( - text, - preserve_input_text, - max_new_tokens, - min_new_tokens, - truncate_input_tokens, - decoding_method, - temperature, - top_k, - top_p, - typical_p, - repetition_penalty, - stop_sequences, + text=text, + preserve_input_text=preserve_input_text, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + truncate_input_tokens=truncate_input_tokens, + decoding_method=decoding_method, + top_k=top_k, + top_p=top_p, + typical_p=typical_p, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_time=max_time, + exponential_decay_length_penalty=exponential_decay_length_penalty, + stop_sequences=stop_sequences, ) + # pylint: disable=duplicate-code @TextGenerationTask.taskmethod(output_streaming=True) def run_stream_out( self, text: str, preserve_input_text: bool = False, - max_new_tokens: int = 20, - min_new_tokens: int = 0, - truncate_input_tokens: int = 0, - decoding_method: str = "GREEDY", - temperature: float = 0.0, - top_k: int = 0, - top_p: float = 0.0, - typical_p: float = 0.0, - repetition_penalty: float = 0.0, - stop_sequences: List[str] = None, + max_new_tokens: Optional[int] = 20, + min_new_tokens: Optional[int] = 0, + truncate_input_tokens: Optional[int] = 0, + decoding_method: Optional[str] = "GREEDY", + top_k: Optional[int] = 0, + top_p: Optional[float] = 0.0, + typical_p: Optional[float] = 0.0, + temperature: Optional[float] = 1.0, + repetition_penalty: Optional[float] = 0.0, + max_time: Optional[float] = None, + exponential_decay_length_penalty: Optional[Tuple[int, float]] = None, + stop_sequences: Optional[str] = None, ) -> Iterable[GeneratedTextStreamResult]: """Run output stream inferencing for text generation module. @@ -358,16 +365,18 @@ def run_stream_out( ) if self._model_loaded: return self.tgis_generation_client.stream_generate( - text, - preserve_input_text, - max_new_tokens, - min_new_tokens, - truncate_input_tokens, - decoding_method, - temperature, - top_k, - top_p, - typical_p, - repetition_penalty, - stop_sequences, + text=text, + preserve_input_text=preserve_input_text, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + truncate_input_tokens=truncate_input_tokens, + decoding_method=decoding_method, + top_k=top_k, + top_p=top_p, + typical_p=typical_p, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_time=max_time, + exponential_decay_length_penalty=exponential_decay_length_penalty, + stop_sequences=stop_sequences, ) diff --git a/caikit_nlp/toolkit/tgis_utils.py b/caikit_nlp/toolkit/tgis_utils.py index a520fc16..cc99baba 100644 --- a/caikit_nlp/toolkit/tgis_utils.py +++ b/caikit_nlp/toolkit/tgis_utils.py @@ -32,6 +32,8 @@ VALID_DECODING_METHODS = ["GREEDY", "SAMPLING"] +# pylint: disable=duplicate-code + def get_params( preserve_input_text, @@ -39,12 +41,13 @@ def get_params( min_new_tokens, truncate_input_tokens, decoding_method, - temperature, top_k, top_p, typical_p, - # seed, + temperature, repetition_penalty, + max_time, + exponential_decay_length_penalty, stop_sequences, ): """Get generation parameters @@ -89,10 +92,12 @@ def get_params( stop_sequences=stop_sequences, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, + time_limit_millis=max_time, ) decoding_parameters = generation_pb2.DecodingParameters( - repetition_penalty=repetition_penalty + repetition_penalty=repetition_penalty, + length_penalty=exponential_decay_length_penalty, ) params = generation_pb2.Parameters( @@ -126,11 +131,13 @@ def unary_generate( min_new_tokens, truncate_input_tokens, decoding_method, - temperature, top_k, top_p, typical_p, + temperature, repetition_penalty, + max_time, + exponential_decay_length_penalty, stop_sequences, ) -> GeneratedTextResult: """Generate unary output from model in TGIS @@ -173,11 +180,13 @@ def unary_generate( min_new_tokens=min_new_tokens, truncate_input_tokens=truncate_input_tokens, decoding_method=decoding_method, - temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p, + temperature=temperature, repetition_penalty=repetition_penalty, + max_time=max_time, + exponential_decay_length_penalty=exponential_decay_length_penalty, stop_sequences=stop_sequences, ) @@ -222,11 +231,13 @@ def stream_generate( min_new_tokens, truncate_input_tokens, decoding_method, - temperature, top_k, top_p, typical_p, + temperature, repetition_penalty, + max_time, + exponential_decay_length_penalty, stop_sequences, ) -> Iterable[GeneratedTextStreamResult]: """Generate stream output from model in TGIS @@ -266,11 +277,13 @@ def stream_generate( min_new_tokens=min_new_tokens, truncate_input_tokens=truncate_input_tokens, decoding_method=decoding_method, - temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p, + temperature=temperature, repetition_penalty=repetition_penalty, + max_time=max_time, + exponential_decay_length_penalty=exponential_decay_length_penalty, stop_sequences=stop_sequences, ) From 713da85ad66f0a6736bcba76e35284a6da5596b4 Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Tue, 29 Aug 2023 14:27:27 -0700 Subject: [PATCH 04/12] =?UTF-8?q?=E2=9C=85=20Add=20tests=20with=20optional?= =?UTF-8?q?=20dependencies?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- .../text_generation/test_peft_tgis_remote.py | 6 +- .../test_text_generation_tgis.py | 55 ++++++++++++++++++- 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/tests/modules/text_generation/test_peft_tgis_remote.py b/tests/modules/text_generation/test_peft_tgis_remote.py index 50df588c..8fd4da3b 100644 --- a/tests/modules/text_generation/test_peft_tgis_remote.py +++ b/tests/modules/text_generation/test_peft_tgis_remote.py @@ -38,7 +38,9 @@ def test_load_and_run(causal_lm_dummy_model, stub_tgis_backend): model_prompt_dir = os.path.split(model_dir)[-1] # Run an inference request, which is wrapped around our mocked Generate call - result = mock_tgis_model.run(SAMPLE_TEXT, preserve_input_text=True) + result = mock_tgis_model.run( + SAMPLE_TEXT, preserve_input_text=True, max_new_tokens=200, min_new_tokens=50 + ) StubTGISClient.validate_unary_generate_response(result) stub_generation_request = mock_gen.call_args_list[0].args[0] @@ -72,7 +74,7 @@ def test_load_and_run_stream_out(causal_lm_dummy_model, stub_tgis_backend): # Run an inference request, which is wrapped around our mocked GenerateStream call stream_result = mock_tgis_model.run_stream_out( - SAMPLE_TEXT, preserve_input_text=True + SAMPLE_TEXT, preserve_input_text=True, max_new_tokens=200, min_new_tokens=50 ) StubTGISClient.validate_stream_generate_response(stream_result) diff --git a/tests/modules/text_generation/test_text_generation_tgis.py b/tests/modules/text_generation/test_text_generation_tgis.py index 6434305b..16686f02 100644 --- a/tests/modules/text_generation/test_text_generation_tgis.py +++ b/tests/modules/text_generation/test_text_generation_tgis.py @@ -35,7 +35,7 @@ def test_bootstrap_and_run_causallm(): CAUSAL_LM_MODEL, load_backend=StubTGISBackend() ) - result = model.run(SAMPLE_TEXT, preserve_input_text=True, repetition_penalty=50.0) + result = model.run(SAMPLE_TEXT, preserve_input_text=True) StubTGISClient.validate_unary_generate_response(result) @@ -157,3 +157,56 @@ def test_run_stream_out_with_runtime_error(): response = model.run_stream_out(SAMPLE_TEXT, preserve_input_text=True) # Need to iterate over stream for error next(response) + + +######################## Test run with optional params ##################### + + +def test_bootstrap_and_run_causallm_with_optional_params(): + """Check if we can bootstrap and run causallm models with optional dependencies""" + + model = TextGenerationTGIS.bootstrap( + CAUSAL_LM_MODEL, load_backend=StubTGISBackend() + ) + + result = model.run( + SAMPLE_TEXT, + preserve_input_text=True, + max_new_tokens=200, + min_new_tokens=50, + truncate_input_tokens=10, + decoding_method="GREEDY", + top_k=0, + top_p=0.1, + typical_p=0.5, + temperature=0.75, + repetition_penalty=0.3, + max_time=1000, + exponential_decay_length_penalty=(), + stop_sequences=["This is a test"], + ) + StubTGISClient.validate_unary_generate_response(result) + + +def test_bootstrap_and_run_stream_out_with_optional_dependencies(): + """Check if we can bootstrap and run_stream_out with optional dependencies""" + model = TextGenerationTGIS.bootstrap( + SEQ2SEQ_LM_MODEL, load_backend=StubTGISBackend() + ) + + stream_result = model.run_stream_out( + SAMPLE_TEXT, + max_new_tokens=200, + min_new_tokens=50, + truncate_input_tokens=10, + decoding_method="GREEDY", + top_k=0, + top_p=0.1, + typical_p=0.5, + temperature=0.75, + repetition_penalty=0.3, + max_time=1000, + exponential_decay_length_penalty=(), + stop_sequences=["This is a test"], + ) + StubTGISClient.validate_stream_generate_response(stream_result) From 987394299457faa13a58d6ad10518152b16ff3a2 Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Tue, 29 Aug 2023 22:11:43 -0700 Subject: [PATCH 05/12] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Add=20more=20inferen?= =?UTF-8?q?ce=20params?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- .../text_generation/peft_tgis_remote.py | 126 ++++----------- .../text_generation/text_generation_tgis.py | 125 ++++----------- caikit_nlp/toolkit/tgis_utils.py | 144 +++++++++++++----- .../test_text_generation_tgis.py | 8 +- 4 files changed, 170 insertions(+), 233 deletions(-) diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index 030194f7..88710991 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -15,7 +15,7 @@ prompt vectors in TGIS generation requests. """ # Standard -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import os # First Party @@ -31,7 +31,12 @@ import alog # Local -from ...toolkit.tgis_utils import VALID_DECODING_METHODS, TGISGenerationClient +from ...data_model import ExponentialDecayLengthPenalty +from ...toolkit.tgis_utils import ( + GENERATE_FUNCTION_ARGS, + VALID_DECODING_METHODS, + TGISGenerationClient, +) from ...toolkit.verbalizer_utils import render_verbalizer from . import PeftPromptTuning @@ -158,7 +163,6 @@ def save(self, model_path: str): } ) - # pylint: disable=duplicate-code @TextGenerationTask.taskmethod() def run( self, @@ -172,9 +176,12 @@ def run( top_p: Optional[float] = 0.0, typical_p: Optional[float] = 0.0, temperature: Optional[float] = 1.0, + seed: Optional[int] = None, repetition_penalty: Optional[float] = 0.0, max_time: Optional[float] = None, - exponential_decay_length_penalty: Optional[Tuple[int, float]] = None, + exponential_decay_length_penalty: Optional[ + Union[Tuple[int, float], ExponentialDecayLengthPenalty] + ] = None, stop_sequences: Optional[str] = None, ) -> GeneratedTextResult: """Run inference against the model running in TGIS. Currently we leverage greedy decoding @@ -182,56 +189,14 @@ def run( request to TGIS. Args: - text: str - Source string to be encoded for generation. - preserve_input_text: str - Whether or not the source string should be contained in the generated output, - e.g., as a prefix. - max_new_tokens: int - The maximum numbers of tokens to generate. - Default: 20 - min_new_tokens: int - The minimum numbers of tokens to generate. - Default: 0 - means no minimum - truncate_input_tokens: int - Truncate inputs to provided number of tokens. This can be - use to avoid failing due to input being longer than - configured limits. - Default: 0 - means don't truncate, thus throw error. - decoding_method: str - Parameters for conditionally penalizing / boosting - candidate tokens during decoding. - Options: "GREEDY" (default), "SAMPLING" - temperature: float - The value used to modulate the next token probabilities. - Default: 0.0 - means disabled - equivalent to 1.0 - top_k: int - The number of highest probability vocabulary tokens to keep for - top-k-filtering. Only applicable when decoding_method is SAMPLING. - Default: 0 - means disabled - top_p: float - If set to float < 1, only the smallest set of most probable tokens - with probabilities that add up to top_p or higher are kept for - generation. - Default: 0.0 - means disabled - equivalent to 1.0 - typical_p: float - Local typicality measures how similar the conditional probability of - predicting a target token next is to the expected conditional - probability of predicting a random token next, given the partial text - already generated. If set to float < 1, the smallest set of the most - locally typical tokens with probabilities that add up to typical_p - or higher are kept for generation. - Default: 0.0 - means disabled - equivalent to 1.0 - repetition_penalty: float - The more a token is used within generation the more it is penalized - to not be picked in successive generation passes. - Default: 0.0 - means no penalty - equivalent to 1.0 - stop_sequences: List(str) - Sequences to be considered for stopping generation. + {} Returns: GeneratedTextResult Generated text result produced by TGIS. - """ + """.format( + GENERATE_FUNCTION_ARGS + ) + error.value_check( "", self.enable_backend, @@ -255,6 +220,7 @@ def run( top_p=top_p, typical_p=typical_p, temperature=temperature, + seed=seed, repetition_penalty=repetition_penalty, max_time=max_time, exponential_decay_length_penalty=exponential_decay_length_penalty, @@ -274,63 +240,24 @@ def run_stream_out( top_p: Optional[float] = 0.0, typical_p: Optional[float] = 0.0, temperature: Optional[float] = 1.0, + seed: Optional[int] = None, repetition_penalty: Optional[float] = 0.0, max_time: Optional[float] = None, - exponential_decay_length_penalty: Optional[Tuple[int, float]] = None, + exponential_decay_length_penalty: Optional[ + Union[Tuple[int, float], ExponentialDecayLengthPenalty] + ] = None, stop_sequences: Optional[str] = None, ) -> Iterable[GeneratedTextStreamResult]: """Run output stream inferencing against the model running in TGIS Args: - text: str - Source string to be encoded for generation. - preserve_input_text: str - Whether or not the source string should be contained in the generated output, - e.g., as a prefix. - max_new_tokens: int - The maximum numbers of tokens to generate. - Default: 20 - min_new_tokens: int - The minimum numbers of tokens to generate. - Default: 0 - means no minimum - truncate_input_tokens: int - Truncate inputs to provided number of tokens. This can be - use to avoid failing due to input being longer than - configured limits. - Default: 0 - means don't truncate, thus throw error. - decoding_method: str - Parameters for conditionally penalizing / boosting - candidate tokens during decoding. - Options: "GREEDY" (default), "SAMPLING" - temperature: float - The value used to modulate the next token probabilities. - Default: 0.0 - means disabled - equivalent to 1.0 - top_k: int - The number of highest probability vocabulary tokens to keep for - top-k-filtering. Only applicable when decoding_method is SAMPLING. - Default: 0 - means disabled - top_p: float - If set to float < 1, only the smallest set of most probable tokens - with probabilities that add up to top_p or higher are kept for - generation. - Default: 0.0 - means disabled - equivalent to 1.0 - typical_p: float - Local typicality measures how similar the conditional probability of - predicting a target token next is to the expected conditional - probability of predicting a random token next, given the partial text - already generated. If set to float < 1, the smallest set of the most - locally typical tokens with probabilities that add up to typical_p - or higher are kept for generation. - Default: 0.0 - means disabled - equivalent to 1.0 - repetition_penalty: float - The more a token is used within generation the more it is penalized - to not be picked in successive generation passes. - Default: 0.0 - means no penalty - equivalent to 1.0 - stop_sequences: List(str) - Sequences to be considered for stopping generation. + {} Returns: Iterable[GeneratedTextStreamResult] - """ + """.format( + GENERATE_FUNCTION_ARGS + ) + error.value_check( "", self.enable_backend, @@ -355,6 +282,7 @@ def run_stream_out( top_p=top_p, typical_p=typical_p, temperature=temperature, + seed=seed, repetition_penalty=repetition_penalty, max_time=max_time, exponential_decay_length_penalty=exponential_decay_length_penalty, diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index 2feb2778..a5998015 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -30,18 +30,26 @@ import alog # Local +from ...data_model import ExponentialDecayLengthPenalty from ...resources.pretrained_model import ( HFAutoCausalLM, HFAutoSeq2SeqLM, PretrainedModelBase, ) -from ...toolkit.tgis_utils import VALID_DECODING_METHODS, TGISGenerationClient +from ...toolkit.tgis_utils import ( + GENERATE_FUNCTION_ARGS, + VALID_DECODING_METHODS, + TGISGenerationClient, +) from .text_generation_local import TextGeneration log = alog.use_channel("TXT_GEN") error = error_handler.get(log) # pylint: disable=too-many-instance-attributes +# pylint: disable=duplicate-code + + @module(backend_type=TGISBackend.backend_type, base_module=TextGeneration) class TextGenerationTGIS(ModuleBase): """Module to provide text generation capabilities""" @@ -205,64 +213,25 @@ def run( top_p: Optional[float] = 0.0, typical_p: Optional[float] = 0.0, temperature: Optional[float] = 1.0, + seed: Optional[int] = None, repetition_penalty: Optional[float] = 0.0, max_time: Optional[float] = None, - exponential_decay_length_penalty: Optional[Tuple[int, float]] = None, + exponential_decay_length_penalty: Optional[ + Union[Tuple[int, float], ExponentialDecayLengthPenalty] + ] = None, stop_sequences: Optional[str] = None, ) -> GeneratedTextResult: """Run inference against the model running in TGIS. Args: - text: str - Source string to be encoded for generation. - preserve_input_text: bool - Whether or not the source string should be contained in the generated output, - e.g., as a prefix. - max_new_tokens: int - The maximum numbers of tokens to generate. - Default: 20 - min_new_tokens: int - The minimum numbers of tokens to generate. - Default: 0 - means no minimum - truncate_input_tokens: int - Truncate inputs to provided number of tokens. This can be - use to avoid failing due to input being longer than - configured limits. - Default: 0 - means don't truncate, thus throw error. - decoding_method: str - Parameters for conditionally penalizing / boosting - candidate tokens during decoding. - Options: "GREEDY" (default), "SAMPLING" - temperature: float - The value used to modulate the next token probabilities. - Default: 0.0 - means disabled - equivalent to 1.0 - top_k: int - The number of highest probability vocabulary tokens to keep for - top-k-filtering. Only applicable when decoding_method is SAMPLING. - Default: 0 - means disabled - top_p: float - If set to float < 1, only the smallest set of most probable tokens - with probabilities that add up to top_p or higher are kept for - generation. - Default: 0.0 - means disabled - equivalent to 1.0 - typical_p: float - Local typicality measures how similar the conditional probability of - predicting a target token next is to the expected conditional - probability of predicting a random token next, given the partial text - already generated. If set to float < 1, the smallest set of the most - locally typical tokens with probabilities that add up to typical_p - or higher are kept for generation. - Default: 0.0 - means disabled - equivalent to 1.0 - repetition_penalty: float - The more a token is used within generation the more it is penalized - to not be picked in successive generation passes. - Default: 0.0 - means no penalty - equivalent to 1.0 - stop_sequences: List(str) - Sequences to be considered for stopping generation. + {} Returns: GeneratedTextResult Generated text result produced by TGIS. - """ + """.format( + GENERATE_FUNCTION_ARGS + ) + error.value_check( "", decoding_method in VALID_DECODING_METHODS, @@ -282,13 +251,13 @@ def run( top_p=top_p, typical_p=typical_p, temperature=temperature, + seed=seed, repetition_penalty=repetition_penalty, max_time=max_time, exponential_decay_length_penalty=exponential_decay_length_penalty, stop_sequences=stop_sequences, ) - # pylint: disable=duplicate-code @TextGenerationTask.taskmethod(output_streaming=True) def run_stream_out( self, @@ -302,61 +271,24 @@ def run_stream_out( top_p: Optional[float] = 0.0, typical_p: Optional[float] = 0.0, temperature: Optional[float] = 1.0, + seed: Optional[int] = None, repetition_penalty: Optional[float] = 0.0, max_time: Optional[float] = None, - exponential_decay_length_penalty: Optional[Tuple[int, float]] = None, + exponential_decay_length_penalty: Optional[ + Union[Tuple[int, float], ExponentialDecayLengthPenalty] + ] = None, stop_sequences: Optional[str] = None, ) -> Iterable[GeneratedTextStreamResult]: """Run output stream inferencing for text generation module. Args: - text: str - Source string to be encoded for generation. - preserve_input_text: bool - Whether or not the source string should be contained in the generated output, - e.g., as a prefix. - max_new_tokens: int - Maximum tokens for the model to generate - min_new_tokens: int - Minimum tokens for the model to generate - truncate_input_tokens: int - Truncate inputs to provided number of tokens. This can be - use to avoid failing due to input being longer than - configured limits. - Default: 0 - means don't truncate, thus throw error. - decoding_method: str - Parameters for conditionally penalizing / boosting - candidate tokens during decoding. - Options: "GREEDY" (default), "SAMPLING" - temperature: float - The value used to modulate the next token probabilities. - Default: 0.0 - means disabled - equivalent to 1.0 - top_k: int - The number of highest probability vocabulary tokens to keep for - top-k-filtering. Only applicable when decoding_method is SAMPLING. - Default: 0 - means disabled - top_p: float - If set to float < 1, only the smallest set of most probable tokens - with probabilities that add up to top_p or higher are kept for - generation. - Default: 0.0 - means disabled - equivalent to 1.0 - typical_p: float - Local typicality measures how similar the conditional probability of - predicting a target token next is to the expected conditional - probability of predicting a random token next, given the partial text - already generated. If set to float < 1, the smallest set of the most - locally typical tokens with probabilities that add up to typical_p - or higher are kept for generation. - Default: 0.0 - means disabled - equivalent to 1.0 - repetition_penalty: float - The more a token is used within generation the more it is penalized - to not be picked in successive generation passes. - Default: 0.0 - means no penalty - equivalent to 1.0 - stop_sequences: List(str) - Sequences to be considered for stopping generation. + {} Returns: Iterable[GeneratedTextStreamResult] - """ + """.format( + GENERATE_FUNCTION_ARGS + ) + error.value_check( "", decoding_method in VALID_DECODING_METHODS, @@ -375,6 +307,7 @@ def run_stream_out( top_p=top_p, typical_p=typical_p, temperature=temperature, + seed=seed, repetition_penalty=repetition_penalty, max_time=max_time, exponential_decay_length_penalty=exponential_decay_length_penalty, diff --git a/caikit_nlp/toolkit/tgis_utils.py b/caikit_nlp/toolkit/tgis_utils.py index cc99baba..d4c05441 100644 --- a/caikit_nlp/toolkit/tgis_utils.py +++ b/caikit_nlp/toolkit/tgis_utils.py @@ -27,6 +27,9 @@ from caikit_tgis_backend.protobufs import generation_pb2 import alog +# Local +from ..data_model import ExponentialDecayLengthPenalty + log = alog.use_channel("TGIS_UTILS") error = error_handler.get(log) @@ -34,9 +37,74 @@ # pylint: disable=duplicate-code +GENERATE_FUNCTION_ARGS = """ + text: str + Input string to be used to the generation model. + preserve_input_text: str + Whether or not the source string should be contained in the generated output, + e.g., as a prefix. + max_new_tokens: int + The maximum numbers of tokens to generate. + Default: 20 + min_new_tokens: int + The minimum numbers of tokens to generate. + Default: 0 - means no minimum + truncate_input_tokens: int + Truncate inputs to provided number of tokens. This can be + use to avoid failing due to input being longer than + configured limits. + Default: 0 - means don't truncate, thus throw error. + decoding_method: str + Parameters for conditionally penalizing / boosting + candidate tokens during decoding. + Options: "GREEDY" (default), "SAMPLING" + top_k: int + The number of highest probability vocabulary tokens to keep for + top-k-filtering. Only applicable when decoding_method is SAMPLING. + Default: 0 - means disabled + top_p: float + If set to float < 1, only the smallest set of most probable tokens + with probabilities that add up to top_p or higher are kept for + generation. Only applicable when decoding_method is SAMPLING. + Default: 0.0 - means disabled - equivalent to 1.0 + typical_p: float + Local typicality measures how similar the conditional probability of + predicting a target token next is to the expected conditional + probability of predicting a random token next, given the partial text + already generated. If set to float < 1, the smallest set of the most + locally typical tokens with probabilities that add up to typical_p + or higher are kept for generation. Only applicable when decoding_method + is SAMPLING. + Default: 0.0 - means disabled - equivalent to 1.0 + temperature: float + The value used to modulate the next token probabilities. + Only applicable when decoding_method is SAMPLING. + Default: 1.0 - means disabled - equivalent to 1.0 + seed: int + Random seed to control sampling. Only applicable when decoding_method + is SAMPLING. Default: None + repetition_penalty: float + The more a token is used within generation the more it is penalized + to not be picked in successive generation passes. + Default: 0.0 - means no penalty - equivalent to 1.0 + max_time: float + Amount of time in seconds that the query should take maximum. + NOTE: this does not include network overhead. + Range: 0-120.0 + exponential_decay_length_penalty: Tuple(int, float) + This Tuple adds an exponentially increasing length penalty, after + a certain amount of tokens have been generated. The tuple shall + consist of: (start_index, decay_factor) where start_index + indicates where penalty starts and decay_factor represents the factor + of exponential decay + stop_sequences: List[str]: + List of strings to be used as stopping criteria +""" + def get_params( preserve_input_text, + eos_token, max_new_tokens, min_new_tokens, truncate_input_tokens, @@ -45,6 +113,7 @@ def get_params( top_p, typical_p, temperature, + seed, repetition_penalty, max_time, exponential_decay_length_penalty, @@ -58,27 +127,22 @@ def get_params( e.g., as a prefix. eos_token: str A special token representing the end of a sentence. - max_new_tokens: int - The maximum numbers of tokens to generate. - min_new_tokens: int - The minimum numbers of tokens to generate. - truncate_input_tokens: int - Truncate inputs to provided number of tokens. - """ + {} + """.format( + GENERATE_FUNCTION_ARGS + ) if decoding_method == "GREEDY": decoding = generation_pb2.DecodingMethod.GREEDY elif decoding_method == "SAMPLING": decoding = generation_pb2.DecodingMethod.SAMPLE - # decoding = generation_pb2.DecodingMethod.__getattr__(decoding_method) - sampling_parameters = generation_pb2.SamplingParameters( temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p, - # seed=seed + seed=seed, ) res_options = generation_pb2.ResponseOptions( @@ -89,15 +153,32 @@ def get_params( token_ranks=True, ) stopping = generation_pb2.StoppingCriteria( - stop_sequences=stop_sequences, + stop_sequences=stop_sequences or [eos_token] if eos_token else None, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, time_limit_millis=max_time, ) + start_index = None + decay_factor = None + + if exponential_decay_length_penalty: + if isinstance(exponential_decay_length_penalty, tuple): + start_index = exponential_decay_length_penalty[0] + decay_factor = exponential_decay_length_penalty[1] + elif isinstance( + exponential_decay_length_penalty, ExponentialDecayLengthPenalty + ): + start_index = exponential_decay_length_penalty.start_index + decay_factor = exponential_decay_length_penalty.decay_factor + + length_penalty = generation_pb2.DecodingParameters.LengthPenalty( + start_index=start_index, decay_factor=decay_factor + ) + decoding_parameters = generation_pb2.DecodingParameters( repetition_penalty=repetition_penalty, - length_penalty=exponential_decay_length_penalty, + length_penalty=length_penalty, ) params = generation_pb2.Parameters( @@ -135,6 +216,7 @@ def unary_generate( top_p, typical_p, temperature, + seed, repetition_penalty, max_time, exponential_decay_length_penalty, @@ -148,21 +230,14 @@ def unary_generate( preserve_input_text: bool Whether or not the source string should be contained in the generated output, e.g., as a prefix. - max_new_tokens: int - The maximum numbers of tokens to generate. - Default: 20 - min_new_tokens: int - The minimum numbers of tokens to generate. - Default: 0 - means no minimum - truncate_input_tokens: int - Truncate inputs to provided number of tokens. This can be - use to avoid failing due to input being longer than - configured limits. - 0 - means don't truncate, thus throw error. + {} Returns: GeneratedTextResult Generated text result produced by TGIS. - """ + """.format( + GENERATE_FUNCTION_ARGS + ) + # In case internal client is not configured - generation # cannot be done (individual modules may already check # for this) @@ -176,6 +251,7 @@ def unary_generate( params = get_params( preserve_input_text=preserve_input_text, + eos_token=self.eos_token, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, truncate_input_tokens=truncate_input_tokens, @@ -184,6 +260,7 @@ def unary_generate( top_p=top_p, typical_p=typical_p, temperature=temperature, + seed=seed, repetition_penalty=repetition_penalty, max_time=max_time, exponential_decay_length_penalty=exponential_decay_length_penalty, @@ -235,6 +312,7 @@ def stream_generate( top_p, typical_p, temperature, + seed, repetition_penalty, max_time, exponential_decay_length_penalty, @@ -248,19 +326,13 @@ def stream_generate( preserve_input_text: bool Whether or not the source string should be contained in the generated output, e.g., as a prefix. - max_new_tokens: int - Maximum tokens for the model to generate - min_new_tokens: int - Minimum tokens for the model to generate - truncate_input_tokens: int - Truncate inputs to provided number of tokens. This can be - use to avoid failing due to input being longer than - configured limits. - 0 - means don't truncate, thus throw error. - + {} Returns: Iterable[GeneratedTextStreamResult] - """ + """.format( + GENERATE_FUNCTION_ARGS + ) + # In case internal client is not configured - generation # cannot be done (individual modules may already check # for this) @@ -273,6 +345,7 @@ def stream_generate( params = get_params( preserve_input_text=preserve_input_text, + eos_token=self.eos_token, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, truncate_input_tokens=truncate_input_tokens, @@ -281,6 +354,7 @@ def stream_generate( top_p=top_p, typical_p=typical_p, temperature=temperature, + seed=seed, repetition_penalty=repetition_penalty, max_time=max_time, exponential_decay_length_penalty=exponential_decay_length_penalty, diff --git a/tests/modules/text_generation/test_text_generation_tgis.py b/tests/modules/text_generation/test_text_generation_tgis.py index 16686f02..65dd34e3 100644 --- a/tests/modules/text_generation/test_text_generation_tgis.py +++ b/tests/modules/text_generation/test_text_generation_tgis.py @@ -15,7 +15,7 @@ import caikit # Local -from caikit_nlp.data_model.generation import GenerationTrainRecord +from caikit_nlp.data_model import ExponentialDecayLengthPenalty, GenerationTrainRecord from caikit_nlp.modules.text_generation import TextGeneration, TextGenerationTGIS from caikit_nlp.resources.pretrained_model.hf_auto_seq2seq_lm import HFAutoSeq2SeqLM from tests.fixtures import ( @@ -182,7 +182,7 @@ def test_bootstrap_and_run_causallm_with_optional_params(): temperature=0.75, repetition_penalty=0.3, max_time=1000, - exponential_decay_length_penalty=(), + exponential_decay_length_penalty=(1, 0.95), stop_sequences=["This is a test"], ) StubTGISClient.validate_unary_generate_response(result) @@ -206,7 +206,9 @@ def test_bootstrap_and_run_stream_out_with_optional_dependencies(): temperature=0.75, repetition_penalty=0.3, max_time=1000, - exponential_decay_length_penalty=(), + exponential_decay_length_penalty=ExponentialDecayLengthPenalty( + start_index=2, decay_factor=0.95 + ), stop_sequences=["This is a test"], ) StubTGISClient.validate_stream_generate_response(stream_result) From 84a15a0f82846f1102636490e3d1af61d4fc1c63 Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Tue, 29 Aug 2023 22:16:51 -0700 Subject: [PATCH 06/12] =?UTF-8?q?=E2=9C=A8=20Support=20input=5Ftoken=5Fcou?= =?UTF-8?q?nt=20in=20text=20gen=20output?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- caikit_nlp/toolkit/tgis_utils.py | 2 ++ tests/fixtures/__init__.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/caikit_nlp/toolkit/tgis_utils.py b/caikit_nlp/toolkit/tgis_utils.py index d4c05441..7e6dea78 100644 --- a/caikit_nlp/toolkit/tgis_utils.py +++ b/caikit_nlp/toolkit/tgis_utils.py @@ -298,6 +298,7 @@ def unary_generate( generated_tokens=response.generated_token_count, finish_reason=response.stop_reason, producer_id=self.producer_id, + input_token_count=response.input_token_count, ) def stream_generate( @@ -385,6 +386,7 @@ def stream_generate( finish_reason=stream_part.stop_reason, generated_tokens=stream_part.generated_token_count, seed=stream_part.seed, + input_token_count=stream_part.input_token_count, ) token_list = [] for token in stream_part.tokens: diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 99d5febe..871ea76d 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -186,6 +186,7 @@ def unary_generate(request): fake_result.stop_reason = 5 fake_result.generated_token_count = 1 fake_result.text = "moose" + fake_result.input_token_count = 1 fake_response.responses = [fake_result] return fake_response @@ -195,6 +196,7 @@ def stream_generate(request): fake_stream.stop_reason = 5 fake_stream.generated_token_count = 1 fake_stream.seed = 10 + fake_stream.input_token_count = 1 token = mock.Mock() token.text = "moose" token.logprob = 0.2 @@ -209,6 +211,7 @@ def validate_unary_generate_response(result): assert result.generated_text == "moose" assert result.generated_tokens == 1 assert result.finish_reason == 5 + assert result.input_token_count == 1 @staticmethod def validate_stream_generate_response(stream_result): @@ -223,6 +226,7 @@ def validate_stream_generate_response(stream_result): assert first_result.details.finish_reason == 5 assert first_result.details.generated_tokens == 1 assert first_result.details.seed == 10 + assert first_result.details.input_token_count == 1 class StubTGISBackend(TGISBackend): From f1e3f518532822caa9189bd2fa27cd231d07c6e9 Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Wed, 30 Aug 2023 10:29:50 -0700 Subject: [PATCH 07/12] =?UTF-8?q?=F0=9F=A6=BA=20Validate=20inference=20par?= =?UTF-8?q?ams?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- .../text_generation/peft_tgis_remote.py | 18 +-- .../text_generation/text_generation_tgis.py | 19 +-- caikit_nlp/toolkit/tgis_utils.py | 126 +++++++++++++++++- .../test_text_generation_tgis.py | 25 +++- 4 files changed, 144 insertions(+), 44 deletions(-) diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index 88710991..a5c04bf3 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -32,11 +32,7 @@ # Local from ...data_model import ExponentialDecayLengthPenalty -from ...toolkit.tgis_utils import ( - GENERATE_FUNCTION_ARGS, - VALID_DECODING_METHODS, - TGISGenerationClient, -) +from ...toolkit.tgis_utils import GENERATE_FUNCTION_ARGS, TGISGenerationClient from ...toolkit.verbalizer_utils import render_verbalizer from . import PeftPromptTuning @@ -202,12 +198,6 @@ def run( self.enable_backend, "Backend must be configured and loaded with this module before executing `run` call.", ) - error.value_check( - "", - decoding_method in VALID_DECODING_METHODS, - f"Decoding method [{decoding_method}] not in valid decoding methods: " - f"[{VALID_DECODING_METHODS}]", - ) verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) return self.tgis_generation_client.unary_generate( text=verbalized_text, @@ -264,12 +254,6 @@ def run_stream_out( "Backend must be configured and loaded with this module \ before executing `run_stream_out` call.", ) - error.value_check( - "", - decoding_method in VALID_DECODING_METHODS, - f"Decoding method [{decoding_method}] not in valid decoding methods: " - f"[{VALID_DECODING_METHODS}]", - ) verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) return self.tgis_generation_client.stream_generate( text=verbalized_text, diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index a5998015..c3e3156f 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -36,11 +36,7 @@ HFAutoSeq2SeqLM, PretrainedModelBase, ) -from ...toolkit.tgis_utils import ( - GENERATE_FUNCTION_ARGS, - VALID_DECODING_METHODS, - TGISGenerationClient, -) +from ...toolkit.tgis_utils import GENERATE_FUNCTION_ARGS, TGISGenerationClient from .text_generation_local import TextGeneration log = alog.use_channel("TXT_GEN") @@ -232,13 +228,6 @@ def run( GENERATE_FUNCTION_ARGS ) - error.value_check( - "", - decoding_method in VALID_DECODING_METHODS, - f"Decoding method [{decoding_method}] not in valid decoding methods: " - f"[{VALID_DECODING_METHODS}]", - ) - if self._model_loaded: return self.tgis_generation_client.unary_generate( text=text, @@ -289,12 +278,6 @@ def run_stream_out( GENERATE_FUNCTION_ARGS ) - error.value_check( - "", - decoding_method in VALID_DECODING_METHODS, - f"Decoding method [{decoding_method}] not in valid decoding methods: " - f"[{VALID_DECODING_METHODS}]", - ) if self._model_loaded: return self.tgis_generation_client.stream_generate( text=text, diff --git a/caikit_nlp/toolkit/tgis_utils.py b/caikit_nlp/toolkit/tgis_utils.py index 7e6dea78..ec572f32 100644 --- a/caikit_nlp/toolkit/tgis_utils.py +++ b/caikit_nlp/toolkit/tgis_utils.py @@ -33,8 +33,6 @@ log = alog.use_channel("TGIS_UTILS") error = error_handler.get(log) -VALID_DECODING_METHODS = ["GREEDY", "SAMPLING"] - # pylint: disable=duplicate-code GENERATE_FUNCTION_ARGS = """ @@ -102,6 +100,87 @@ """ +def validate_inf_params( + text, + preserve_input_text, + eos_token, + max_new_tokens, + min_new_tokens, + truncate_input_tokens, + decoding_method, + top_k, + top_p, + typical_p, + temperature, + seed, + repetition_penalty, + max_time, + exponential_decay_length_penalty, + stop_sequences, +): + """Validate inference parameters + + Args: + eos_token: str + A special token representing the end of a sentence. + {} + """.format( + GENERATE_FUNCTION_ARGS + ) + error.type_check("", str, text=text) + error.type_check("", bool, preserve_input_text=preserve_input_text) + error.type_check("", str, eos_token=eos_token) + error.type_check( + "", int, allow_none=True, max_new_tokens=max_new_tokens + ) + error.type_check( + "", int, allow_none=True, min_new_tokens=min_new_tokens + ) + + error.value_check( + "", + max_new_tokens >= min_new_tokens, + f"Maximum new tokens [{max_new_tokens}] has to be greater than minimum new tokens \ + [{min_new_tokens}]", + ) + + error.type_check( + "", + int, + allow_none=True, + truncate_input_tokens=truncate_input_tokens, + ) + + valid_decoding_methods = ["GREEDY", "SAMPLING"] + + error.value_check( + "", + decoding_method in valid_decoding_methods, + f"Decoding method [{decoding_method}] not in valid decoding methods: " + f"[{valid_decoding_methods}]", + ) + error.type_check("", int, allow_none=True, top_k=top_k) + error.type_check("", float, allow_none=True, top_p=top_p) + error.type_check("", float, allow_none=True, typical_p=typical_p) + error.type_check("", float, allow_none=True, temperature=temperature) + error.type_check("", int, allow_none=True, seed=seed) + error.type_check( + "", float, allow_none=True, repetition_penalty=repetition_penalty + ) + error.type_check("", float, allow_none=True, max_time=max_time) + error.type_check( + "", + ExponentialDecayLengthPenalty, + tuple, + allow_none=True, + exponential_decay_length_penalty=exponential_decay_length_penalty, + ) + + error.type_check_all( + "", str, allow_none=True, stop_sequences=stop_sequences + ) + + def get_params( preserve_input_text, eos_token, @@ -122,9 +201,6 @@ def get_params( """Get generation parameters Args: - preserve_input_text: str - Whether or not the source string should be contained in the generated output, - e.g., as a prefix. eos_token: str A special token representing the end of a sentence. {} @@ -156,7 +232,7 @@ def get_params( stop_sequences=stop_sequences or [eos_token] if eos_token else None, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, - time_limit_millis=max_time, + time_limit_millis=int(max_time * 1000) if max_time else None, ) start_index = None @@ -247,6 +323,25 @@ def unary_generate( "Backend must be configured and loaded for generate", ) + validate_inf_params( + text=text, + preserve_input_text=preserve_input_text, + eos_token=self.eos_token, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + truncate_input_tokens=truncate_input_tokens, + decoding_method=decoding_method, + top_k=top_k, + top_p=top_p, + typical_p=typical_p, + temperature=temperature, + seed=seed, + repetition_penalty=repetition_penalty, + max_time=max_time, + exponential_decay_length_penalty=exponential_decay_length_penalty, + stop_sequences=stop_sequences, + ) + log.debug("Building protobuf request to send to TGIS") params = get_params( @@ -344,6 +439,25 @@ def stream_generate( ) log.debug("Building protobuf request to send to TGIS") + validate_inf_params( + text=text, + preserve_input_text=preserve_input_text, + eos_token=self.eos_token, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + truncate_input_tokens=truncate_input_tokens, + decoding_method=decoding_method, + top_k=top_k, + top_p=top_p, + typical_p=typical_p, + temperature=temperature, + seed=seed, + repetition_penalty=repetition_penalty, + max_time=max_time, + exponential_decay_length_penalty=exponential_decay_length_penalty, + stop_sequences=stop_sequences, + ) + params = get_params( preserve_input_text=preserve_input_text, eos_token=self.eos_token, diff --git a/tests/modules/text_generation/test_text_generation_tgis.py b/tests/modules/text_generation/test_text_generation_tgis.py index 65dd34e3..8909ddbb 100644 --- a/tests/modules/text_generation/test_text_generation_tgis.py +++ b/tests/modules/text_generation/test_text_generation_tgis.py @@ -181,8 +181,8 @@ def test_bootstrap_and_run_causallm_with_optional_params(): typical_p=0.5, temperature=0.75, repetition_penalty=0.3, - max_time=1000, - exponential_decay_length_penalty=(1, 0.95), + max_time=10.5, + exponential_decay_length_penalty=(2, 0.95), stop_sequences=["This is a test"], ) StubTGISClient.validate_unary_generate_response(result) @@ -205,10 +205,29 @@ def test_bootstrap_and_run_stream_out_with_optional_dependencies(): typical_p=0.5, temperature=0.75, repetition_penalty=0.3, - max_time=1000, + max_time=10.5, exponential_decay_length_penalty=ExponentialDecayLengthPenalty( start_index=2, decay_factor=0.95 ), stop_sequences=["This is a test"], ) StubTGISClient.validate_stream_generate_response(stream_result) + + +def test_invalid_optional_params(): + """Check if we an error is thrown when invalid inference params are used to run causallm models""" + + model = TextGenerationTGIS.bootstrap( + CAUSAL_LM_MODEL, load_backend=StubTGISBackend() + ) + + with pytest.raises(ValueError): + _ = model.run( + SAMPLE_TEXT, preserve_input_text=True, max_new_tokens=20, min_new_tokens=50 + ) + + with pytest.raises(TypeError): + _ = model.run(SAMPLE_TEXT, preserve_input_text=True, top_k=0.5) + + with pytest.raises(TypeError): + _ = model.run(SAMPLE_TEXT, exponential_decay_length_penalty=[2, 0.95]) From d421b3c32923196e6ae9f6184103f095e6fc3981 Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Wed, 30 Aug 2023 13:46:07 -0700 Subject: [PATCH 08/12] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Update=20defaults=20?= =?UTF-8?q?and=20minor=20refactor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- .../text_generation/peft_tgis_remote.py | 12 +++---- .../text_generation/text_generation_tgis.py | 12 +++---- caikit_nlp/toolkit/tgis_utils.py | 34 ++++++++----------- 3 files changed, 27 insertions(+), 31 deletions(-) diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index a5c04bf3..65528c4d 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -169,11 +169,11 @@ def run( truncate_input_tokens: Optional[int] = 0, decoding_method: Optional[str] = "GREEDY", top_k: Optional[int] = 0, - top_p: Optional[float] = 0.0, - typical_p: Optional[float] = 0.0, + top_p: Optional[float] = 1.0, + typical_p: Optional[float] = 1.0, temperature: Optional[float] = 1.0, seed: Optional[int] = None, - repetition_penalty: Optional[float] = 0.0, + repetition_penalty: Optional[float] = 1.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[ Union[Tuple[int, float], ExponentialDecayLengthPenalty] @@ -227,11 +227,11 @@ def run_stream_out( truncate_input_tokens: Optional[int] = 0, decoding_method: Optional[str] = "GREEDY", top_k: Optional[int] = 0, - top_p: Optional[float] = 0.0, - typical_p: Optional[float] = 0.0, + top_p: Optional[float] = 1.0, + typical_p: Optional[float] = 1.0, temperature: Optional[float] = 1.0, seed: Optional[int] = None, - repetition_penalty: Optional[float] = 0.0, + repetition_penalty: Optional[float] = 1.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[ Union[Tuple[int, float], ExponentialDecayLengthPenalty] diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index c3e3156f..08bbfefc 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -206,11 +206,11 @@ def run( truncate_input_tokens: Optional[int] = 0, decoding_method: Optional[str] = "GREEDY", top_k: Optional[int] = 0, - top_p: Optional[float] = 0.0, - typical_p: Optional[float] = 0.0, + top_p: Optional[float] = 1.0, + typical_p: Optional[float] = 1.0, temperature: Optional[float] = 1.0, seed: Optional[int] = None, - repetition_penalty: Optional[float] = 0.0, + repetition_penalty: Optional[float] = 1.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[ Union[Tuple[int, float], ExponentialDecayLengthPenalty] @@ -257,11 +257,11 @@ def run_stream_out( truncate_input_tokens: Optional[int] = 0, decoding_method: Optional[str] = "GREEDY", top_k: Optional[int] = 0, - top_p: Optional[float] = 0.0, - typical_p: Optional[float] = 0.0, + top_p: Optional[float] = 1.0, + typical_p: Optional[float] = 1.0, temperature: Optional[float] = 1.0, seed: Optional[int] = None, - repetition_penalty: Optional[float] = 0.0, + repetition_penalty: Optional[float] = 1.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[ Union[Tuple[int, float], ExponentialDecayLengthPenalty] diff --git a/caikit_nlp/toolkit/tgis_utils.py b/caikit_nlp/toolkit/tgis_utils.py index ec572f32..c90ce86b 100644 --- a/caikit_nlp/toolkit/tgis_utils.py +++ b/caikit_nlp/toolkit/tgis_utils.py @@ -64,7 +64,7 @@ If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Only applicable when decoding_method is SAMPLING. - Default: 0.0 - means disabled - equivalent to 1.0 + Default: 1.0 - means disabled - 0.0 equivalent to 1.0 typical_p: float Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional @@ -73,7 +73,7 @@ locally typical tokens with probabilities that add up to typical_p or higher are kept for generation. Only applicable when decoding_method is SAMPLING. - Default: 0.0 - means disabled - equivalent to 1.0 + Default: 1.0 - means disabled - 0.0 equivalent to 1.0 temperature: float The value used to modulate the next token probabilities. Only applicable when decoding_method is SAMPLING. @@ -84,7 +84,7 @@ repetition_penalty: float The more a token is used within generation the more it is penalized to not be picked in successive generation passes. - Default: 0.0 - means no penalty - equivalent to 1.0 + Default: 1.0 - means no penalty - 0.0 equivalent to 1.0 max_time: float Amount of time in seconds that the query should take maximum. NOTE: this does not include network overhead. @@ -235,26 +235,22 @@ def get_params( time_limit_millis=int(max_time * 1000) if max_time else None, ) - start_index = None - decay_factor = None - if exponential_decay_length_penalty: - if isinstance(exponential_decay_length_penalty, tuple): - start_index = exponential_decay_length_penalty[0] - decay_factor = exponential_decay_length_penalty[1] - elif isinstance( - exponential_decay_length_penalty, ExponentialDecayLengthPenalty - ): - start_index = exponential_decay_length_penalty.start_index - decay_factor = exponential_decay_length_penalty.decay_factor - - length_penalty = generation_pb2.DecodingParameters.LengthPenalty( - start_index=start_index, decay_factor=decay_factor - ) + if isinstance(exponential_decay_length_penalty, ExponentialDecayLengthPenalty): + exponential_decay_length_penalty = ( + exponential_decay_length_penalty.start_index, + exponential_decay_length_penalty.decay_factor, + ) + exponential_decay_length_penalty = ( + generation_pb2.DecodingParameters.LengthPenalty( + start_index=exponential_decay_length_penalty[0], + decay_factor=exponential_decay_length_penalty[1], + ) + ) decoding_parameters = generation_pb2.DecodingParameters( repetition_penalty=repetition_penalty, - length_penalty=length_penalty, + length_penalty=exponential_decay_length_penalty, ) params = generation_pb2.Parameters( From b2c1ed2f8bdb9c14c8d715495926b1e99f55e7c3 Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Thu, 31 Aug 2023 10:50:03 -0700 Subject: [PATCH 09/12] =?UTF-8?q?=E2=AC=86=EF=B8=8F=20Upgrade=20caikit?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e3249f0f..d101c519 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ classifiers=[ "License :: OSI Approved :: Apache Software License" ] dependencies = [ - "caikit[runtime-grpc,runtime-http]>=0.18.0,<0.20.0", + "caikit[runtime-grpc,runtime-http]>=0.18.1,<0.20.0", "caikit-tgis-backend>=0.1.16,<0.2.0", # TODO: loosen dependencies "accelerate>=0.21.0", From c0bb81568957dbba71c460c4a5312ad5c3df26d5 Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Thu, 31 Aug 2023 11:10:40 -0700 Subject: [PATCH 10/12] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20tgis=20ut?= =?UTF-8?q?ils?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- .../text_generation/peft_tgis_remote.py | 9 +- .../text_generation/text_generation_tgis.py | 10 ++- .../{ => text_generation}/tgis_utils.py | 86 ++----------------- 3 files changed, 21 insertions(+), 84 deletions(-) rename caikit_nlp/toolkit/{ => text_generation}/tgis_utils.py (78%) diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index 65528c4d..4203af37 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -32,7 +32,10 @@ # Local from ...data_model import ExponentialDecayLengthPenalty -from ...toolkit.tgis_utils import GENERATE_FUNCTION_ARGS, TGISGenerationClient +from ...toolkit.text_generation.tgis_utils import ( + GENERATE_FUNCTION_TGIS_ARGS, + TGISGenerationClient, +) from ...toolkit.verbalizer_utils import render_verbalizer from . import PeftPromptTuning @@ -190,7 +193,7 @@ def run( GeneratedTextResult Generated text result produced by TGIS. """.format( - GENERATE_FUNCTION_ARGS + GENERATE_FUNCTION_TGIS_ARGS ) error.value_check( @@ -245,7 +248,7 @@ def run_stream_out( Returns: Iterable[GeneratedTextStreamResult] """.format( - GENERATE_FUNCTION_ARGS + GENERATE_FUNCTION_TGIS_ARGS ) error.value_check( diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index 08bbfefc..1cb54a76 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -36,14 +36,16 @@ HFAutoSeq2SeqLM, PretrainedModelBase, ) -from ...toolkit.tgis_utils import GENERATE_FUNCTION_ARGS, TGISGenerationClient +from ...toolkit.text_generation.tgis_utils import ( + GENERATE_FUNCTION_TGIS_ARGS, + TGISGenerationClient, +) from .text_generation_local import TextGeneration log = alog.use_channel("TXT_GEN") error = error_handler.get(log) # pylint: disable=too-many-instance-attributes -# pylint: disable=duplicate-code @module(backend_type=TGISBackend.backend_type, base_module=TextGeneration) @@ -225,7 +227,7 @@ def run( GeneratedTextResult Generated text result produced by TGIS. """.format( - GENERATE_FUNCTION_ARGS + GENERATE_FUNCTION_TGIS_ARGS ) if self._model_loaded: @@ -275,7 +277,7 @@ def run_stream_out( Returns: Iterable[GeneratedTextStreamResult] """.format( - GENERATE_FUNCTION_ARGS + GENERATE_FUNCTION_TGIS_ARGS ) if self._model_loaded: diff --git a/caikit_nlp/toolkit/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py similarity index 78% rename from caikit_nlp/toolkit/tgis_utils.py rename to caikit_nlp/toolkit/text_generation/tgis_utils.py index c90ce86b..5943e78b 100644 --- a/caikit_nlp/toolkit/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -28,76 +28,18 @@ import alog # Local -from ..data_model import ExponentialDecayLengthPenalty +from ...data_model import ExponentialDecayLengthPenalty +from .model_run_utils import GENERATE_FUNCTION_ARGS log = alog.use_channel("TGIS_UTILS") error = error_handler.get(log) -# pylint: disable=duplicate-code - -GENERATE_FUNCTION_ARGS = """ - text: str - Input string to be used to the generation model. +GENERATE_FUNCTION_TGIS_ARGS = """ + {} preserve_input_text: str Whether or not the source string should be contained in the generated output, e.g., as a prefix. - max_new_tokens: int - The maximum numbers of tokens to generate. - Default: 20 - min_new_tokens: int - The minimum numbers of tokens to generate. - Default: 0 - means no minimum - truncate_input_tokens: int - Truncate inputs to provided number of tokens. This can be - use to avoid failing due to input being longer than - configured limits. - Default: 0 - means don't truncate, thus throw error. - decoding_method: str - Parameters for conditionally penalizing / boosting - candidate tokens during decoding. - Options: "GREEDY" (default), "SAMPLING" - top_k: int - The number of highest probability vocabulary tokens to keep for - top-k-filtering. Only applicable when decoding_method is SAMPLING. - Default: 0 - means disabled - top_p: float - If set to float < 1, only the smallest set of most probable tokens - with probabilities that add up to top_p or higher are kept for - generation. Only applicable when decoding_method is SAMPLING. - Default: 1.0 - means disabled - 0.0 equivalent to 1.0 - typical_p: float - Local typicality measures how similar the conditional probability of - predicting a target token next is to the expected conditional - probability of predicting a random token next, given the partial text - already generated. If set to float < 1, the smallest set of the most - locally typical tokens with probabilities that add up to typical_p - or higher are kept for generation. Only applicable when decoding_method - is SAMPLING. - Default: 1.0 - means disabled - 0.0 equivalent to 1.0 - temperature: float - The value used to modulate the next token probabilities. - Only applicable when decoding_method is SAMPLING. - Default: 1.0 - means disabled - equivalent to 1.0 - seed: int - Random seed to control sampling. Only applicable when decoding_method - is SAMPLING. Default: None - repetition_penalty: float - The more a token is used within generation the more it is penalized - to not be picked in successive generation passes. - Default: 1.0 - means no penalty - 0.0 equivalent to 1.0 - max_time: float - Amount of time in seconds that the query should take maximum. - NOTE: this does not include network overhead. - Range: 0-120.0 - exponential_decay_length_penalty: Tuple(int, float) - This Tuple adds an exponentially increasing length penalty, after - a certain amount of tokens have been generated. The tuple shall - consist of: (start_index, decay_factor) where start_index - indicates where penalty starts and decay_factor represents the factor - of exponential decay - stop_sequences: List[str]: - List of strings to be used as stopping criteria -""" +""".format(GENERATE_FUNCTION_ARGS) def validate_inf_params( @@ -125,7 +67,7 @@ def validate_inf_params( A special token representing the end of a sentence. {} """.format( - GENERATE_FUNCTION_ARGS + GENERATE_FUNCTION_TGIS_ARGS ) error.type_check("", str, text=text) error.type_check("", bool, preserve_input_text=preserve_input_text) @@ -205,7 +147,7 @@ def get_params( A special token representing the end of a sentence. {} """.format( - GENERATE_FUNCTION_ARGS + GENERATE_FUNCTION_TGIS_ARGS ) if decoding_method == "GREEDY": @@ -297,17 +239,12 @@ def unary_generate( """Generate unary output from model in TGIS Args: - text: str - Source string to be encoded for generation. - preserve_input_text: bool - Whether or not the source string should be contained in the generated output, - e.g., as a prefix. {} Returns: GeneratedTextResult Generated text result produced by TGIS. """.format( - GENERATE_FUNCTION_ARGS + GENERATE_FUNCTION_TGIS_ARGS ) # In case internal client is not configured - generation @@ -413,16 +350,11 @@ def stream_generate( """Generate stream output from model in TGIS Args: - text: str - Source string to be encoded for generation. - preserve_input_text: bool - Whether or not the source string should be contained in the generated output, - e.g., as a prefix. {} Returns: Iterable[GeneratedTextStreamResult] """.format( - GENERATE_FUNCTION_ARGS + GENERATE_FUNCTION_TGIS_ARGS ) # In case internal client is not configured - generation From bb192cf40b4eb867d73a7d9e9b84b0117e69a509 Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Thu, 31 Aug 2023 11:19:19 -0700 Subject: [PATCH 11/12] =?UTF-8?q?=F0=9F=8E=A8=20formatting?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- .../toolkit/text_generation/tgis_utils.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index 5943e78b..16bf99b9 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -29,7 +29,7 @@ # Local from ...data_model import ExponentialDecayLengthPenalty -from .model_run_utils import GENERATE_FUNCTION_ARGS +from .model_run_utils import GENERATE_FUNCTION_ARGS, VALID_DECODING_METHODS log = alog.use_channel("TGIS_UTILS") error = error_handler.get(log) @@ -39,7 +39,9 @@ preserve_input_text: str Whether or not the source string should be contained in the generated output, e.g., as a prefix. -""".format(GENERATE_FUNCTION_ARGS) +""".format( + GENERATE_FUNCTION_ARGS +) def validate_inf_params( @@ -69,14 +71,14 @@ def validate_inf_params( """.format( GENERATE_FUNCTION_TGIS_ARGS ) - error.type_check("", str, text=text) + error.type_check("", str, text=text) error.type_check("", bool, preserve_input_text=preserve_input_text) - error.type_check("", str, eos_token=eos_token) + error.type_check("", str, eos_token=eos_token) error.type_check( - "", int, allow_none=True, max_new_tokens=max_new_tokens + "", int, allow_none=True, max_new_tokens=max_new_tokens ) error.type_check( - "", int, allow_none=True, min_new_tokens=min_new_tokens + "", int, allow_none=True, min_new_tokens=min_new_tokens ) error.value_check( @@ -87,31 +89,29 @@ def validate_inf_params( ) error.type_check( - "", + "", int, allow_none=True, truncate_input_tokens=truncate_input_tokens, ) - valid_decoding_methods = ["GREEDY", "SAMPLING"] - error.value_check( "", - decoding_method in valid_decoding_methods, + decoding_method in VALID_DECODING_METHODS, f"Decoding method [{decoding_method}] not in valid decoding methods: " - f"[{valid_decoding_methods}]", + f"[{VALID_DECODING_METHODS}]", ) - error.type_check("", int, allow_none=True, top_k=top_k) - error.type_check("", float, allow_none=True, top_p=top_p) - error.type_check("", float, allow_none=True, typical_p=typical_p) - error.type_check("", float, allow_none=True, temperature=temperature) - error.type_check("", int, allow_none=True, seed=seed) + error.type_check("", int, allow_none=True, top_k=top_k) + error.type_check("", float, allow_none=True, top_p=top_p) + error.type_check("", float, allow_none=True, typical_p=typical_p) + error.type_check("", float, allow_none=True, temperature=temperature) + error.type_check("", int, allow_none=True, seed=seed) error.type_check( - "", float, allow_none=True, repetition_penalty=repetition_penalty + "", float, allow_none=True, repetition_penalty=repetition_penalty ) - error.type_check("", float, allow_none=True, max_time=max_time) + error.type_check("", float, allow_none=True, max_time=max_time) error.type_check( - "", + "", ExponentialDecayLengthPenalty, tuple, allow_none=True, @@ -119,7 +119,7 @@ def validate_inf_params( ) error.type_check_all( - "", str, allow_none=True, stop_sequences=stop_sequences + "", str, allow_none=True, stop_sequences=stop_sequences ) From 10dac557255854a64665fcb5295e279d9dda2d68 Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Thu, 31 Aug 2023 14:35:46 -0700 Subject: [PATCH 12/12] =?UTF-8?q?=F0=9F=92=A1=20Docstring=20edit?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- caikit_nlp/modules/text_generation/peft_tgis_remote.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index 4203af37..896389a4 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -183,12 +183,10 @@ def run( ] = None, stop_sequences: Optional[str] = None, ) -> GeneratedTextResult: - """Run inference against the model running in TGIS. Currently we leverage greedy decoding - and apply the same verbalizer used for training the local model prior to sending the - request to TGIS. + """Run inference against the model running in TGIS. Args: - {} + {} Returns: GeneratedTextResult Generated text result produced by TGIS.