From 763bf99a86422a474b1115c61c95dae00acaeb28 Mon Sep 17 00:00:00 2001 From: Juan Daza Date: Thu, 7 Sep 2023 07:17:48 +0800 Subject: [PATCH 1/4] Added Initial Version for SageMaker Streaming based on AWS Blog Post --- .../langchain/llms/sagemaker_endpoint.py | 123 +++++++++++++++--- 1 file changed, 104 insertions(+), 19 deletions(-) diff --git a/libs/langchain/langchain/llms/sagemaker_endpoint.py b/libs/langchain/langchain/llms/sagemaker_endpoint.py index e0383552b2f25..3470399852957 100644 --- a/libs/langchain/langchain/llms/sagemaker_endpoint.py +++ b/libs/langchain/langchain/llms/sagemaker_endpoint.py @@ -1,6 +1,8 @@ """Sagemaker InvokeEndpoint API.""" +import io +import json from abc import abstractmethod -from typing import Any, Dict, Generic, List, Mapping, Optional, TypeVar, Union +from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, TypeVar, Union from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM @@ -8,7 +10,66 @@ from langchain.pydantic_v1 import Extra, root_validator INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]]) -OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]]]) +OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]], Iterator]) + + +class LineIterator: + """ + A helper class for parsing the byte stream input. + + The output of the model will be in the following format: + + b'{"outputs": [" a"]}\n' + b'{"outputs": [" challenging"]}\n' + b'{"outputs": [" problem"]}\n' + ... + + While usually each PayloadPart event from the event stream will + contain a byte array with a full json, this is not guaranteed + and some of the json objects may be split acrossPayloadPart events. + + For example: + + {'PayloadPart': {'Bytes': b'{"outputs": '}} + {'PayloadPart': {'Bytes': b'[" problem"]}\n'}} + + + This class accounts for this by concatenating bytes written via the 'write' function + and then exposing a method which will return lines (ending with a '\n' character) + within the buffer via the 'scan_lines' function. + It maintains the position of the last read position to ensure + that previous bytes are not exposed again. + + For more details see: + https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/ + """ + + def _init_(self, stream): + self.byte_iterator = iter(stream) + self.buffer = io.BytesIO() + self.read_pos = 0 + + def _iter_(self): + return self + + def _next_(self): + while True: + self.buffer.seek(self.read_pos) + line = self.buffer.readline() + if line and line[-1] == ord("\n"): + self.read_pos += len(line) + return line[:-1] + try: + chunk = next(self.byte_iterator) + except StopIteration: + if self.read_pos < self.buffer.getbuffer().nbytes: + continue + raise + if "PayloadPart" not in chunk: + # Unknown Event Type + continue + self.buffer.seek(0, io.SEEK_END) + self.buffer.write(chunk["PayloadPart"]["Bytes"]) class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]): @@ -122,6 +183,9 @@ class SagemakerEndpoint(LLM): and the endpoint. """ + streaming: bool = False + """Whether to stream the results.""" + """ Example: .. code-block:: python @@ -231,22 +295,43 @@ def _call( content_type = self.content_handler.content_type accepts = self.content_handler.accepts - # send request - try: - response = self.client.invoke_endpoint( - EndpointName=self.endpoint_name, - Body=body, - ContentType=content_type, - Accept=accepts, - **_endpoint_kwargs, - ) - except Exception as e: - raise ValueError(f"Error raised by inference endpoint: {e}") + if self.streaming and run_manager: + try: + resp = self.client.invoke_endpoint_with_response_stream( + EndpointName=self.endpoint_name, + Body=body, + ContentType=self.content_handler.content_type, + **_endpoint_kwargs, + ) + iterator = LineIterator(resp["Body"]) + current_completion: str = "" + for line in iterator: + resp = json.loads(line) + resp_output = resp.get("outputs")[0] + if stop is not None: + # Uses same approach as below + resp_output = enforce_stop_tokens(resp_output, stop) + current_completion += resp_output + run_manager.on_llm_new_token(resp_output) + return current_completion + except Exception as e: + raise ValueError(f"Error raised by streaming inference endpoint: {e}") + else: + try: + response = self.client.invoke_endpoint( + EndpointName=self.endpoint_name, + Body=body, + ContentType=content_type, + Accept=accepts, + **_endpoint_kwargs, + ) + except Exception as e: + raise ValueError(f"Error raised by inference endpoint: {e}") - text = self.content_handler.transform_output(response["Body"]) - if stop is not None: - # This is a bit hacky, but I can't figure out a better way to enforce - # stop tokens when making calls to the sagemaker endpoint. - text = enforce_stop_tokens(text, stop) + text = self.content_handler.transform_output(response["Body"]) + if stop is not None: + # This is a bit hacky, but I can't figure out a better way to enforce + # stop tokens when making calls to the sagemaker endpoint. + text = enforce_stop_tokens(text, stop) - return text + return text \ No newline at end of file From dac2c4762322ea1bfe2b44fd0383b2815a9d26cd Mon Sep 17 00:00:00 2001 From: Juan Daza Date: Thu, 14 Sep 2023 07:30:29 +0800 Subject: [PATCH 2/4] added formatting fix using black --- libs/langchain/langchain/llms/sagemaker_endpoint.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/libs/langchain/langchain/llms/sagemaker_endpoint.py b/libs/langchain/langchain/llms/sagemaker_endpoint.py index 3470399852957..3c414ab5a0480 100644 --- a/libs/langchain/langchain/llms/sagemaker_endpoint.py +++ b/libs/langchain/langchain/llms/sagemaker_endpoint.py @@ -18,21 +18,21 @@ class LineIterator: A helper class for parsing the byte stream input. The output of the model will be in the following format: - + b'{"outputs": [" a"]}\n' b'{"outputs": [" challenging"]}\n' b'{"outputs": [" problem"]}\n' ... - + While usually each PayloadPart event from the event stream will contain a byte array with a full json, this is not guaranteed and some of the json objects may be split acrossPayloadPart events. For example: - + {'PayloadPart': {'Bytes': b'{"outputs": '}} {'PayloadPart': {'Bytes': b'[" problem"]}\n'}} - + This class accounts for this by concatenating bytes written via the 'write' function and then exposing a method which will return lines (ending with a '\n' character) @@ -334,4 +334,4 @@ def _call( # stop tokens when making calls to the sagemaker endpoint. text = enforce_stop_tokens(text, stop) - return text \ No newline at end of file + return text From 7a0a46da9f9184218ca8f0ecbc380eec3acc77a8 Mon Sep 17 00:00:00 2001 From: Juan Daza Date: Fri, 22 Sep 2023 11:39:25 +0800 Subject: [PATCH 3/4] added first function fix and underscore fix --- libs/langchain/langchain/llms/sagemaker_endpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/llms/sagemaker_endpoint.py b/libs/langchain/langchain/llms/sagemaker_endpoint.py index 3c414ab5a0480..6fa74b4440d67 100644 --- a/libs/langchain/langchain/llms/sagemaker_endpoint.py +++ b/libs/langchain/langchain/llms/sagemaker_endpoint.py @@ -44,15 +44,15 @@ class LineIterator: https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/ """ - def _init_(self, stream): + def __init__(self, stream) -> None: self.byte_iterator = iter(stream) self.buffer = io.BytesIO() self.read_pos = 0 - def _iter_(self): + def __iter__(self): return self - def _next_(self): + def __next__(self): while True: self.buffer.seek(self.read_pos) line = self.buffer.readline() From 1e1238f714ef30bdfe641603daaa9f545640be29 Mon Sep 17 00:00:00 2001 From: Juan Daza Date: Mon, 25 Sep 2023 07:10:55 +0800 Subject: [PATCH 4/4] Added mypy validation rules --- libs/langchain/langchain/llms/sagemaker_endpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/llms/sagemaker_endpoint.py b/libs/langchain/langchain/llms/sagemaker_endpoint.py index 6fa74b4440d67..eb33ca007c741 100644 --- a/libs/langchain/langchain/llms/sagemaker_endpoint.py +++ b/libs/langchain/langchain/llms/sagemaker_endpoint.py @@ -44,15 +44,15 @@ class LineIterator: https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/ """ - def __init__(self, stream) -> None: + def __init__(self, stream: Any) -> None: self.byte_iterator = iter(stream) self.buffer = io.BytesIO() self.read_pos = 0 - def __iter__(self): + def __iter__(self) -> "LineIterator": return self - def __next__(self): + def __next__(self) -> Any: while True: self.buffer.seek(self.read_pos) line = self.buffer.readline()