Skip to content

Commit

Permalink
refactor(internal): move to Stream and AsyncStream classes for st…
Browse files Browse the repository at this point in the history
…reaming

refactor(internal): move to `Stream` and `AsyncStream` classes for streaming
  • Loading branch information
stainless-bot authored Mar 17, 2023
1 parent 12eb383 commit e4da92e
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 63 deletions.
4 changes: 4 additions & 0 deletions src/lithic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
ENVIRONMENTS,
Client,
Lithic,
Stream,
Timeout,
Transport,
AsyncClient,
AsyncLithic,
AsyncStream,
ProxiesTypes,
RequestOptions,
)
Expand Down Expand Up @@ -55,6 +57,8 @@
"RequestOptions",
"Client",
"AsyncClient",
"Stream",
"AsyncStream",
"Lithic",
"AsyncLithic",
"ENVIRONMENTS",
Expand Down
155 changes: 100 additions & 55 deletions src/lithic/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,88 @@ class StopStreaming(Exception):
"""Raised internally when processing of a streamed response should be stopped."""


class Stream(Generic[ResponseT]):
response: httpx.Response

def __init__(
self,
*,
cast_to: type[ResponseT],
response: httpx.Response,
client: SyncAPIClient,
) -> None:
self.response = response
self._cast_to = cast_to
self._client = client
self._iterator = self.__iter()

def __next__(self) -> ResponseT:
return self._iterator.__next__()

def __iter__(self) -> Iterator[ResponseT]:
for item in self._iterator:
yield item

def __iter(self) -> Iterator[ResponseT]:
cast_to = self._cast_to
response = self.response
process_line = self._client._process_stream_line
process_data = self._client._process_response_data

for raw_line in response.iter_lines():
if not raw_line or raw_line == "\n":
continue

try:
line = process_line(raw_line)
except StopStreaming:
# we are done!
break

yield process_data(data=json.loads(line), cast_to=cast_to, response=response)


class AsyncStream(Generic[ResponseT]):
response: httpx.Response

def __init__(
self,
*,
cast_to: type[ResponseT],
response: httpx.Response,
client: AsyncAPIClient,
) -> None:
self.response = response
self._cast_to = cast_to
self._client = client
self._iterator = self.__iter()

async def __anext__(self) -> ResponseT:
return await self._iterator.__anext__()

async def __aiter__(self) -> AsyncIterator[ResponseT]:
async for item in self._iterator:
yield item

async def __iter(self) -> AsyncIterator[ResponseT]:
cast_to = self._cast_to
response = self.response
process_line = self._client._process_stream_line
process_data = self._client._process_response_data

async for raw_line in response.aiter_lines():
if not raw_line or raw_line == "\n":
continue

try:
line = process_line(raw_line)
except StopStreaming:
# we are done!
break

yield process_data(data=json.loads(line), cast_to=cast_to, response=response)


class PageInfo:
"""Stores the necesary information to build the request to retrieve the next page.
Expand Down Expand Up @@ -526,7 +608,6 @@ def _process_response_data(

return cast(ResponseT, construct_type(type_=cast_to, value=data))

# TODO: make the constants in here configurable
def _process_stream_line(self, contents: str) -> str:
"""Pre-process an indiviudal line from a streaming response"""
if contents == "data: [DONE]\n":
Expand Down Expand Up @@ -690,7 +771,7 @@ def request(
remaining_retries: Optional[int] = None,
*,
stream: Literal[True],
) -> Iterator[ResponseT]:
) -> Stream[ResponseT]:
...

@overload
Expand All @@ -712,7 +793,7 @@ def request(
remaining_retries: Optional[int] = None,
*,
stream: bool = False,
) -> ResponseT | Iterator[ResponseT]:
) -> ResponseT | Stream[ResponseT]:
...

def request(
Expand All @@ -722,7 +803,7 @@ def request(
remaining_retries: Optional[int] = None,
*,
stream: bool = False,
) -> ResponseT | Iterator[ResponseT]:
) -> ResponseT | Stream[ResponseT]:
return self._request(
cast_to=cast_to,
options=options,
Expand All @@ -737,7 +818,7 @@ def _request(
options: FinalRequestOptions,
remaining_retries: int | None,
stream: bool,
) -> ResponseT | Iterator[ResponseT]:
) -> ResponseT | Stream[ResponseT]:
retries = self._remaining_retries(remaining_retries, options)
request = self._build_request(options)

Expand All @@ -762,7 +843,7 @@ def _request(
raise APIConnectionError(request=request) from err

if stream:
return self._process_stream_response(cast_to=cast_to, response=response)
return Stream(cast_to=cast_to, response=response, client=self)

try:
rsp = self._process_response(cast_to=cast_to, options=options, response=response)
Expand All @@ -779,7 +860,7 @@ def _retry_request(
response_headers: Optional[httpx.Headers] = None,
*,
stream: bool,
) -> ResponseT | Iterator[ResponseT]:
) -> ResponseT | Stream[ResponseT]:
remaining = remaining_retries - 1
timeout = self._calculate_retry_timeout(remaining, options, response_headers)

Expand All @@ -794,24 +875,6 @@ def _retry_request(
stream=stream,
)

def _process_stream_response(
self,
*,
cast_to: Type[ResponseT],
response: httpx.Response,
) -> Iterator[ResponseT]:
for raw_line in response.iter_lines():
if not raw_line or raw_line == "\n":
continue

try:
line = self._process_stream_line(raw_line)
except StopStreaming:
# we are done!
break

yield self._process_response_data(data=json.loads(line), cast_to=cast_to, response=response)

def _request_api_list(
self,
model: Type[ModelT],
Expand Down Expand Up @@ -861,7 +924,7 @@ def post(
options: RequestOptions = {},
files: RequestFiles | None = None,
stream: Literal[True],
) -> Iterator[ResponseT]:
) -> Stream[ResponseT]:
...

@overload
Expand All @@ -874,7 +937,7 @@ def post(
options: RequestOptions = {},
files: RequestFiles | None = None,
stream: bool,
) -> ResponseT | Iterator[ResponseT]:
) -> ResponseT | Stream[ResponseT]:
...

def post(
Expand All @@ -886,7 +949,7 @@ def post(
options: RequestOptions = {},
files: RequestFiles | None = None,
stream: bool = False,
) -> ResponseT | Iterator[ResponseT]:
) -> ResponseT | Stream[ResponseT]:
opts = FinalRequestOptions.construct(method="post", url=path, json_data=body, files=files, **options)
return cast(ResponseT, self.request(cast_to, opts, stream=stream))

Expand Down Expand Up @@ -993,7 +1056,7 @@ async def request(
*,
stream: Literal[True],
remaining_retries: Optional[int] = None,
) -> AsyncIterator[ResponseT]:
) -> AsyncStream[ResponseT]:
...

@overload
Expand All @@ -1004,7 +1067,7 @@ async def request(
*,
stream: bool,
remaining_retries: Optional[int] = None,
) -> ResponseT | AsyncIterator[ResponseT]:
) -> ResponseT | AsyncStream[ResponseT]:
...

async def request(
Expand All @@ -1014,7 +1077,7 @@ async def request(
*,
stream: bool = False,
remaining_retries: Optional[int] = None,
) -> ResponseT | AsyncIterator[ResponseT]:
) -> ResponseT | AsyncStream[ResponseT]:
return await self._request(
cast_to=cast_to,
options=options,
Expand All @@ -1029,7 +1092,7 @@ async def _request(
*,
stream: bool,
remaining_retries: int | None,
) -> ResponseT | AsyncIterator[ResponseT]:
) -> ResponseT | AsyncStream[ResponseT]:
retries = self._remaining_retries(remaining_retries, options)
request = self._build_request(options)

Expand Down Expand Up @@ -1064,7 +1127,7 @@ async def _request(
raise APIConnectionError(request=request) from err

if stream:
return self._process_stream_response(cast_to=cast_to, response=response)
return AsyncStream(cast_to=cast_to, response=response, client=self)

try:
rsp = self._process_response(cast_to=cast_to, options=options, response=response)
Expand All @@ -1081,7 +1144,7 @@ async def _retry_request(
response_headers: Optional[httpx.Headers] = None,
*,
stream: bool,
) -> ResponseT | AsyncIterator[ResponseT]:
) -> ResponseT | AsyncStream[ResponseT]:
remaining = remaining_retries - 1
timeout = self._calculate_retry_timeout(remaining, options, response_headers)

Expand All @@ -1094,24 +1157,6 @@ async def _retry_request(
stream=stream,
)

async def _process_stream_response(
self,
*,
cast_to: Type[ResponseT],
response: httpx.Response,
) -> AsyncIterator[ResponseT]:
async for raw_line in response.aiter_lines():
if not raw_line or raw_line == "\n":
continue

try:
line = self._process_stream_line(raw_line)
except StopStreaming:
# we are done!
break

yield self._process_response_data(data=json.loads(line), cast_to=cast_to, response=response)

def _request_api_list(
self,
model: Type[ModelT],
Expand Down Expand Up @@ -1153,7 +1198,7 @@ async def post(
files: RequestFiles | None = None,
options: RequestOptions = {},
stream: Literal[True],
) -> AsyncIterator[ResponseT]:
) -> AsyncStream[ResponseT]:
...

@overload
Expand All @@ -1166,7 +1211,7 @@ async def post(
files: RequestFiles | None = None,
options: RequestOptions = {},
stream: bool,
) -> ResponseT | AsyncIterator[ResponseT]:
) -> ResponseT | AsyncStream[ResponseT]:
...

async def post(
Expand All @@ -1178,7 +1223,7 @@ async def post(
files: RequestFiles | None = None,
options: RequestOptions = {},
stream: bool = False,
) -> ResponseT | AsyncIterator[ResponseT]:
) -> ResponseT | AsyncStream[ResponseT]:
opts = FinalRequestOptions.construct(method="post", url=path, json_data=body, files=files, **options)
return await self.request(cast_to, opts, stream=stream)

Expand Down
12 changes: 4 additions & 8 deletions src/lithic/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,10 @@
RequestOptions,
)
from ._version import __version__
from ._base_client import (
DEFAULT_LIMITS,
DEFAULT_TIMEOUT,
DEFAULT_MAX_RETRIES,
SyncAPIClient,
AsyncAPIClient,
make_request_options,
)
from ._base_client import DEFAULT_LIMITS, DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES
from ._base_client import Stream as Stream
from ._base_client import AsyncStream as AsyncStream
from ._base_client import SyncAPIClient, AsyncAPIClient, make_request_options

__all__ = [
"ENVIRONMENTS",
Expand Down

0 comments on commit e4da92e

Please sign in to comment.