Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more text generation inference params to TGIS modules #155

Merged
merged 15 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 79 additions & 55 deletions caikit_nlp/modules/text_generation/peft_tgis_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
prompt vectors in TGIS generation requests.
"""
# Standard
from typing import Iterable, List, Optional
from typing import Iterable, List, Optional, Tuple, Union
import os

# First Party
Expand All @@ -31,7 +31,11 @@
import alog

# Local
from ...toolkit.tgis_utils import TGISGenerationClient
from ...data_model import ExponentialDecayLengthPenalty
from ...toolkit.text_generation.tgis_utils import (
GENERATE_FUNCTION_TGIS_ARGS,
TGISGenerationClient,
)
from ...toolkit.verbalizer_utils import render_verbalizer
from . import PeftPromptTuning

Expand Down Expand Up @@ -161,82 +165,92 @@ def save(self, model_path: str):
@TextGenerationTask.taskmethod()
def run(
self,
text,
preserve_input_text=False,
max_new_tokens=20,
min_new_tokens=0,
truncate_input_tokens=0,
text: str,
preserve_input_text: bool = False,
max_new_tokens: Optional[int] = 20,
min_new_tokens: Optional[int] = 0,
truncate_input_tokens: Optional[int] = 0,
decoding_method: Optional[str] = "GREEDY",
top_k: Optional[int] = 0,
top_p: Optional[float] = 1.0,
typical_p: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
seed: Optional[int] = None,
repetition_penalty: Optional[float] = 1.0,
max_time: Optional[float] = None,
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
) -> GeneratedTextResult:
"""Run inference against the model running in TGIS. Currently we leverage greedy decoding
and apply the same verbalizer used for training the local model prior to sending the
request to TGIS.

Args:
text: str
Source string to be encoded for generation.
preserve_input_text: str
Whether or not the source string should be contained in the generated output,
e.g., as a prefix.
max_new_tokens: int
The maximum numbers of tokens to generate.
Default: 20
min_new_tokens: int
The minimum numbers of tokens to generate.
Default: 0 - means no minimum
truncate_input_tokens: int
Truncate inputs to provided number of tokens. This can be
use to avoid failing due to input being longer than
configured limits.
Default: 0 - means don't truncate, thus throw error.
{}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Above docstring can be modified a bit since we do provide decoding_method as option now so its not only greedy decoding anymore

Returns:
GeneratedTextResult
Generated text result produced by TGIS.
"""
""".format(
GENERATE_FUNCTION_TGIS_ARGS
)

error.value_check(
"<NLP87360638E>",
self.enable_backend,
"Backend must be configured and loaded with this module before executing `run` call.",
)
verbalized_text = render_verbalizer(self.verbalizer, {"input": text})
return self.tgis_generation_client.unary_generate(
verbalized_text,
preserve_input_text,
max_new_tokens,
min_new_tokens,
truncate_input_tokens,
text=verbalized_text,
preserve_input_text=preserve_input_text,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
decoding_method=decoding_method,
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
seed=seed,
repetition_penalty=repetition_penalty,
max_time=max_time,
exponential_decay_length_penalty=exponential_decay_length_penalty,
stop_sequences=stop_sequences,
)

@TextGenerationTask.taskmethod(output_streaming=True)
def run_stream_out(
self,
text: str,
preserve_input_text=False,
max_new_tokens=20,
min_new_tokens=0,
truncate_input_tokens=0,
preserve_input_text: bool = False,
max_new_tokens: Optional[int] = 20,
min_new_tokens: Optional[int] = 0,
truncate_input_tokens: Optional[int] = 0,
decoding_method: Optional[str] = "GREEDY",
top_k: Optional[int] = 0,
top_p: Optional[float] = 1.0,
typical_p: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
seed: Optional[int] = None,
repetition_penalty: Optional[float] = 1.0,
max_time: Optional[float] = None,
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
) -> Iterable[GeneratedTextStreamResult]:
"""Run output stream inferencing against the model running in TGIS

Args:
text: str
Source string to be encoded for generation.
preserve_input_text: str
Whether or not the source string should be contained in the generated output,
e.g., as a prefix.
max_new_tokens: int
The maximum numbers of tokens to generate.
Default: 20
min_new_tokens: int
The minimum numbers of tokens to generate.
Default: 0 - means no minimum
truncate_input_tokens: int
Truncate inputs to provided number of tokens. This can be
use to avoid failing due to input being longer than
configured limits.
Default: 0 - means don't truncate, thus throw error.
{}
Returns:
Iterable[GeneratedTextStreamResult]
"""
""".format(
GENERATE_FUNCTION_TGIS_ARGS
)

error.value_check(
"<NLP62995899E>",
self.enable_backend,
Expand All @@ -245,9 +259,19 @@ def run_stream_out(
)
verbalized_text = render_verbalizer(self.verbalizer, {"input": text})
return self.tgis_generation_client.stream_generate(
verbalized_text,
preserve_input_text,
max_new_tokens,
min_new_tokens,
truncate_input_tokens,
text=verbalized_text,
preserve_input_text=preserve_input_text,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
decoding_method=decoding_method,
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
seed=seed,
repetition_penalty=repetition_penalty,
max_time=max_time,
exponential_decay_length_penalty=exponential_decay_length_penalty,
stop_sequences=stop_sequences,
)
130 changes: 79 additions & 51 deletions caikit_nlp/modules/text_generation/text_generation_tgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


# Standard
from typing import Iterable, Optional, Union
from typing import Iterable, Optional, Tuple, Union
import os

# First Party
Expand All @@ -30,18 +30,24 @@
import alog

# Local
from ...data_model import ExponentialDecayLengthPenalty
from ...resources.pretrained_model import (
HFAutoCausalLM,
HFAutoSeq2SeqLM,
PretrainedModelBase,
)
from ...toolkit.tgis_utils import TGISGenerationClient
from ...toolkit.text_generation.tgis_utils import (
GENERATE_FUNCTION_TGIS_ARGS,
TGISGenerationClient,
)
from .text_generation_local import TextGeneration

log = alog.use_channel("TXT_GEN")
error = error_handler.get(log)

# pylint: disable=too-many-instance-attributes


@module(backend_type=TGISBackend.backend_type, base_module=TextGeneration)
class TextGenerationTGIS(ModuleBase):
"""Module to provide text generation capabilities"""
Expand Down Expand Up @@ -197,76 +203,98 @@ def run(
self,
text: str,
preserve_input_text: bool = False,
max_new_tokens: int = 20,
min_new_tokens: int = 0,
truncate_input_tokens: int = 0,
max_new_tokens: Optional[int] = 20,
min_new_tokens: Optional[int] = 0,
truncate_input_tokens: Optional[int] = 0,
decoding_method: Optional[str] = "GREEDY",
top_k: Optional[int] = 0,
top_p: Optional[float] = 1.0,
typical_p: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
seed: Optional[int] = None,
repetition_penalty: Optional[float] = 1.0,
max_time: Optional[float] = None,
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
) -> GeneratedTextResult:
"""Run inference against the model running in TGIS.

Args:
text: str
Source string to be encoded for generation.
preserve_input_text: bool
Whether or not the source string should be contained in the generated output,
e.g., as a prefix.
max_new_tokens: int
The maximum numbers of tokens to generate.
Default: 20
min_new_tokens: int
The minimum numbers of tokens to generate.
Default: 0 - means no minimum
truncate_input_tokens: int
Truncate inputs to provided number of tokens. This can be
use to avoid failing due to input being longer than
configured limits.
Default: 0 - means don't truncate, thus throw error.
{}
Returns:
GeneratedTextResult
Generated text result produced by TGIS.
"""
""".format(
GENERATE_FUNCTION_TGIS_ARGS
)

if self._model_loaded:
return self.tgis_generation_client.unary_generate(
text,
preserve_input_text,
max_new_tokens,
min_new_tokens,
truncate_input_tokens,
text=text,
preserve_input_text=preserve_input_text,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
decoding_method=decoding_method,
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
seed=seed,
repetition_penalty=repetition_penalty,
max_time=max_time,
exponential_decay_length_penalty=exponential_decay_length_penalty,
stop_sequences=stop_sequences,
)

@TextGenerationTask.taskmethod(output_streaming=True)
def run_stream_out(
self,
text: str,
preserve_input_text=False,
max_new_tokens=20,
min_new_tokens=0,
truncate_input_tokens=0,
preserve_input_text: bool = False,
max_new_tokens: Optional[int] = 20,
min_new_tokens: Optional[int] = 0,
truncate_input_tokens: Optional[int] = 0,
decoding_method: Optional[str] = "GREEDY",
top_k: Optional[int] = 0,
top_p: Optional[float] = 1.0,
typical_p: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
seed: Optional[int] = None,
repetition_penalty: Optional[float] = 1.0,
max_time: Optional[float] = None,
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
) -> Iterable[GeneratedTextStreamResult]:
"""Run output stream inferencing for text generation module.

Args:
text: str
Source string to be encoded for generation.
preserve_input_text: bool
Whether or not the source string should be contained in the generated output,
e.g., as a prefix.
max_new_tokens: int
Maximum tokens for the model to generate
min_new_tokens: int
Minimum tokens for the model to generate
truncate_input_tokens: int
Truncate inputs to provided number of tokens. This can be
use to avoid failing due to input being longer than
configured limits.
Default: 0 - means don't truncate, thus throw error.
{}
Returns:
Iterable[GeneratedTextStreamResult]
"""
""".format(
GENERATE_FUNCTION_TGIS_ARGS
)

if self._model_loaded:
return self.tgis_generation_client.stream_generate(
text,
preserve_input_text,
max_new_tokens,
min_new_tokens,
truncate_input_tokens,
text=text,
preserve_input_text=preserve_input_text,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
decoding_method=decoding_method,
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
seed=seed,
repetition_penalty=repetition_penalty,
max_time=max_time,
exponential_decay_length_penalty=exponential_decay_length_penalty,
stop_sequences=stop_sequences,
)
Loading