From 5a55a724d13b356d8ee2ec183bafeffcdebf9e06 Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Tue, 5 Sep 2023 11:41:49 -0700 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Fix=20seed=20type=20mismatch?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- caikit_nlp/modules/text_generation/peft_prompt_tuning.py | 5 +++-- caikit_nlp/modules/text_generation/peft_tgis_remote.py | 7 +++++-- caikit_nlp/modules/text_generation/text_generation_tgis.py | 7 +++++-- caikit_nlp/toolkit/text_generation/model_run_utils.py | 7 ++++--- 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index b637e9f6..aa89adb7 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -42,6 +42,7 @@ ) from transformers.models.auto.tokenization_auto import AutoTokenizer from transformers.optimization import get_linear_schedule_with_warmup +import numpy as np import torch # First Party @@ -180,7 +181,7 @@ def run( top_p: Optional[float] = 1.0, typical_p: Optional[float] = 1.0, temperature: Optional[float] = 1.0, - seed: Optional[int] = None, + seed: Optional[np.uint64] = None, repetition_penalty: Optional[float] = 1.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[ @@ -240,7 +241,7 @@ def run_stream_out( top_p: Optional[float] = 0.0, typical_p: Optional[float] = 0.0, temperature: Optional[float] = 1.0, - seed: Optional[int] = None, + seed: Optional[np.uint64] = None, repetition_penalty: Optional[float] = 0.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[ diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index 896389a4..ef3eb6fc 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -18,6 +18,9 @@ from typing import Iterable, List, Optional, Tuple, Union import os +# Third Party +import numpy as np + # First Party from caikit.core import ModuleBase, ModuleConfig, ModuleSaver, modules from caikit.core.module_backends import BackendBase, backend_types @@ -175,7 +178,7 @@ def run( top_p: Optional[float] = 1.0, typical_p: Optional[float] = 1.0, temperature: Optional[float] = 1.0, - seed: Optional[int] = None, + seed: Optional[np.uint64] = None, repetition_penalty: Optional[float] = 1.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[ @@ -231,7 +234,7 @@ def run_stream_out( top_p: Optional[float] = 1.0, typical_p: Optional[float] = 1.0, temperature: Optional[float] = 1.0, - seed: Optional[int] = None, + seed: Optional[np.uint64] = None, repetition_penalty: Optional[float] = 1.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[ diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index 1cb54a76..61a5331b 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -17,6 +17,9 @@ from typing import Iterable, Optional, Tuple, Union import os +# Third Party +import numpy as np + # First Party from caikit.core.module_backends import BackendBase, backend_types from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module @@ -211,7 +214,7 @@ def run( top_p: Optional[float] = 1.0, typical_p: Optional[float] = 1.0, temperature: Optional[float] = 1.0, - seed: Optional[int] = None, + seed: Optional[np.uint64] = None, repetition_penalty: Optional[float] = 1.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[ @@ -262,7 +265,7 @@ def run_stream_out( top_p: Optional[float] = 1.0, typical_p: Optional[float] = 1.0, temperature: Optional[float] = 1.0, - seed: Optional[int] = None, + seed: Optional[np.uint64] = None, repetition_penalty: Optional[float] = 1.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[ diff --git a/caikit_nlp/toolkit/text_generation/model_run_utils.py b/caikit_nlp/toolkit/text_generation/model_run_utils.py index 49cee8d7..b8454249 100644 --- a/caikit_nlp/toolkit/text_generation/model_run_utils.py +++ b/caikit_nlp/toolkit/text_generation/model_run_utils.py @@ -19,6 +19,7 @@ # Third Party from transformers import StoppingCriteria, TextStreamer +import numpy as np import torch # First Party @@ -79,7 +80,7 @@ The value used to modulate the next token probabilities. Only applicable when decoding_method is SAMPLING. Default: 1.0 - means disabled - equivalent to 1.0 - seed: int + seed: numpy.uint64 Random seed to control sampling. Only applicable when decoding_method is SAMPLING. Default: None repetition_penalty: float @@ -141,7 +142,7 @@ def generate_text_func( top_p: Optional[float] = 1.0, typical_p: Optional[float] = 1.0, temperature: Optional[float] = 1.0, - seed: Optional[int] = None, + seed: Optional[np.uint64] = None, repetition_penalty: Optional[float] = 1.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[ @@ -249,7 +250,7 @@ def generate_text_func_stream( top_p: Optional[float] = 0.0, typical_p: Optional[float] = 0.0, temperature: Optional[float] = 1.0, - seed: Optional[int] = None, + seed: Optional[np.uint64] = None, repetition_penalty: Optional[float] = 0.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[