Skip to content

Commit

Permalink
🔀 Merge branch 'main' into seed_fix
Browse files Browse the repository at this point in the history
Signed-off-by: Thara Palanivel <[email protected]>
  • Loading branch information
tharapalanivel committed Sep 5, 2023
2 parents 5a55a72 + d0c18d7 commit 7b0306e
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def run(
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
stop_sequences: Optional[List[str]] = None,
) -> GeneratedTextResult:
"""
Run the full text generation model.
Expand Down Expand Up @@ -247,7 +247,7 @@ def run_stream_out(
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
stop_sequences: Optional[List[str]] = None,
) -> Iterable[GeneratedTextStreamResult]:
"""Run the text generation model with output streaming
Expand Down
4 changes: 2 additions & 2 deletions caikit_nlp/modules/text_generation/peft_tgis_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def run(
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
stop_sequences: Optional[List[str]] = None,
) -> GeneratedTextResult:
"""Run inference against the model running in TGIS.
Expand Down Expand Up @@ -240,7 +240,7 @@ def run_stream_out(
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
stop_sequences: Optional[List[str]] = None,
) -> Iterable[GeneratedTextStreamResult]:
"""Run output stream inferencing against the model running in TGIS
Expand Down
6 changes: 3 additions & 3 deletions caikit_nlp/modules/text_generation/text_generation_tgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


# Standard
from typing import Iterable, Optional, Tuple, Union
from typing import Iterable, List, Optional, Tuple, Union
import os

# Third Party
Expand Down Expand Up @@ -220,7 +220,7 @@ def run(
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
stop_sequences: Optional[List[str]] = None,
) -> GeneratedTextResult:
"""Run inference against the model running in TGIS.
Expand Down Expand Up @@ -271,7 +271,7 @@ def run_stream_out(
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
stop_sequences: Optional[List[str]] = None,
) -> Iterable[GeneratedTextStreamResult]:
"""Run output stream inferencing for text generation module.
Expand Down
8 changes: 4 additions & 4 deletions caikit_nlp/toolkit/text_generation/model_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Utility functions used for executing run function for text_generation"""

# Standard
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

# Third Party
from transformers import StoppingCriteria, TextStreamer
Expand Down Expand Up @@ -97,7 +97,7 @@
consist of: (start_index, decay_factor) where start_index
indicates where penalty starts and decay_factor represents the factor
of exponential decay
stop_sequences: List[str]:
stop_sequences: List[str]
List of strings to be used as stopping criteria
"""

Expand Down Expand Up @@ -148,7 +148,7 @@ def generate_text_func(
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
stop_sequences: Optional[List[str]] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -256,7 +256,7 @@ def generate_text_func_stream(
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
stop_sequences: Optional[List[str]] = None,
**kwargs,
):
"""
Expand Down

0 comments on commit 7b0306e

Please sign in to comment.