Skip to content

Commit

Permalink
Merge pull request #2593 from langchain-ai/nc/2dec/sdk-sse
Browse files Browse the repository at this point in the history
sdk-py: Fix SSE parsing to split lines only \n \r , remove httpx-sse, fix missing decoder flush
  • Loading branch information
nfcampos authored Dec 3, 2024
2 parents 2d87195 + 3bf92d0 commit dd010e9
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 30 deletions.
55 changes: 38 additions & 17 deletions libs/sdk-py/langgraph_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
)

import httpx
import httpx_sse
import orjson
from httpx._types import QueryParamTypes

Expand Down Expand Up @@ -61,6 +60,7 @@
ThreadStatus,
ThreadUpdateStateResponse,
)
from langgraph_sdk.sse import SSEDecoder, aiter_lines_raw, iter_lines_raw

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -282,22 +282,35 @@ async def stream(
) -> AsyncIterator[StreamPart]:
"""Stream results using SSE."""
headers, content = await aencode_json(json)
async with httpx_sse.aconnect_sse(
self.client, method, path, headers=headers, content=content
) as sse:
headers["Accept"] = "text/event-stream"
headers["Cache-Control"] = "no-store"

async with self.client.stream(
method, path, headers=headers, content=content
) as res:
# check status
try:
sse.response.raise_for_status()
res.raise_for_status()
except httpx.HTTPStatusError as e:
body = (await sse.response.aread()).decode()
body = (await res.aread()).decode()
if sys.version_info >= (3, 11):
e.add_note(body)
else:
logger.error(f"Error from langgraph-api: {body}", exc_info=e)
raise e
async for event in sse.aiter_sse():
yield StreamPart(
event.event, orjson.loads(event.data) if event.data else None
# check content type
content_type = res.headers.get("content-type", "").partition(";")[0]
if "text/event-stream" not in content_type:
raise httpx.TransportError(
"Expected response header Content-Type to contain 'text/event-stream', "
f"got {content_type!r}"
)
# parse SSE
decoder = SSEDecoder()
async for line in aiter_lines_raw(res):
sse = decoder.decode(line=line.rstrip(b"\n"))
if sse is not None:
yield sse


async def aencode_json(json: Any) -> tuple[dict[str, str], bytes]:
Expand Down Expand Up @@ -2438,22 +2451,30 @@ def stream(
) -> Iterator[StreamPart]:
"""Stream the results of a request using SSE."""
headers, content = encode_json(json)
with httpx_sse.connect_sse(
self.client, method, path, headers=headers, content=content
) as sse:
with self.client.stream(method, path, headers=headers, content=content) as res:
# check status
try:
sse.response.raise_for_status()
res.raise_for_status()
except httpx.HTTPStatusError as e:
body = sse.response.read().decode()
body = (res.read()).decode()
if sys.version_info >= (3, 11):
e.add_note(body)
else:
logger.error(f"Error from langgraph-api: {body}", exc_info=e)
raise e
for event in sse.iter_sse():
yield StreamPart(
event.event, orjson.loads(event.data) if event.data else None
# check content type
content_type = res.headers.get("content-type", "").partition(";")[0]
if "text/event-stream" not in content_type:
raise httpx.TransportError(
"Expected response header Content-Type to contain 'text/event-stream', "
f"got {content_type!r}"
)
# parse SSE
decoder = SSEDecoder()
for line in iter_lines_raw(res):
sse = decoder.decode(line.rstrip(b"\n"))
if sse is not None:
yield sse


def encode_json(json: Any) -> tuple[dict[str, str], bytes]:
Expand Down
148 changes: 148 additions & 0 deletions libs/sdk-py/langgraph_sdk/sse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""Adapted from httpx_sse to split lines on \n, \r, \r\n per the SSE spec."""

from typing import AsyncIterator, Iterator, Optional, Union

import httpx
import orjson

from langgraph_sdk.schema import StreamPart

BytesLike = Union[bytes, bytearray, memoryview]


class BytesLineDecoder:
"""
Handles incrementally reading lines from text.
Has the same behaviour as the stdllib bytes splitlines,
but handling the input iteratively.
"""

def __init__(self) -> None:
self.buffer = bytearray()
self.trailing_cr: bool = False

def decode(self, text: bytes) -> list[BytesLike]:
# See https://docs.python.org/3/glossary.html#term-universal-newlines
NEWLINE_CHARS = b"\n\r"

# We always push a trailing `\r` into the next decode iteration.
if self.trailing_cr:
text = b"\r" + text
self.trailing_cr = False
if text.endswith(b"\r"):
self.trailing_cr = True
text = text[:-1]

if not text:
# NOTE: the edge case input of empty text doesn't occur in practice,
# because other httpx internals filter out this value
return [] # pragma: no cover

trailing_newline = text[-1] in NEWLINE_CHARS
lines = text.splitlines()

if len(lines) == 1 and not trailing_newline:
# No new lines, buffer the input and continue.
self.buffer.extend(lines[0])
return []

if self.buffer:
# Include any existing buffer in the first portion of the
# splitlines result.
self.buffer.extend(lines[0])
lines = [self.buffer] + lines[1:]
self.buffer = bytearray()

if not trailing_newline:
# If the last segment of splitlines is not newline terminated,
# then drop it from our output and start a new buffer.
self.buffer.extend(lines.pop())

return lines

def flush(self) -> list[BytesLike]:
if not self.buffer and not self.trailing_cr:
return []

lines = [self.buffer]
self.buffer = bytearray()
self.trailing_cr = False
return lines


class SSEDecoder:
def __init__(self) -> None:
self._event = ""
self._data = bytearray()
self._last_event_id = ""
self._retry: Optional[int] = None

def decode(self, line: bytes) -> Optional[StreamPart]:
# See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501

if not line:
if (
not self._event
and not self._data
and not self._last_event_id
and self._retry is None
):
return None

sse = StreamPart(
event=self._event,
data=orjson.loads(self._data) if self._data else None,
)

# NOTE: as per the SSE spec, do not reset last_event_id.
self._event = ""
self._data = bytearray()
self._retry = None

return sse

if line.startswith(b":"):
return None

fieldname, _, value = line.partition(b":")

if value.startswith(b" "):
value = value[1:]

if fieldname == b"event":
self._event = value.decode()
elif fieldname == b"data":
self._data.extend(value)
elif fieldname == b"id":
if b"\0" in value:
pass
else:
self._last_event_id = value.decode()
elif fieldname == b"retry":
try:
self._retry = int(value)
except (TypeError, ValueError):
pass
else:
pass # Field is ignored.

return None


async def aiter_lines_raw(response: httpx.Response) -> AsyncIterator[BytesLike]:
decoder = BytesLineDecoder()
async for chunk in response.aiter_bytes():
for line in decoder.decode(chunk):
yield line
for line in decoder.flush():
yield line


def iter_lines_raw(response: httpx.Response) -> Iterator[BytesLike]:
decoder = BytesLineDecoder()
for chunk in response.iter_bytes():
for line in decoder.decode(chunk):
yield line
for line in decoder.flush():
yield line
13 changes: 1 addition & 12 deletions libs/sdk-py/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion libs/sdk-py/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ packages = [{ include = "langgraph_sdk" }]
[tool.poetry.dependencies]
python = "^3.9.0,<4.0"
httpx = ">=0.25.2"
httpx-sse = ">=0.4.0"
orjson = ">=3.10.1"

[tool.poetry.group.dev.dependencies]
Expand Down

0 comments on commit dd010e9

Please sign in to comment.