diff --git a/libs/langchain/langchain/llms/sagemaker_endpoint.py b/libs/langchain/langchain/llms/sagemaker_endpoint.py index c7e465c8c1494..a64e58f5db251 100644 --- a/libs/langchain/langchain/llms/sagemaker_endpoint.py +++ b/libs/langchain/langchain/llms/sagemaker_endpoint.py @@ -1,6 +1,8 @@ """Sagemaker InvokeEndpoint API.""" +import io +import json from abc import abstractmethod -from typing import Any, Dict, Generic, List, Mapping, Optional, TypeVar, Union +from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, TypeVar, Union from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM @@ -8,7 +10,66 @@ from langchain.pydantic_v1 import Extra, root_validator INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]]) -OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]]]) +OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]], Iterator]) + + +class LineIterator: + """ + A helper class for parsing the byte stream input. + + The output of the model will be in the following format: + + b'{"outputs": [" a"]}\n' + b'{"outputs": [" challenging"]}\n' + b'{"outputs": [" problem"]}\n' + ... + + While usually each PayloadPart event from the event stream will + contain a byte array with a full json, this is not guaranteed + and some of the json objects may be split acrossPayloadPart events. + + For example: + + {'PayloadPart': {'Bytes': b'{"outputs": '}} + {'PayloadPart': {'Bytes': b'[" problem"]}\n'}} + + + This class accounts for this by concatenating bytes written via the 'write' function + and then exposing a method which will return lines (ending with a '\n' character) + within the buffer via the 'scan_lines' function. + It maintains the position of the last read position to ensure + that previous bytes are not exposed again. + + For more details see: + https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/ + """ + + def __init__(self, stream: Any) -> None: + self.byte_iterator = iter(stream) + self.buffer = io.BytesIO() + self.read_pos = 0 + + def __iter__(self) -> "LineIterator": + return self + + def __next__(self) -> Any: + while True: + self.buffer.seek(self.read_pos) + line = self.buffer.readline() + if line and line[-1] == ord("\n"): + self.read_pos += len(line) + return line[:-1] + try: + chunk = next(self.byte_iterator) + except StopIteration: + if self.read_pos < self.buffer.getbuffer().nbytes: + continue + raise + if "PayloadPart" not in chunk: + # Unknown Event Type + continue + self.buffer.seek(0, io.SEEK_END) + self.buffer.write(chunk["PayloadPart"]["Bytes"]) class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]): @@ -151,6 +212,9 @@ class SagemakerEndpoint(LLM): and the endpoint. """ + streaming: bool = False + """Whether to stream the results.""" + """ Example: .. code-block:: python @@ -264,22 +328,43 @@ def _call( content_type = self.content_handler.content_type accepts = self.content_handler.accepts - # send request - try: - response = self.client.invoke_endpoint( - EndpointName=self.endpoint_name, - Body=body, - ContentType=content_type, - Accept=accepts, - **_endpoint_kwargs, - ) - except Exception as e: - raise ValueError(f"Error raised by inference endpoint: {e}") + if self.streaming and run_manager: + try: + resp = self.client.invoke_endpoint_with_response_stream( + EndpointName=self.endpoint_name, + Body=body, + ContentType=self.content_handler.content_type, + **_endpoint_kwargs, + ) + iterator = LineIterator(resp["Body"]) + current_completion: str = "" + for line in iterator: + resp = json.loads(line) + resp_output = resp.get("outputs")[0] + if stop is not None: + # Uses same approach as below + resp_output = enforce_stop_tokens(resp_output, stop) + current_completion += resp_output + run_manager.on_llm_new_token(resp_output) + return current_completion + except Exception as e: + raise ValueError(f"Error raised by streaming inference endpoint: {e}") + else: + try: + response = self.client.invoke_endpoint( + EndpointName=self.endpoint_name, + Body=body, + ContentType=content_type, + Accept=accepts, + **_endpoint_kwargs, + ) + except Exception as e: + raise ValueError(f"Error raised by inference endpoint: {e}") - text = self.content_handler.transform_output(response["Body"]) - if stop is not None: - # This is a bit hacky, but I can't figure out a better way to enforce - # stop tokens when making calls to the sagemaker endpoint. - text = enforce_stop_tokens(text, stop) + text = self.content_handler.transform_output(response["Body"]) + if stop is not None: + # This is a bit hacky, but I can't figure out a better way to enforce + # stop tokens when making calls to the sagemaker endpoint. + text = enforce_stop_tokens(text, stop) - return text + return text