Skip to content

Commit

Permalink
update pydocs for streamer
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Zhou committed Nov 11, 2024
1 parent 9e62192 commit f5d87b2
Showing 1 changed file with 53 additions and 22 deletions.
75 changes: 53 additions & 22 deletions lmcsc/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,49 @@

class BeamStreamer(BaseStreamer):
"""
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
A streamer class that handles beam search output streaming during text generation.
Warnings:
The API for the streamer classes is still under development and may change in the future.
This class extends BaseStreamer to provide functionality for streaming beam search results,
processing the beam hypotheses, and providing an iterator interface for accessing the generated text.
Notes:
This class only supports batch size 1.
Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.
Examples:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
The tokenizer used to decode the tokens into text.
timeout (`float`, *optional*, defaults to `None`):
The timeout in seconds for queue operations. If None, queue operations block indefinitely.
**decode_kwargs:
Additional keyword arguments passed to the tokenizer's decode method.
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
>>> streamer = TextStreamer(tok)
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
Attributes:
tokenizer (`AutoTokenizer`):
The tokenizer instance used for decoding.
decode_kwargs (`dict`):
Additional arguments for token decoding.
print_len (`int`):
Length of previously printed text.
text_queue (`Queue`):
Queue for storing generated text chunks.
stop_signal:
Signal used to indicate end of stream.
timeout (`float`):
Timeout value for queue operations.
last_text (`str`):
Most recently generated text.
Examples:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> from lmcsc.streamer import BeamStreamer
>>>
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> streamer = BeamStreamer(tokenizer)
>>>
>>> # Stream generated text
>>> for text in streamer:
... print(text)
"""

def __init__(
Expand All @@ -52,7 +70,14 @@ def __init__(

def put(self, value: BeamSearchScorer):
"""
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
Receives tokens, decodes them, and puts the decoded text into the queue.
Args:
value (tuple): A tuple containing (BeamSearchScorer, decoded_text).
The BeamSearchScorer contains beam hypotheses and the decoded_text is a list of token IDs.
Raises:
ValueError: If batch size is greater than 1.
"""
beam_scorer, decoded_text = value
if (len(beam_scorer._beam_hyps) // beam_scorer.num_beam_groups) > 1:
Expand All @@ -78,11 +103,17 @@ def put(self, value: BeamSearchScorer):
self.on_finalized_text(text)

def end(self):
"""Flushes any remaining cache and prints a newline to stdout."""
"""Signals the end of the stream by putting the stop signal in the queue."""
self.on_finalized_text(self.last_text, stream_end=True)

def on_finalized_text(self, text: str, stream_end: bool = False):
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
"""
Puts finalized text into the queue and handles stream end signaling.
Args:
text (str): The text to put in the queue.
stream_end (bool, optional): Whether this is the end of the stream. Defaults to False.
"""
self.text_queue.put(text, timeout=self.timeout)
if stream_end:
self.text_queue.put(self.stop_signal, timeout=self.timeout)
Expand Down

0 comments on commit f5d87b2

Please sign in to comment.