diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index c2acec11..b637e9f6 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -38,7 +38,6 @@ AutoConfig, AutoModelForCausalLM, DataCollatorForLanguageModeling, - TextStreamer, default_data_collator, ) from transformers.models.auto.tokenization_auto import AutoTokenizer @@ -59,7 +58,12 @@ import alog # Local -from ...data_model import GenerationTrainRecord, PromptOutputModelType, TuningConfig +from ...data_model import ( + ExponentialDecayLengthPenalty, + GenerationTrainRecord, + PromptOutputModelType, + TuningConfig, +) from ...resources.pretrained_model import ( HFAutoCausalLM, HFAutoSeq2SeqLM, @@ -68,6 +72,11 @@ from ...toolkit.data_stream_wrapper import SimpleIterableStreamWrapper from ...toolkit.data_type_utils import get_torch_dtype, str_to_torch_dtype from ...toolkit.task_specific_utils import convert_to_generation_record +from ...toolkit.text_generation.model_run_utils import ( + GENERATE_FUNCTION_ARGS, + generate_text_func, + generate_text_func_stream, +) from ...toolkit.verbalizer_utils import is_valid_verbalizer, render_verbalizer log = alog.use_channel("PEFT_PROMPT") @@ -94,13 +103,6 @@ class TuningType(str, Enum): # LORA = "LORA" -class Streamer(TextStreamer): - # The default TextStreamer currently prints to stdout - # so we override that here - def on_finalized_text(self, text: str, stream_end: bool = False): - pass - - # TODO: try to refactor this into a smaller module # pylint: disable=too-many-lines,too-many-instance-attributes @module( @@ -170,53 +172,55 @@ def __del__(self): def run( self, text: str, - device: Optional[Union[str, int]] = None, - max_new_tokens=20, - min_new_tokens=0, + max_new_tokens: Optional[int] = 20, + min_new_tokens: Optional[int] = 0, + truncate_input_tokens: Optional[int] = 0, + decoding_method: Optional[str] = "GREEDY", + top_k: Optional[int] = 0, + top_p: Optional[float] = 1.0, + typical_p: Optional[float] = 1.0, + temperature: Optional[float] = 1.0, + seed: Optional[int] = None, + repetition_penalty: Optional[float] = 1.0, + max_time: Optional[float] = None, + exponential_decay_length_penalty: Optional[ + Union[Tuple[int, float], ExponentialDecayLengthPenalty] + ] = None, + stop_sequences: Optional[str] = None, ) -> GeneratedTextResult: - """Run the full text generation model. - - Args: - text: str - Input string to be used to the generation model. - device: Optional[Union[str, int]] - Deprecated. By default, we use the detected device. - max_new_tokens: int - The maximum numbers of tokens to generate. - Default: 20 - min_new_tokens: int - The minimum numbers of tokens to generate. - Default: 0 - means no minimum - - Returns: - GeneratedTextResult - Generated text result produced by PEFT / Transformers. """ - if device is not None: - log.warning( - "Specifying device is deprecated and ignored, please update your calling argument" - ) - device = self._DETECT_DEVICE - # Apply the verbalizer to our text string + Run the full text generation model. + Args: + {} + Returns: + GeneratedTextResult + Generated text result produced by PEFT / Transformers. + """.format( + GENERATE_FUNCTION_ARGS + ) + verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) - # Apply the tokenizer to the sample text & move to correct device - tok_tensors = self.tokenizer(verbalized_text, return_tensors="pt") - - device = PeftPromptTuning._get_device(device) - inputs = {k: v.to(device) for k, v in tok_tensors.items()} - with torch.no_grad(): - # Run tokenized tensors through the rest of the PEFT model - outputs = self.model.generate( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - max_new_tokens=max_new_tokens, - min_new_tokens=min_new_tokens, - eos_token_id=self.eos_token_id, - ) - gen_text = self.tokenizer.batch_decode( - outputs.detach().cpu().numpy(), skip_special_tokens=True - ) - return GeneratedTextResult(generated_text=gen_text[0]) + + return generate_text_func( + self.model, + self.tokenizer, + self.PRODUCER_ID, + self.tokenizer.eos_token, + verbalized_text, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + truncate_input_tokens=truncate_input_tokens, + decoding_method=decoding_method, + top_k=top_k, + top_p=top_p, + typical_p=typical_p, + temperature=temperature, + seed=seed, + repetition_penalty=repetition_penalty, + max_time=max_time, + exponential_decay_length_penalty=exponential_decay_length_penalty, + stop_sequences=stop_sequences, + ) # NOTE: We need to disable wip decorator here otherwise we get issues in # proto generation for streaming. We are keeping it commented out for now, @@ -226,7 +230,23 @@ def run( # ) @TextGenerationTask.taskmethod(output_streaming=True) def run_stream_out( - self, text: str, max_new_tokens=20, min_new_tokens=0 + self, + text: str, + max_new_tokens=20, + min_new_tokens=0, + truncate_input_tokens: Optional[int] = 0, + decoding_method: Optional[str] = "GREEDY", + top_k: Optional[int] = 0, + top_p: Optional[float] = 0.0, + typical_p: Optional[float] = 0.0, + temperature: Optional[float] = 1.0, + seed: Optional[int] = None, + repetition_penalty: Optional[float] = 0.0, + max_time: Optional[float] = None, + exponential_decay_length_penalty: Optional[ + Union[Tuple[int, float], ExponentialDecayLengthPenalty] + ] = None, + stop_sequences: Optional[str] = None, ) -> Iterable[GeneratedTextStreamResult]: """Run the text generation model with output streaming @@ -236,40 +256,37 @@ def run_stream_out( Ref. https://huggingface.co/docs/transformers/v4.30.0/generation_strategies#streaming Args: - text: str - Input string to be used to the generation model. - max_new_tokens: int - The maximum numbers of tokens to generate. - Default: 20 - min_new_tokens: int - The minimum numbers of tokens to generate. - Default: 0 - means no minimum + {} Returns: Iterable[GeneratedTextStreamResult] - """ + """.format( + GENERATE_FUNCTION_ARGS + ) + # Apply the verbalizer to our text string verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) - # Apply the tokenizer to the sample text & move to correct device - tok_tensors = self.tokenizer(verbalized_text, return_tensors="pt") - inputs = {k: v.to(self.model.device) for k, v in tok_tensors.items()} - - streamer = Streamer(self.tokenizer) - with torch.no_grad(): - # Run tokenized tensors through the rest of the PEFT model - stream_outputs = self.model.generate( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - max_new_tokens=max_new_tokens, - min_new_tokens=min_new_tokens, - eos_token_id=self.eos_token_id, - streamer=streamer, - ) - for stream_part in stream_outputs: - gen_text = self.tokenizer.batch_decode( - stream_part.detach().cpu().numpy(), skip_special_tokens=True - ) - yield GeneratedTextStreamResult(generated_text=gen_text) + + return generate_text_func_stream( + self.model, + self.tokenizer, + self.PRODUCER_ID, + self.tokenizer.eos_token, + verbalized_text, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + truncate_input_tokens=truncate_input_tokens, + decoding_method=decoding_method, + top_k=top_k, + top_p=top_p, + typical_p=typical_p, + temperature=temperature, + seed=seed, + repetition_penalty=repetition_penalty, + max_time=max_time, + exponential_decay_length_penalty=exponential_decay_length_penalty, + stop_sequences=stop_sequences, + ) @classmethod def train( diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index 3898f16d..370aa2e5 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -40,6 +40,10 @@ ) from ...toolkit.data_stream_wrapper import SimpleIterableStreamWrapper from ...toolkit.data_type_utils import get_torch_dtype, str_to_torch_dtype +from ...toolkit.text_generation.model_run_utils import ( + GENERATE_FUNCTION_ARGS, + generate_text_func, +) log = alog.use_channel("TXT_GEN") error = error_handler.get(log) @@ -380,106 +384,51 @@ def save(self, model_path): def run( self, text: str, - repetition_penalty: float = 2.5, - length_penalty: float = 1.0, - early_stopping: bool = True, - num_beams: int = 1, - max_new_tokens: int = 20, - min_new_tokens: int = 0, - truncate_input_tokens: int = 0, + max_new_tokens: Optional[int] = 20, + min_new_tokens: Optional[int] = 0, + truncate_input_tokens: Optional[int] = 0, + decoding_method: Optional[str] = "GREEDY", + top_k: Optional[int] = 0, + top_p: Optional[float] = 0.0, + typical_p: Optional[float] = 0.0, + temperature: Optional[float] = 1.0, + repetition_penalty: Optional[float] = 0.0, + max_time: Optional[float] = None, **kwargs, - ) -> "GeneratedTextResult": - """Run inference against the model running in TGIS. + ) -> GeneratedTextResult: - Args: - text: str - Source string to be encoded for generation. - repetition_penalty: float - The parameter for repetition penalty. 1.0 means no penalty. - Default: 2.5 - length_penalty: float - Exponential penalty to the length that is used with beam-based generation. - It is applied as an exponent to the sequence length, \ - which is used to divide the score of the sequence. - Since the score is the log likelihood of the sequence (i.e. negative), \ - length_penalty > 0.0 promotes longer sequences, \ - while length_penalty < 0.0 encourages shorter sequences. - Default: 1.0. - early_stopping: bool - Controls the stopping condition for beam-based methods, like beam-search. - It accepts the following values: - True, where the generation stops as soon as there are num_beams complete candidates; - False, where an heuristic is applied and the generation stops when \ - is it very unlikely to find better candidates; - "never", where the beam search procedure only stops \ - when there cannot be better candidates (canonical beam search algorithm). - num_beams: int - Number of beams for beam search. 1 means no beam search. - Default: 1 - max_new_tokens: int - The maximum numbers of tokens to generate. - Default: 20 - min_new_tokens: int - The minimum numbers of tokens to generate. - Default: 0 - means no minimum - truncate_input_tokens: int - Truncate inputs to provided number of tokens. This can be - use to avoid failing due to input being longer than - configured limits. - Default: 0 - means don't truncate, thus throw error. - kwargs: - Any other parameters to pass to generate as specified in GenerationConfig. - https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/text_generation#transformers.GenerationConfig - Returns: - GeneratedTextResult - Generated text result produced by the model. """ + Run the full text generation model. + Args: + {} + Returns: + GeneratedTextResult + Generated text result produced by the model. + """.format( + GENERATE_FUNCTION_ARGS + ) - # NOTE: below is to match TGIS API, where 0 identifies as no truncation - if truncate_input_tokens == 0: - # NOTE: below will make model throw error in case inputs are longer - # than allowed length - truncation = False - - else: - truncation = True + # TODO: Beam search currently not supported - inputs = self.model.tokenizer( + return generate_text_func( + # Pass HF model + self.model.model, + self.model.tokenizer, + self.PRODUCER_ID, + self._eos_token, text, - truncation=truncation, - max_length=truncate_input_tokens, - return_tensors="pt", - ) - generate_ids = self.model.model.generate( - input_ids=inputs["input_ids"], - num_beams=num_beams, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, + truncate_input_tokens=truncate_input_tokens, + decoding_method=decoding_method, + top_k=top_k, + top_p=top_p, + typical_p=typical_p, + temperature=temperature, repetition_penalty=repetition_penalty, - length_penalty=length_penalty, - early_stopping=early_stopping, - use_cache=True, + max_time=max_time, **kwargs, ) - token_count = generate_ids.size(1) - 1 - preds = [ - self.model.tokenizer.decode( - g, skip_special_tokens=True, clean_up_tokenization_spaces=True - ) - for g in generate_ids - ] - if generate_ids[0][-1].item() == self._eos_token: - finish_reason = "EOS_TOKEN" - elif generate_ids.size(1) - 1 == max_new_tokens: - finish_reason = "MAX_TOKENS" - else: - finish_reason = "OTHER" - return GeneratedTextResult( - generated_tokens=token_count, - generated_text=preds[0], - finish_reason=finish_reason, - producer_id=self.PRODUCER_ID, - ) ################################## Private Functions ###################################### diff --git a/caikit_nlp/toolkit/text_generation/model_run_utils.py b/caikit_nlp/toolkit/text_generation/model_run_utils.py new file mode 100644 index 00000000..ac33130b --- /dev/null +++ b/caikit_nlp/toolkit/text_generation/model_run_utils.py @@ -0,0 +1,435 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions used for executing run function for text_generation""" + +# Standard +from typing import Optional, Tuple, Union + +# Third Party +from transformers import StoppingCriteria, TextStreamer +import torch + +# First Party +from caikit.core.data_model.producer import ProducerId +from caikit.core.toolkit.errors import error_handler +from caikit.interfaces.nlp.data_model import ( + GeneratedTextResult, + GeneratedTextStreamResult, + TokenStreamDetails, +) +import alog + +# Local +from ...data_model import ExponentialDecayLengthPenalty + +log = alog.use_channel("RUN_UTILS") +error = error_handler.get(log) + +VALID_DECODING_METHODS = ["GREEDY", "SAMPLING"] + +GENERATE_FUNCTION_ARGS = """ + text: str + Input string to be used to the generation model. + max_new_tokens: int + The maximum numbers of tokens to generate. + Default: 20 + min_new_tokens: int + The minimum numbers of tokens to generate. + Default: 0 - means no minimum + truncate_input_tokens: int + Truncate inputs to provided number of tokens. This can be + use to avoid failing due to input being longer than + configured limits. + Default: 0 - means don't truncate, thus throw error. + decoding_method: str + Parameters for conditionally penalizing / boosting + candidate tokens during decoding. + Options: "GREEDY" (default), "SAMPLING" + top_k: int + The number of highest probability vocabulary tokens to keep for + top-k-filtering. Only applicable when decoding_method is SAMPLING. + Default: 0 - means disabled + top_p: float + If set to float < 1, only the smallest set of most probable tokens + with probabilities that add up to top_p or higher are kept for + generation. Only applicable when decoding_method is SAMPLING. + Default: 1.0 - means disabled - 0.0 equivalent to 1.0 + typical_p: float + Local typicality measures how similar the conditional probability of + predicting a target token next is to the expected conditional + probability of predicting a random token next, given the partial text + already generated. If set to float < 1, the smallest set of the most + locally typical tokens with probabilities that add up to typical_p + or higher are kept for generation. Only applicable when decoding_method + is SAMPLING. + Default: 1.0 - means disabled - 0.0 equivalent to 1.0 + temperature: float + The value used to modulate the next token probabilities. + Only applicable when decoding_method is SAMPLING. + Default: 1.0 - means disabled - equivalent to 1.0 + seed: int + Random seed to control sampling. Only applicable when decoding_method + is SAMPLING. Default: None + repetition_penalty: float + The more a token is used within generation the more it is penalized + to not be picked in successive generation passes. + Default: 1.0 - means no penalty - 0.0 equivalent to 1.0 + max_time: float + Amount of time in seconds that the query should take maximum. + NOTE: this does not include network overhead. + Range: 0-120.0 + exponential_decay_length_penalty: Tuple(int, float) + This Tuple adds an exponentially increasing length penalty, after + a certain amount of tokens have been generated. The tuple shall + consist of: (start_index, decay_factor) where start_index + indicates where penalty starts and decay_factor represents the factor + of exponential decay + stop_sequences: List[str]: + List of strings to be used as stopping criteria +""" + + +class Streamer(TextStreamer): + # The default TextStreamer currently prints to stdout + # so we override that here + def on_finalized_text(self, text: str, stream_end: bool = False): + pass + + +class SequenceStoppingCriteria(StoppingCriteria): + def __init__(self, target_sequence_ids): + self.target_sequence_ids = target_sequence_ids + + def __call__(self, input_ids, scores, **kwargs): + # Check if the target sequence appears in the generated text + for seq_id in self.target_sequence_ids: + if seq_id in input_ids: + return True # Stop generation + + return False # Continue generation + + def __len__(self): + return 1 + + def __iter__(self): + yield self + + +def generate_text_func( + model, + tokenizer, + producer_id: ProducerId, + eos_token: str, + text: str, + max_new_tokens: Optional[int] = 20, + min_new_tokens: Optional[int] = 0, + truncate_input_tokens: Optional[int] = 0, + decoding_method: Optional[str] = "GREEDY", + top_k: Optional[int] = 0, + top_p: Optional[float] = 1.0, + typical_p: Optional[float] = 1.0, + temperature: Optional[float] = 1.0, + seed: Optional[int] = None, + repetition_penalty: Optional[float] = 1.0, + max_time: Optional[float] = None, + exponential_decay_length_penalty: Optional[ + Union[Tuple[int, float], ExponentialDecayLengthPenalty] + ] = None, + stop_sequences: Optional[str] = None, + **kwargs, +): + """ + Args: + model: PeftModel or transformers.AutoModel + Peft model or Transformers model + tokenizer: AutoTokenizer + Tokenizer to be used with the model + producer_id: ProducerId + Caikit producer id associated with the module + eos_token: str + End of sequence token to be used with generation + {} + Returns: + GeneratedTextResult + """.format( + GENERATE_FUNCTION_ARGS + ) + + error.type_check("", str, eos_token=eos_token) + error.type_check("", str, text=text) + + error.type_check( + "", + int, + allow_none=True, + truncate_input_tokens=truncate_input_tokens, + ) + + # NOTE: below is to match TGIS API, where 0 identifies as no truncation + truncation = truncate_input_tokens != 0 + + tok_tensors = tokenizer( + text, + truncation=truncation, + max_length=truncate_input_tokens, + return_tensors="pt", + ) + inputs = {k: v.to(model.device) for k, v in tok_tensors.items()} + + input_token_count = len(tok_tensors) + + gen_optional_params = __process_gen_args( + tokenizer, + max_new_tokens, + min_new_tokens, + decoding_method, + top_k, + top_p, + typical_p, + temperature, + seed, + repetition_penalty, + max_time, + exponential_decay_length_penalty, + stop_sequences, + ) + + with torch.no_grad(): + generate_ids = model.generate( + input_ids=inputs["input_ids"], + **gen_optional_params, + **kwargs, + ) + + token_count = generate_ids.size(1) - 1 + + preds = [ + tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) + for g in generate_ids + ] + + if generate_ids[0][-1].item() == eos_token: + finish_reason = "EOS_TOKEN" + elif generate_ids.size(1) - 1 == max_new_tokens: + finish_reason = "MAX_TOKENS" + else: + finish_reason = "OTHER" + return GeneratedTextResult( + generated_tokens=token_count, + generated_text=preds[0], + finish_reason=finish_reason, + producer_id=producer_id, + input_token_count=input_token_count, + ) + + +def generate_text_func_stream( + model, + tokenizer, + producer_id: ProducerId, + eos_token: str, + text: str, + max_new_tokens: Optional[int] = 20, + min_new_tokens: Optional[int] = 0, + truncate_input_tokens: Optional[int] = 0, + decoding_method: Optional[str] = "GREEDY", + top_k: Optional[int] = 0, + top_p: Optional[float] = 0.0, + typical_p: Optional[float] = 0.0, + temperature: Optional[float] = 1.0, + seed: Optional[int] = None, + repetition_penalty: Optional[float] = 0.0, + max_time: Optional[float] = None, + exponential_decay_length_penalty: Optional[ + Union[Tuple[int, float], ExponentialDecayLengthPenalty] + ] = None, + stop_sequences: Optional[str] = None, + **kwargs, +): + """ + Args: + model: PeftModel or transformers.AutoModel + Peft model or Transformers model + tokenizer: AutoTokenizer + Tokenizer to be used with the model + producer_id: ProducerId + Caikit producer id associated with the module + eos_token: str + End of sequence token to be used with generation + {} + Returns: + GeneratedTextResult + """.format( + GENERATE_FUNCTION_ARGS + ) + error.type_check("", str, eos_token=eos_token) + error.type_check("", str, text=text) + + error.type_check( + "", + int, + allow_none=True, + truncate_input_tokens=truncate_input_tokens, + ) + + # NOTE: below is to match TGIS API, where 0 identifies as no truncation + truncation = truncate_input_tokens != 0 + + tok_tensors = tokenizer( + text, + truncation=truncation, + max_length=truncate_input_tokens, + return_tensors="pt", + ) + inputs = {k: v.to(model.device) for k, v in tok_tensors.items()} + + input_token_count = len(tok_tensors) + + streamer = Streamer(tokenizer) + + gen_optional_params = __process_gen_args( + tokenizer, + max_new_tokens, + min_new_tokens, + decoding_method, + top_k, + top_p, + typical_p, + temperature, + seed, + repetition_penalty, + max_time, + exponential_decay_length_penalty, + stop_sequences, + ) + + with torch.no_grad(): + # Run tokenized tensors through the rest of the PEFT model + stream_outputs = model.generate( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + streamer=streamer, + **gen_optional_params, + **kwargs, + ) + details = TokenStreamDetails( + input_token_count=input_token_count, + ) + for stream_part in stream_outputs: + gen_text = tokenizer.batch_decode( + stream_part.detach().cpu().numpy(), skip_special_tokens=True + ) + yield GeneratedTextStreamResult( + generated_text=gen_text, details=details, producer_id=producer_id + ) + + +def __process_gen_args( + tokenizer, + max_new_tokens, + min_new_tokens, + decoding_method, + top_k, + top_p, + typical_p, + temperature, + seed, + repetition_penalty, + max_time, + exponential_decay_length_penalty, + stop_sequences, +): + """Utility function to preprocess model generate arguments""" + error.type_check( + "", int, allow_none=True, max_new_tokens=max_new_tokens + ) + error.type_check( + "", int, allow_none=True, min_new_tokens=min_new_tokens + ) + error.type_check("", int, allow_none=True, top_k=top_k) + error.type_check("", float, allow_none=True, top_p=top_p) + error.type_check( + "", + float, + allow_none=True, + typical_p=typical_p, + temperature=temperature, + ) + error.type_check( + "", + float, + allow_none=True, + repetition_penalty=repetition_penalty, + max_time=max_time, + ) + error.type_check_all( + "", str, allow_none=True, stop_sequences=stop_sequences + ) + error.type_check("", int, allow_none=True, seed=seed) + + error.value_check( + "", + max_new_tokens >= min_new_tokens, + "Max new tokens needs to be bigger than min new tokens", + ) + + if isinstance(exponential_decay_length_penalty, ExponentialDecayLengthPenalty): + exponential_decay_length_penalty = ( + exponential_decay_length_penalty.start_index, + exponential_decay_length_penalty.decay_factor, + ) + + error.type_check( + "", + tuple, + allow_none=True, + exponential_decay_length_penalty=exponential_decay_length_penalty, + ) + + error.value_check( + "", + decoding_method in VALID_DECODING_METHODS, + f"Decoding method [{decoding_method}] not in valid decoding methods: " + f"[{VALID_DECODING_METHODS}]", + ) + + if repetition_penalty == 0.0: + repetition_penalty = 1.0 + + gen_optional_params = { + "max_new_tokens": max_new_tokens, + "min_new_tokens": min_new_tokens, + "repetition_penalty": repetition_penalty, + "use_cache": True, + "max_time": max_time, + "exponential_decay_length_penalty": exponential_decay_length_penalty, + } + + # TODO: Make decoding parameters enums + if decoding_method == "SAMPLING": + gen_optional_params["do_sample"] = True + gen_optional_params["top_k"] = top_k + gen_optional_params["top_p"] = top_p + gen_optional_params["typical_p"] = typical_p + gen_optional_params["temperature"] = temperature + gen_optional_params["seed"] = seed + + if stop_sequences and len(stop_sequences) > 0: + # Tokenize sequences + stop_sequence_ids = tokenizer.encode(stop_sequences) + stopping_criteria = SequenceStoppingCriteria(stop_sequence_ids) + gen_optional_params["stopping_criteria"] = stopping_criteria + + return gen_optional_params diff --git a/pyproject.toml b/pyproject.toml index a6610177..e3249f0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ classifiers=[ "License :: OSI Approved :: Apache Software License" ] dependencies = [ - "caikit[runtime-grpc,runtime-http]>=0.16.0,<0.18.0", + "caikit[runtime-grpc,runtime-http]>=0.18.0,<0.20.0", "caikit-tgis-backend>=0.1.16,<0.2.0", # TODO: loosen dependencies "accelerate>=0.21.0", diff --git a/tests/modules/text_generation/test_peft_prompt_tuning.py b/tests/modules/text_generation/test_peft_prompt_tuning.py index 5ed1c9fe..2d16c4b8 100644 --- a/tests/modules/text_generation/test_peft_prompt_tuning.py +++ b/tests/modules/text_generation/test_peft_prompt_tuning.py @@ -24,6 +24,7 @@ import caikit # Local +from caikit_nlp.data_model import ExponentialDecayLengthPenalty from caikit_nlp.modules.text_generation import PeftPromptTuning from caikit_nlp.modules.text_generation.peft_prompt_tuning import TuningType from tests.fixtures import ( @@ -75,6 +76,7 @@ def test_run_stream_out_model(causal_lm_dummy_model): pred_stream = causal_lm_dummy_model.run_stream_out("This text doesn't matter") assert isinstance(pred_stream, Iterable) for pred in pred_stream: + print(pred) assert isinstance(pred, GeneratedTextStreamResult) @@ -82,7 +84,9 @@ def test_verbalizer_rendering(causal_lm_dummy_model): """Ensure that our model renders its verbalizer text correctly before calling tokenizer.""" # Mock the tokenizer; we want to make sure its inputs are rendered properly causal_lm_dummy_model.tokenizer = mock.Mock( - side_effect=RuntimeError("Tokenizer is a mock!") + side_effect=RuntimeError("Tokenizer is a mock!"), + # Set eos token property to be attribute of tokenizer + eos_token="", ) input_text = "This text doesn't matter" causal_lm_dummy_model.verbalizer = " | {{input}} |" @@ -297,3 +301,60 @@ def test_model_can_only_have_one_or_two_transformer_modules(seq2seq_lm_dummy_mod TuningType.PROMPT_TUNING, seq2seq_lm_dummy_model.output_model_types, ) + + +######################## Test run with optional params ##################### + + +def test_run_repetition_penalty_0_works(causal_lm_dummy_model): + """Ensure repetition_penalty works with 0.0 as input""" + pred = causal_lm_dummy_model.run("This text doesn't matter", repetition_penalty=0.0) + assert isinstance(pred, GeneratedTextResult) + + +def test_run_truncate_tokens_0(causal_lm_dummy_model): + """Ensure run function accepts 0 for truncation value + and successfully turns off truncation""" + pred = causal_lm_dummy_model.run( + "This text doesn't matter", truncate_input_tokens=0 + ) + assert isinstance(pred, GeneratedTextResult) + + +def test_run_sampling_param_ignored_greedy_decoding(causal_lm_dummy_model): + """Ensure sampling parameter gets ignored when decoding method + is set to GREEDY + """ + pred = causal_lm_dummy_model.run( + "This text doesn't matter", + decoding_method="GREEDY", + top_k=2, + top_p=0.23, + typical_p=0.23, + temperature=0.77, + ) + assert isinstance(pred, GeneratedTextResult) + + +def test_run_with_custom_stop_criteria(causal_lm_dummy_model): + """Ensure custom stop sequences works with run""" + pred = causal_lm_dummy_model.run( + "This text doesn't matter", + decoding_method="GREEDY", + stop_sequences=["Foo", "bar"], + ) + assert isinstance(pred, GeneratedTextResult) + + +def test_run_exponential_decay_len_penatly_object(causal_lm_dummy_model): + """Ensure exponential decay len penalty works with the data model + object + """ + penalty = ExponentialDecayLengthPenalty(1, 0.2) + pred = causal_lm_dummy_model.run( + "This text doesn't matter", + decoding_method="GREEDY", + stop_sequences=["Foo", "bar"], + exponential_decay_length_penalty=penalty, + ) + assert isinstance(pred, GeneratedTextResult)