Skip to content

Commit

Permalink
Merge pull request #171 from tharapalanivel/seed_fix
Browse files Browse the repository at this point in the history
Fix seed type mismatch
  • Loading branch information
gkumbhat authored Sep 5, 2023
2 parents d0c18d7 + 7b0306e commit eda7af7
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 9 deletions.
5 changes: 3 additions & 2 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.optimization import get_linear_schedule_with_warmup
import numpy as np
import torch

# First Party
Expand Down Expand Up @@ -180,7 +181,7 @@ def run(
top_p: Optional[float] = 1.0,
typical_p: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
seed: Optional[int] = None,
seed: Optional[np.uint64] = None,
repetition_penalty: Optional[float] = 1.0,
max_time: Optional[float] = None,
exponential_decay_length_penalty: Optional[
Expand Down Expand Up @@ -240,7 +241,7 @@ def run_stream_out(
top_p: Optional[float] = 0.0,
typical_p: Optional[float] = 0.0,
temperature: Optional[float] = 1.0,
seed: Optional[int] = None,
seed: Optional[np.uint64] = None,
repetition_penalty: Optional[float] = 0.0,
max_time: Optional[float] = None,
exponential_decay_length_penalty: Optional[
Expand Down
7 changes: 5 additions & 2 deletions caikit_nlp/modules/text_generation/peft_tgis_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from typing import Iterable, List, Optional, Tuple, Union
import os

# Third Party
import numpy as np

# First Party
from caikit.core import ModuleBase, ModuleConfig, ModuleSaver, modules
from caikit.core.module_backends import BackendBase, backend_types
Expand Down Expand Up @@ -175,7 +178,7 @@ def run(
top_p: Optional[float] = 1.0,
typical_p: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
seed: Optional[int] = None,
seed: Optional[np.uint64] = None,
repetition_penalty: Optional[float] = 1.0,
max_time: Optional[float] = None,
exponential_decay_length_penalty: Optional[
Expand Down Expand Up @@ -231,7 +234,7 @@ def run_stream_out(
top_p: Optional[float] = 1.0,
typical_p: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
seed: Optional[int] = None,
seed: Optional[np.uint64] = None,
repetition_penalty: Optional[float] = 1.0,
max_time: Optional[float] = None,
exponential_decay_length_penalty: Optional[
Expand Down
7 changes: 5 additions & 2 deletions caikit_nlp/modules/text_generation/text_generation_tgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from typing import Iterable, List, Optional, Tuple, Union
import os

# Third Party
import numpy as np

# First Party
from caikit.core.module_backends import BackendBase, backend_types
from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module
Expand Down Expand Up @@ -211,7 +214,7 @@ def run(
top_p: Optional[float] = 1.0,
typical_p: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
seed: Optional[int] = None,
seed: Optional[np.uint64] = None,
repetition_penalty: Optional[float] = 1.0,
max_time: Optional[float] = None,
exponential_decay_length_penalty: Optional[
Expand Down Expand Up @@ -262,7 +265,7 @@ def run_stream_out(
top_p: Optional[float] = 1.0,
typical_p: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
seed: Optional[int] = None,
seed: Optional[np.uint64] = None,
repetition_penalty: Optional[float] = 1.0,
max_time: Optional[float] = None,
exponential_decay_length_penalty: Optional[
Expand Down
7 changes: 4 additions & 3 deletions caikit_nlp/toolkit/text_generation/model_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# Third Party
from transformers import StoppingCriteria, TextStreamer
import numpy as np
import torch

# First Party
Expand Down Expand Up @@ -79,7 +80,7 @@
The value used to modulate the next token probabilities.
Only applicable when decoding_method is SAMPLING.
Default: 1.0 - means disabled - equivalent to 1.0
seed: int
seed: numpy.uint64
Random seed to control sampling. Only applicable when decoding_method
is SAMPLING. Default: None
repetition_penalty: float
Expand Down Expand Up @@ -141,7 +142,7 @@ def generate_text_func(
top_p: Optional[float] = 1.0,
typical_p: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
seed: Optional[int] = None,
seed: Optional[np.uint64] = None,
repetition_penalty: Optional[float] = 1.0,
max_time: Optional[float] = None,
exponential_decay_length_penalty: Optional[
Expand Down Expand Up @@ -249,7 +250,7 @@ def generate_text_func_stream(
top_p: Optional[float] = 0.0,
typical_p: Optional[float] = 0.0,
temperature: Optional[float] = 1.0,
seed: Optional[int] = None,
seed: Optional[np.uint64] = None,
repetition_penalty: Optional[float] = 0.0,
max_time: Optional[float] = None,
exponential_decay_length_penalty: Optional[
Expand Down

0 comments on commit eda7af7

Please sign in to comment.