Skip to content

Commit

Permalink
Merge pull request #152 from gkumbhat/flattened_inf_params
Browse files Browse the repository at this point in the history
Flattened inf params
  • Loading branch information
gkumbhat authored Aug 31, 2023
2 parents 77bc04e + 295f9f7 commit 4db0c26
Show file tree
Hide file tree
Showing 5 changed files with 636 additions and 174 deletions.
183 changes: 100 additions & 83 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
AutoConfig,
AutoModelForCausalLM,
DataCollatorForLanguageModeling,
TextStreamer,
default_data_collator,
)
from transformers.models.auto.tokenization_auto import AutoTokenizer
Expand All @@ -59,7 +58,12 @@
import alog

# Local
from ...data_model import GenerationTrainRecord, PromptOutputModelType, TuningConfig
from ...data_model import (
ExponentialDecayLengthPenalty,
GenerationTrainRecord,
PromptOutputModelType,
TuningConfig,
)
from ...resources.pretrained_model import (
HFAutoCausalLM,
HFAutoSeq2SeqLM,
Expand All @@ -68,6 +72,11 @@
from ...toolkit.data_stream_wrapper import SimpleIterableStreamWrapper
from ...toolkit.data_type_utils import get_torch_dtype, str_to_torch_dtype
from ...toolkit.task_specific_utils import convert_to_generation_record
from ...toolkit.text_generation.model_run_utils import (
GENERATE_FUNCTION_ARGS,
generate_text_func,
generate_text_func_stream,
)
from ...toolkit.verbalizer_utils import is_valid_verbalizer, render_verbalizer

log = alog.use_channel("PEFT_PROMPT")
Expand All @@ -94,13 +103,6 @@ class TuningType(str, Enum):
# LORA = "LORA"


class Streamer(TextStreamer):
# The default TextStreamer currently prints to stdout
# so we override that here
def on_finalized_text(self, text: str, stream_end: bool = False):
pass


# TODO: try to refactor this into a smaller module
# pylint: disable=too-many-lines,too-many-instance-attributes
@module(
Expand Down Expand Up @@ -170,53 +172,55 @@ def __del__(self):
def run(
self,
text: str,
device: Optional[Union[str, int]] = None,
max_new_tokens=20,
min_new_tokens=0,
max_new_tokens: Optional[int] = 20,
min_new_tokens: Optional[int] = 0,
truncate_input_tokens: Optional[int] = 0,
decoding_method: Optional[str] = "GREEDY",
top_k: Optional[int] = 0,
top_p: Optional[float] = 1.0,
typical_p: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
seed: Optional[int] = None,
repetition_penalty: Optional[float] = 1.0,
max_time: Optional[float] = None,
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
) -> GeneratedTextResult:
"""Run the full text generation model.
Args:
text: str
Input string to be used to the generation model.
device: Optional[Union[str, int]]
Deprecated. By default, we use the detected device.
max_new_tokens: int
The maximum numbers of tokens to generate.
Default: 20
min_new_tokens: int
The minimum numbers of tokens to generate.
Default: 0 - means no minimum
Returns:
GeneratedTextResult
Generated text result produced by PEFT / Transformers.
"""
if device is not None:
log.warning(
"Specifying device is deprecated and ignored, please update your calling argument"
)
device = self._DETECT_DEVICE
# Apply the verbalizer to our text string
Run the full text generation model.
Args:
{}
Returns:
GeneratedTextResult
Generated text result produced by PEFT / Transformers.
""".format(
GENERATE_FUNCTION_ARGS
)

verbalized_text = render_verbalizer(self.verbalizer, {"input": text})
# Apply the tokenizer to the sample text & move to correct device
tok_tensors = self.tokenizer(verbalized_text, return_tensors="pt")

device = PeftPromptTuning._get_device(device)
inputs = {k: v.to(device) for k, v in tok_tensors.items()}
with torch.no_grad():
# Run tokenized tensors through the rest of the PEFT model
outputs = self.model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
eos_token_id=self.eos_token_id,
)
gen_text = self.tokenizer.batch_decode(
outputs.detach().cpu().numpy(), skip_special_tokens=True
)
return GeneratedTextResult(generated_text=gen_text[0])

return generate_text_func(
self.model,
self.tokenizer,
self.PRODUCER_ID,
self.tokenizer.eos_token,
verbalized_text,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
decoding_method=decoding_method,
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
seed=seed,
repetition_penalty=repetition_penalty,
max_time=max_time,
exponential_decay_length_penalty=exponential_decay_length_penalty,
stop_sequences=stop_sequences,
)

# NOTE: We need to disable wip decorator here otherwise we get issues in
# proto generation for streaming. We are keeping it commented out for now,
Expand All @@ -226,7 +230,23 @@ def run(
# )
@TextGenerationTask.taskmethod(output_streaming=True)
def run_stream_out(
self, text: str, max_new_tokens=20, min_new_tokens=0
self,
text: str,
max_new_tokens=20,
min_new_tokens=0,
truncate_input_tokens: Optional[int] = 0,
decoding_method: Optional[str] = "GREEDY",
top_k: Optional[int] = 0,
top_p: Optional[float] = 0.0,
typical_p: Optional[float] = 0.0,
temperature: Optional[float] = 1.0,
seed: Optional[int] = None,
repetition_penalty: Optional[float] = 0.0,
max_time: Optional[float] = None,
exponential_decay_length_penalty: Optional[
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[str] = None,
) -> Iterable[GeneratedTextStreamResult]:
"""Run the text generation model with output streaming
Expand All @@ -236,40 +256,37 @@ def run_stream_out(
Ref. https://huggingface.co/docs/transformers/v4.30.0/generation_strategies#streaming
Args:
text: str
Input string to be used to the generation model.
max_new_tokens: int
The maximum numbers of tokens to generate.
Default: 20
min_new_tokens: int
The minimum numbers of tokens to generate.
Default: 0 - means no minimum
{}
Returns:
Iterable[GeneratedTextStreamResult]
"""
""".format(
GENERATE_FUNCTION_ARGS
)

# Apply the verbalizer to our text string
verbalized_text = render_verbalizer(self.verbalizer, {"input": text})
# Apply the tokenizer to the sample text & move to correct device
tok_tensors = self.tokenizer(verbalized_text, return_tensors="pt")
inputs = {k: v.to(self.model.device) for k, v in tok_tensors.items()}

streamer = Streamer(self.tokenizer)
with torch.no_grad():
# Run tokenized tensors through the rest of the PEFT model
stream_outputs = self.model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
eos_token_id=self.eos_token_id,
streamer=streamer,
)
for stream_part in stream_outputs:
gen_text = self.tokenizer.batch_decode(
stream_part.detach().cpu().numpy(), skip_special_tokens=True
)
yield GeneratedTextStreamResult(generated_text=gen_text)

return generate_text_func_stream(
self.model,
self.tokenizer,
self.PRODUCER_ID,
self.tokenizer.eos_token,
verbalized_text,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
decoding_method=decoding_method,
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
seed=seed,
repetition_penalty=repetition_penalty,
max_time=max_time,
exponential_decay_length_penalty=exponential_decay_length_penalty,
stop_sequences=stop_sequences,
)

@classmethod
def train(
Expand Down
127 changes: 38 additions & 89 deletions caikit_nlp/modules/text_generation/text_generation_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
)
from ...toolkit.data_stream_wrapper import SimpleIterableStreamWrapper
from ...toolkit.data_type_utils import get_torch_dtype, str_to_torch_dtype
from ...toolkit.text_generation.model_run_utils import (
GENERATE_FUNCTION_ARGS,
generate_text_func,
)

log = alog.use_channel("TXT_GEN")
error = error_handler.get(log)
Expand Down Expand Up @@ -380,106 +384,51 @@ def save(self, model_path):
def run(
self,
text: str,
repetition_penalty: float = 2.5,
length_penalty: float = 1.0,
early_stopping: bool = True,
num_beams: int = 1,
max_new_tokens: int = 20,
min_new_tokens: int = 0,
truncate_input_tokens: int = 0,
max_new_tokens: Optional[int] = 20,
min_new_tokens: Optional[int] = 0,
truncate_input_tokens: Optional[int] = 0,
decoding_method: Optional[str] = "GREEDY",
top_k: Optional[int] = 0,
top_p: Optional[float] = 0.0,
typical_p: Optional[float] = 0.0,
temperature: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 0.0,
max_time: Optional[float] = None,
**kwargs,
) -> "GeneratedTextResult":
"""Run inference against the model running in TGIS.
) -> GeneratedTextResult:

Args:
text: str
Source string to be encoded for generation.
repetition_penalty: float
The parameter for repetition penalty. 1.0 means no penalty.
Default: 2.5
length_penalty: float
Exponential penalty to the length that is used with beam-based generation.
It is applied as an exponent to the sequence length, \
which is used to divide the score of the sequence.
Since the score is the log likelihood of the sequence (i.e. negative), \
length_penalty > 0.0 promotes longer sequences, \
while length_penalty < 0.0 encourages shorter sequences.
Default: 1.0.
early_stopping: bool
Controls the stopping condition for beam-based methods, like beam-search.
It accepts the following values:
True, where the generation stops as soon as there are num_beams complete candidates;
False, where an heuristic is applied and the generation stops when \
is it very unlikely to find better candidates;
"never", where the beam search procedure only stops \
when there cannot be better candidates (canonical beam search algorithm).
num_beams: int
Number of beams for beam search. 1 means no beam search.
Default: 1
max_new_tokens: int
The maximum numbers of tokens to generate.
Default: 20
min_new_tokens: int
The minimum numbers of tokens to generate.
Default: 0 - means no minimum
truncate_input_tokens: int
Truncate inputs to provided number of tokens. This can be
use to avoid failing due to input being longer than
configured limits.
Default: 0 - means don't truncate, thus throw error.
kwargs:
Any other parameters to pass to generate as specified in GenerationConfig.
https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/text_generation#transformers.GenerationConfig
Returns:
GeneratedTextResult
Generated text result produced by the model.
"""
Run the full text generation model.
Args:
{}
Returns:
GeneratedTextResult
Generated text result produced by the model.
""".format(
GENERATE_FUNCTION_ARGS
)

# NOTE: below is to match TGIS API, where 0 identifies as no truncation
if truncate_input_tokens == 0:
# NOTE: below will make model throw error in case inputs are longer
# than allowed length
truncation = False

else:
truncation = True
# TODO: Beam search currently not supported

inputs = self.model.tokenizer(
return generate_text_func(
# Pass HF model
self.model.model,
self.model.tokenizer,
self.PRODUCER_ID,
self._eos_token,
text,
truncation=truncation,
max_length=truncate_input_tokens,
return_tensors="pt",
)
generate_ids = self.model.model.generate(
input_ids=inputs["input_ids"],
num_beams=num_beams,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
decoding_method=decoding_method,
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
early_stopping=early_stopping,
use_cache=True,
max_time=max_time,
**kwargs,
)
token_count = generate_ids.size(1) - 1
preds = [
self.model.tokenizer.decode(
g, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
for g in generate_ids
]
if generate_ids[0][-1].item() == self._eos_token:
finish_reason = "EOS_TOKEN"
elif generate_ids.size(1) - 1 == max_new_tokens:
finish_reason = "MAX_TOKENS"
else:
finish_reason = "OTHER"
return GeneratedTextResult(
generated_tokens=token_count,
generated_text=preds[0],
finish_reason=finish_reason,
producer_id=self.PRODUCER_ID,
)

################################## Private Functions ######################################

Expand Down
Loading

0 comments on commit 4db0c26

Please sign in to comment.