Skip to content

Commit

Permalink
Merge pull request #168 from tharapalanivel/stop_sequences_fix
Browse files Browse the repository at this point in the history
🐛 stop_sequences as list of str
  • Loading branch information
gkumbhat authored Sep 5, 2023
2 parents 841f446 + 5f93677 commit d0c18d7
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def run(
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
stop_sequences: Optional[List[str]] = None,
) -> GeneratedTextResult:
"""
Run the full text generation model.
Expand Down Expand Up @@ -246,7 +246,7 @@ def run_stream_out(
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
stop_sequences: Optional[List[str]] = None,
) -> Iterable[GeneratedTextStreamResult]:
"""Run the text generation model with output streaming
Expand Down
4 changes: 2 additions & 2 deletions caikit_nlp/modules/text_generation/peft_tgis_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def run(
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
stop_sequences: Optional[List[str]] = None,
) -> GeneratedTextResult:
"""Run inference against the model running in TGIS.
Expand Down Expand Up @@ -237,7 +237,7 @@ def run_stream_out(
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
stop_sequences: Optional[List[str]] = None,
) -> Iterable[GeneratedTextStreamResult]:
"""Run output stream inferencing against the model running in TGIS
Expand Down
6 changes: 3 additions & 3 deletions caikit_nlp/modules/text_generation/text_generation_tgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


# Standard
from typing import Iterable, Optional, Tuple, Union
from typing import Iterable, List, Optional, Tuple, Union
import os

# First Party
Expand Down Expand Up @@ -217,7 +217,7 @@ def run(
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
stop_sequences: Optional[List[str]] = None,
) -> GeneratedTextResult:
"""Run inference against the model running in TGIS.
Expand Down Expand Up @@ -268,7 +268,7 @@ def run_stream_out(
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
stop_sequences: Optional[List[str]] = None,
) -> Iterable[GeneratedTextStreamResult]:
"""Run output stream inferencing for text generation module.
Expand Down
8 changes: 4 additions & 4 deletions caikit_nlp/toolkit/text_generation/model_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Utility functions used for executing run function for text_generation"""

# Standard
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

# Third Party
from transformers import StoppingCriteria, TextStreamer
Expand Down Expand Up @@ -96,7 +96,7 @@
consist of: (start_index, decay_factor) where start_index
indicates where penalty starts and decay_factor represents the factor
of exponential decay
stop_sequences: List[str]:
stop_sequences: List[str]
List of strings to be used as stopping criteria
"""

Expand Down Expand Up @@ -147,7 +147,7 @@ def generate_text_func(
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
stop_sequences: Optional[List[str]] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -255,7 +255,7 @@ def generate_text_func_stream(
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
stop_sequences: Optional[List[str]] = None,
**kwargs,
):
"""
Expand Down

0 comments on commit d0c18d7

Please sign in to comment.