Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Streaming Capability to SageMaker LLMs #10535

Merged
merged 10 commits into from
Oct 5, 2023
123 changes: 104 additions & 19 deletions libs/langchain/langchain/llms/sagemaker_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,75 @@
"""Sagemaker InvokeEndpoint API."""
import io
import json
from abc import abstractmethod
from typing import Any, Dict, Generic, List, Mapping, Optional, TypeVar, Union
from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, TypeVar, Union

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import Extra, root_validator

INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]])
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]]])
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]], Iterator])


class LineIterator:
"""
A helper class for parsing the byte stream input.

The output of the model will be in the following format:

b'{"outputs": [" a"]}\n'
b'{"outputs": [" challenging"]}\n'
b'{"outputs": [" problem"]}\n'
...

While usually each PayloadPart event from the event stream will
contain a byte array with a full json, this is not guaranteed
and some of the json objects may be split acrossPayloadPart events.

For example:

{'PayloadPart': {'Bytes': b'{"outputs": '}}
{'PayloadPart': {'Bytes': b'[" problem"]}\n'}}


This class accounts for this by concatenating bytes written via the 'write' function
and then exposing a method which will return lines (ending with a '\n' character)
within the buffer via the 'scan_lines' function.
It maintains the position of the last read position to ensure
that previous bytes are not exposed again.

For more details see:
https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/
"""

def __init__(self, stream: Any) -> None:
self.byte_iterator = iter(stream)
self.buffer = io.BytesIO()
self.read_pos = 0

def __iter__(self) -> "LineIterator":
return self

def __next__(self) -> Any:
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line)
return line[:-1]
try:
chunk = next(self.byte_iterator)
except StopIteration:
if self.read_pos < self.buffer.getbuffer().nbytes:
continue
raise
if "PayloadPart" not in chunk:
# Unknown Event Type
continue
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])


class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]):
Expand Down Expand Up @@ -151,6 +212,9 @@ class SagemakerEndpoint(LLM):
and the endpoint.
"""

streaming: bool = False
"""Whether to stream the results."""

"""
Example:
.. code-block:: python
Expand Down Expand Up @@ -264,22 +328,43 @@ def _call(
content_type = self.content_handler.content_type
accepts = self.content_handler.accepts

# send request
try:
response = self.client.invoke_endpoint(
EndpointName=self.endpoint_name,
Body=body,
ContentType=content_type,
Accept=accepts,
**_endpoint_kwargs,
)
except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
if self.streaming and run_manager:
try:
resp = self.client.invoke_endpoint_with_response_stream(
EndpointName=self.endpoint_name,
Body=body,
ContentType=self.content_handler.content_type,
**_endpoint_kwargs,
)
iterator = LineIterator(resp["Body"])
current_completion: str = ""
for line in iterator:
resp = json.loads(line)
resp_output = resp.get("outputs")[0]
if stop is not None:
# Uses same approach as below
resp_output = enforce_stop_tokens(resp_output, stop)
current_completion += resp_output
run_manager.on_llm_new_token(resp_output)
return current_completion
except Exception as e:
raise ValueError(f"Error raised by streaming inference endpoint: {e}")
else:
try:
response = self.client.invoke_endpoint(
EndpointName=self.endpoint_name,
Body=body,
ContentType=content_type,
Accept=accepts,
**_endpoint_kwargs,
)
except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}")

text = self.content_handler.transform_output(response["Body"])
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to the sagemaker endpoint.
text = enforce_stop_tokens(text, stop)
text = self.content_handler.transform_output(response["Body"])
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to the sagemaker endpoint.
text = enforce_stop_tokens(text, stop)

return text
return text
Loading