diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index c1b4ca53..a40c3aca 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -42,6 +42,7 @@ ) from transformers.models.auto.tokenization_auto import AutoTokenizer from transformers.optimization import get_linear_schedule_with_warmup +import numpy as np import torch # First Party @@ -180,7 +181,7 @@ def run( top_p: Optional[float] = 1.0, typical_p: Optional[float] = 1.0, temperature: Optional[float] = 1.0, - seed: Optional[int] = None, + seed: Optional[np.uint64] = None, repetition_penalty: Optional[float] = 1.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[ @@ -240,7 +241,7 @@ def run_stream_out( top_p: Optional[float] = 0.0, typical_p: Optional[float] = 0.0, temperature: Optional[float] = 1.0, - seed: Optional[int] = None, + seed: Optional[np.uint64] = None, repetition_penalty: Optional[float] = 0.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[ diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index a62a1c7c..6b712104 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -18,6 +18,9 @@ from typing import Iterable, List, Optional, Tuple, Union import os +# Third Party +import numpy as np + # First Party from caikit.core import ModuleBase, ModuleConfig, ModuleSaver, modules from caikit.core.module_backends import BackendBase, backend_types @@ -175,7 +178,7 @@ def run( top_p: Optional[float] = 1.0, typical_p: Optional[float] = 1.0, temperature: Optional[float] = 1.0, - seed: Optional[int] = None, + seed: Optional[np.uint64] = None, repetition_penalty: Optional[float] = 1.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[ @@ -231,7 +234,7 @@ def run_stream_out( top_p: Optional[float] = 1.0, typical_p: Optional[float] = 1.0, temperature: Optional[float] = 1.0, - seed: Optional[int] = None, + seed: Optional[np.uint64] = None, repetition_penalty: Optional[float] = 1.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[ diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index 43be59c4..e2130ed7 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -17,6 +17,9 @@ from typing import Iterable, List, Optional, Tuple, Union import os +# Third Party +import numpy as np + # First Party from caikit.core.module_backends import BackendBase, backend_types from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module @@ -211,7 +214,7 @@ def run( top_p: Optional[float] = 1.0, typical_p: Optional[float] = 1.0, temperature: Optional[float] = 1.0, - seed: Optional[int] = None, + seed: Optional[np.uint64] = None, repetition_penalty: Optional[float] = 1.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[ @@ -262,7 +265,7 @@ def run_stream_out( top_p: Optional[float] = 1.0, typical_p: Optional[float] = 1.0, temperature: Optional[float] = 1.0, - seed: Optional[int] = None, + seed: Optional[np.uint64] = None, repetition_penalty: Optional[float] = 1.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[ diff --git a/caikit_nlp/toolkit/text_generation/model_run_utils.py b/caikit_nlp/toolkit/text_generation/model_run_utils.py index 43b3b195..45699166 100644 --- a/caikit_nlp/toolkit/text_generation/model_run_utils.py +++ b/caikit_nlp/toolkit/text_generation/model_run_utils.py @@ -19,6 +19,7 @@ # Third Party from transformers import StoppingCriteria, TextStreamer +import numpy as np import torch # First Party @@ -79,7 +80,7 @@ The value used to modulate the next token probabilities. Only applicable when decoding_method is SAMPLING. Default: 1.0 - means disabled - equivalent to 1.0 - seed: int + seed: numpy.uint64 Random seed to control sampling. Only applicable when decoding_method is SAMPLING. Default: None repetition_penalty: float @@ -141,7 +142,7 @@ def generate_text_func( top_p: Optional[float] = 1.0, typical_p: Optional[float] = 1.0, temperature: Optional[float] = 1.0, - seed: Optional[int] = None, + seed: Optional[np.uint64] = None, repetition_penalty: Optional[float] = 1.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[ @@ -249,7 +250,7 @@ def generate_text_func_stream( top_p: Optional[float] = 0.0, typical_p: Optional[float] = 0.0, temperature: Optional[float] = 1.0, - seed: Optional[int] = None, + seed: Optional[np.uint64] = None, repetition_penalty: Optional[float] = 0.0, max_time: Optional[float] = None, exponential_decay_length_penalty: Optional[