Skip to content

Commit

Permalink
Merge pull request #155 from tharapalanivel/more_inference_params
Browse files Browse the repository at this point in the history
Add more text generation inference params to TGIS modules
  • Loading branch information
tharapalanivel authored Aug 31, 2023
2 parents 4db0c26 + 10dac55 commit 0d8a323
Show file tree
Hide file tree
Showing 8 changed files with 685 additions and 365 deletions.
138 changes: 80 additions & 58 deletions caikit_nlp/modules/text_generation/peft_tgis_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
prompt vectors in TGIS generation requests.
"""
# Standard
from typing import Iterable, List, Optional
from typing import Iterable, List, Optional, Tuple, Union
import os

# First Party
Expand All @@ -31,7 +31,11 @@
import alog

# Local
from ...toolkit.tgis_utils import TGISGenerationClient
from ...data_model import ExponentialDecayLengthPenalty
from ...toolkit.text_generation.tgis_utils import (
GENERATE_FUNCTION_TGIS_ARGS,
TGISGenerationClient,
)
from ...toolkit.verbalizer_utils import render_verbalizer
from . import PeftPromptTuning

Expand Down Expand Up @@ -161,82 +165,90 @@ def save(self, model_path: str):
@TextGenerationTask.taskmethod()
def run(
self,
text,
preserve_input_text=False,
max_new_tokens=20,
min_new_tokens=0,
truncate_input_tokens=0,
text: str,
preserve_input_text: bool = False,
max_new_tokens: Optional[int] = 20,
min_new_tokens: Optional[int] = 0,
truncate_input_tokens: Optional[int] = 0,
decoding_method: Optional[str] = "GREEDY",
top_k: Optional[int] = 0,
top_p: Optional[float] = 1.0,
typical_p: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
seed: Optional[int] = None,
repetition_penalty: Optional[float] = 1.0,
max_time: Optional[float] = None,
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
) -> GeneratedTextResult:
"""Run inference against the model running in TGIS. Currently we leverage greedy decoding
and apply the same verbalizer used for training the local model prior to sending the
request to TGIS.
"""Run inference against the model running in TGIS.
Args:
text: str
Source string to be encoded for generation.
preserve_input_text: str
Whether or not the source string should be contained in the generated output,
e.g., as a prefix.
max_new_tokens: int
The maximum numbers of tokens to generate.
Default: 20
min_new_tokens: int
The minimum numbers of tokens to generate.
Default: 0 - means no minimum
truncate_input_tokens: int
Truncate inputs to provided number of tokens. This can be
use to avoid failing due to input being longer than
configured limits.
Default: 0 - means don't truncate, thus throw error.
{}
Returns:
GeneratedTextResult
Generated text result produced by TGIS.
"""
""".format(
GENERATE_FUNCTION_TGIS_ARGS
)

error.value_check(
"<NLP87360638E>",
self.enable_backend,
"Backend must be configured and loaded with this module before executing `run` call.",
)
verbalized_text = render_verbalizer(self.verbalizer, {"input": text})
return self.tgis_generation_client.unary_generate(
verbalized_text,
preserve_input_text,
max_new_tokens,
min_new_tokens,
truncate_input_tokens,
text=verbalized_text,
preserve_input_text=preserve_input_text,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
decoding_method=decoding_method,
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
seed=seed,
repetition_penalty=repetition_penalty,
max_time=max_time,
exponential_decay_length_penalty=exponential_decay_length_penalty,
stop_sequences=stop_sequences,
)

@TextGenerationTask.taskmethod(output_streaming=True)
def run_stream_out(
self,
text: str,
preserve_input_text=False,
max_new_tokens=20,
min_new_tokens=0,
truncate_input_tokens=0,
preserve_input_text: bool = False,
max_new_tokens: Optional[int] = 20,
min_new_tokens: Optional[int] = 0,
truncate_input_tokens: Optional[int] = 0,
decoding_method: Optional[str] = "GREEDY",
top_k: Optional[int] = 0,
top_p: Optional[float] = 1.0,
typical_p: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
seed: Optional[int] = None,
repetition_penalty: Optional[float] = 1.0,
max_time: Optional[float] = None,
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
) -> Iterable[GeneratedTextStreamResult]:
"""Run output stream inferencing against the model running in TGIS
Args:
text: str
Source string to be encoded for generation.
preserve_input_text: str
Whether or not the source string should be contained in the generated output,
e.g., as a prefix.
max_new_tokens: int
The maximum numbers of tokens to generate.
Default: 20
min_new_tokens: int
The minimum numbers of tokens to generate.
Default: 0 - means no minimum
truncate_input_tokens: int
Truncate inputs to provided number of tokens. This can be
use to avoid failing due to input being longer than
configured limits.
Default: 0 - means don't truncate, thus throw error.
{}
Returns:
Iterable[GeneratedTextStreamResult]
"""
""".format(
GENERATE_FUNCTION_TGIS_ARGS
)

error.value_check(
"<NLP62995899E>",
self.enable_backend,
Expand All @@ -245,9 +257,19 @@ def run_stream_out(
)
verbalized_text = render_verbalizer(self.verbalizer, {"input": text})
return self.tgis_generation_client.stream_generate(
verbalized_text,
preserve_input_text,
max_new_tokens,
min_new_tokens,
truncate_input_tokens,
text=verbalized_text,
preserve_input_text=preserve_input_text,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
decoding_method=decoding_method,
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
seed=seed,
repetition_penalty=repetition_penalty,
max_time=max_time,
exponential_decay_length_penalty=exponential_decay_length_penalty,
stop_sequences=stop_sequences,
)
130 changes: 79 additions & 51 deletions caikit_nlp/modules/text_generation/text_generation_tgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


# Standard
from typing import Iterable, Optional, Union
from typing import Iterable, Optional, Tuple, Union
import os

# First Party
Expand All @@ -30,18 +30,24 @@
import alog

# Local
from ...data_model import ExponentialDecayLengthPenalty
from ...resources.pretrained_model import (
HFAutoCausalLM,
HFAutoSeq2SeqLM,
PretrainedModelBase,
)
from ...toolkit.tgis_utils import TGISGenerationClient
from ...toolkit.text_generation.tgis_utils import (
GENERATE_FUNCTION_TGIS_ARGS,
TGISGenerationClient,
)
from .text_generation_local import TextGeneration

log = alog.use_channel("TXT_GEN")
error = error_handler.get(log)

# pylint: disable=too-many-instance-attributes


@module(backend_type=TGISBackend.backend_type, base_module=TextGeneration)
class TextGenerationTGIS(ModuleBase):
"""Module to provide text generation capabilities"""
Expand Down Expand Up @@ -197,76 +203,98 @@ def run(
self,
text: str,
preserve_input_text: bool = False,
max_new_tokens: int = 20,
min_new_tokens: int = 0,
truncate_input_tokens: int = 0,
max_new_tokens: Optional[int] = 20,
min_new_tokens: Optional[int] = 0,
truncate_input_tokens: Optional[int] = 0,
decoding_method: Optional[str] = "GREEDY",
top_k: Optional[int] = 0,
top_p: Optional[float] = 1.0,
typical_p: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
seed: Optional[int] = None,
repetition_penalty: Optional[float] = 1.0,
max_time: Optional[float] = None,
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
) -> GeneratedTextResult:
"""Run inference against the model running in TGIS.
Args:
text: str
Source string to be encoded for generation.
preserve_input_text: bool
Whether or not the source string should be contained in the generated output,
e.g., as a prefix.
max_new_tokens: int
The maximum numbers of tokens to generate.
Default: 20
min_new_tokens: int
The minimum numbers of tokens to generate.
Default: 0 - means no minimum
truncate_input_tokens: int
Truncate inputs to provided number of tokens. This can be
use to avoid failing due to input being longer than
configured limits.
Default: 0 - means don't truncate, thus throw error.
{}
Returns:
GeneratedTextResult
Generated text result produced by TGIS.
"""
""".format(
GENERATE_FUNCTION_TGIS_ARGS
)

if self._model_loaded:
return self.tgis_generation_client.unary_generate(
text,
preserve_input_text,
max_new_tokens,
min_new_tokens,
truncate_input_tokens,
text=text,
preserve_input_text=preserve_input_text,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
decoding_method=decoding_method,
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
seed=seed,
repetition_penalty=repetition_penalty,
max_time=max_time,
exponential_decay_length_penalty=exponential_decay_length_penalty,
stop_sequences=stop_sequences,
)

@TextGenerationTask.taskmethod(output_streaming=True)
def run_stream_out(
self,
text: str,
preserve_input_text=False,
max_new_tokens=20,
min_new_tokens=0,
truncate_input_tokens=0,
preserve_input_text: bool = False,
max_new_tokens: Optional[int] = 20,
min_new_tokens: Optional[int] = 0,
truncate_input_tokens: Optional[int] = 0,
decoding_method: Optional[str] = "GREEDY",
top_k: Optional[int] = 0,
top_p: Optional[float] = 1.0,
typical_p: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
seed: Optional[int] = None,
repetition_penalty: Optional[float] = 1.0,
max_time: Optional[float] = None,
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
) -> Iterable[GeneratedTextStreamResult]:
"""Run output stream inferencing for text generation module.
Args:
text: str
Source string to be encoded for generation.
preserve_input_text: bool
Whether or not the source string should be contained in the generated output,
e.g., as a prefix.
max_new_tokens: int
Maximum tokens for the model to generate
min_new_tokens: int
Minimum tokens for the model to generate
truncate_input_tokens: int
Truncate inputs to provided number of tokens. This can be
use to avoid failing due to input being longer than
configured limits.
Default: 0 - means don't truncate, thus throw error.
{}
Returns:
Iterable[GeneratedTextStreamResult]
"""
""".format(
GENERATE_FUNCTION_TGIS_ARGS
)

if self._model_loaded:
return self.tgis_generation_client.stream_generate(
text,
preserve_input_text,
max_new_tokens,
min_new_tokens,
truncate_input_tokens,
text=text,
preserve_input_text=preserve_input_text,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
decoding_method=decoding_method,
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
seed=seed,
repetition_penalty=repetition_penalty,
max_time=max_time,
exponential_decay_length_penalty=exponential_decay_length_penalty,
stop_sequences=stop_sequences,
)
Loading

0 comments on commit 0d8a323

Please sign in to comment.