From e4da92e38d6a79c92000762a4cbf731d9f32e980 Mon Sep 17 00:00:00 2001 From: Stainless Bot <107565488+stainless-bot@users.noreply.github.com> Date: Fri, 17 Mar 2023 15:58:53 -0700 Subject: [PATCH] refactor(internal): move to `Stream` and `AsyncStream` classes for streaming refactor(internal): move to `Stream` and `AsyncStream` classes for streaming --- src/lithic/__init__.py | 4 + src/lithic/_base_client.py | 155 ++++++++++++++++++++++++------------- src/lithic/_client.py | 12 +-- 3 files changed, 108 insertions(+), 63 deletions(-) diff --git a/src/lithic/__init__.py b/src/lithic/__init__.py index 852c5f02..91d59b00 100644 --- a/src/lithic/__init__.py +++ b/src/lithic/__init__.py @@ -7,10 +7,12 @@ ENVIRONMENTS, Client, Lithic, + Stream, Timeout, Transport, AsyncClient, AsyncLithic, + AsyncStream, ProxiesTypes, RequestOptions, ) @@ -55,6 +57,8 @@ "RequestOptions", "Client", "AsyncClient", + "Stream", + "AsyncStream", "Lithic", "AsyncLithic", "ENVIRONMENTS", diff --git a/src/lithic/_base_client.py b/src/lithic/_base_client.py index 7cee090a..5a40cbb3 100644 --- a/src/lithic/_base_client.py +++ b/src/lithic/_base_client.py @@ -99,6 +99,88 @@ class StopStreaming(Exception): """Raised internally when processing of a streamed response should be stopped.""" +class Stream(Generic[ResponseT]): + response: httpx.Response + + def __init__( + self, + *, + cast_to: type[ResponseT], + response: httpx.Response, + client: SyncAPIClient, + ) -> None: + self.response = response + self._cast_to = cast_to + self._client = client + self._iterator = self.__iter() + + def __next__(self) -> ResponseT: + return self._iterator.__next__() + + def __iter__(self) -> Iterator[ResponseT]: + for item in self._iterator: + yield item + + def __iter(self) -> Iterator[ResponseT]: + cast_to = self._cast_to + response = self.response + process_line = self._client._process_stream_line + process_data = self._client._process_response_data + + for raw_line in response.iter_lines(): + if not raw_line or raw_line == "\n": + continue + + try: + line = process_line(raw_line) + except StopStreaming: + # we are done! + break + + yield process_data(data=json.loads(line), cast_to=cast_to, response=response) + + +class AsyncStream(Generic[ResponseT]): + response: httpx.Response + + def __init__( + self, + *, + cast_to: type[ResponseT], + response: httpx.Response, + client: AsyncAPIClient, + ) -> None: + self.response = response + self._cast_to = cast_to + self._client = client + self._iterator = self.__iter() + + async def __anext__(self) -> ResponseT: + return await self._iterator.__anext__() + + async def __aiter__(self) -> AsyncIterator[ResponseT]: + async for item in self._iterator: + yield item + + async def __iter(self) -> AsyncIterator[ResponseT]: + cast_to = self._cast_to + response = self.response + process_line = self._client._process_stream_line + process_data = self._client._process_response_data + + async for raw_line in response.aiter_lines(): + if not raw_line or raw_line == "\n": + continue + + try: + line = process_line(raw_line) + except StopStreaming: + # we are done! + break + + yield process_data(data=json.loads(line), cast_to=cast_to, response=response) + + class PageInfo: """Stores the necesary information to build the request to retrieve the next page. @@ -526,7 +608,6 @@ def _process_response_data( return cast(ResponseT, construct_type(type_=cast_to, value=data)) - # TODO: make the constants in here configurable def _process_stream_line(self, contents: str) -> str: """Pre-process an indiviudal line from a streaming response""" if contents == "data: [DONE]\n": @@ -690,7 +771,7 @@ def request( remaining_retries: Optional[int] = None, *, stream: Literal[True], - ) -> Iterator[ResponseT]: + ) -> Stream[ResponseT]: ... @overload @@ -712,7 +793,7 @@ def request( remaining_retries: Optional[int] = None, *, stream: bool = False, - ) -> ResponseT | Iterator[ResponseT]: + ) -> ResponseT | Stream[ResponseT]: ... def request( @@ -722,7 +803,7 @@ def request( remaining_retries: Optional[int] = None, *, stream: bool = False, - ) -> ResponseT | Iterator[ResponseT]: + ) -> ResponseT | Stream[ResponseT]: return self._request( cast_to=cast_to, options=options, @@ -737,7 +818,7 @@ def _request( options: FinalRequestOptions, remaining_retries: int | None, stream: bool, - ) -> ResponseT | Iterator[ResponseT]: + ) -> ResponseT | Stream[ResponseT]: retries = self._remaining_retries(remaining_retries, options) request = self._build_request(options) @@ -762,7 +843,7 @@ def _request( raise APIConnectionError(request=request) from err if stream: - return self._process_stream_response(cast_to=cast_to, response=response) + return Stream(cast_to=cast_to, response=response, client=self) try: rsp = self._process_response(cast_to=cast_to, options=options, response=response) @@ -779,7 +860,7 @@ def _retry_request( response_headers: Optional[httpx.Headers] = None, *, stream: bool, - ) -> ResponseT | Iterator[ResponseT]: + ) -> ResponseT | Stream[ResponseT]: remaining = remaining_retries - 1 timeout = self._calculate_retry_timeout(remaining, options, response_headers) @@ -794,24 +875,6 @@ def _retry_request( stream=stream, ) - def _process_stream_response( - self, - *, - cast_to: Type[ResponseT], - response: httpx.Response, - ) -> Iterator[ResponseT]: - for raw_line in response.iter_lines(): - if not raw_line or raw_line == "\n": - continue - - try: - line = self._process_stream_line(raw_line) - except StopStreaming: - # we are done! - break - - yield self._process_response_data(data=json.loads(line), cast_to=cast_to, response=response) - def _request_api_list( self, model: Type[ModelT], @@ -861,7 +924,7 @@ def post( options: RequestOptions = {}, files: RequestFiles | None = None, stream: Literal[True], - ) -> Iterator[ResponseT]: + ) -> Stream[ResponseT]: ... @overload @@ -874,7 +937,7 @@ def post( options: RequestOptions = {}, files: RequestFiles | None = None, stream: bool, - ) -> ResponseT | Iterator[ResponseT]: + ) -> ResponseT | Stream[ResponseT]: ... def post( @@ -886,7 +949,7 @@ def post( options: RequestOptions = {}, files: RequestFiles | None = None, stream: bool = False, - ) -> ResponseT | Iterator[ResponseT]: + ) -> ResponseT | Stream[ResponseT]: opts = FinalRequestOptions.construct(method="post", url=path, json_data=body, files=files, **options) return cast(ResponseT, self.request(cast_to, opts, stream=stream)) @@ -993,7 +1056,7 @@ async def request( *, stream: Literal[True], remaining_retries: Optional[int] = None, - ) -> AsyncIterator[ResponseT]: + ) -> AsyncStream[ResponseT]: ... @overload @@ -1004,7 +1067,7 @@ async def request( *, stream: bool, remaining_retries: Optional[int] = None, - ) -> ResponseT | AsyncIterator[ResponseT]: + ) -> ResponseT | AsyncStream[ResponseT]: ... async def request( @@ -1014,7 +1077,7 @@ async def request( *, stream: bool = False, remaining_retries: Optional[int] = None, - ) -> ResponseT | AsyncIterator[ResponseT]: + ) -> ResponseT | AsyncStream[ResponseT]: return await self._request( cast_to=cast_to, options=options, @@ -1029,7 +1092,7 @@ async def _request( *, stream: bool, remaining_retries: int | None, - ) -> ResponseT | AsyncIterator[ResponseT]: + ) -> ResponseT | AsyncStream[ResponseT]: retries = self._remaining_retries(remaining_retries, options) request = self._build_request(options) @@ -1064,7 +1127,7 @@ async def _request( raise APIConnectionError(request=request) from err if stream: - return self._process_stream_response(cast_to=cast_to, response=response) + return AsyncStream(cast_to=cast_to, response=response, client=self) try: rsp = self._process_response(cast_to=cast_to, options=options, response=response) @@ -1081,7 +1144,7 @@ async def _retry_request( response_headers: Optional[httpx.Headers] = None, *, stream: bool, - ) -> ResponseT | AsyncIterator[ResponseT]: + ) -> ResponseT | AsyncStream[ResponseT]: remaining = remaining_retries - 1 timeout = self._calculate_retry_timeout(remaining, options, response_headers) @@ -1094,24 +1157,6 @@ async def _retry_request( stream=stream, ) - async def _process_stream_response( - self, - *, - cast_to: Type[ResponseT], - response: httpx.Response, - ) -> AsyncIterator[ResponseT]: - async for raw_line in response.aiter_lines(): - if not raw_line or raw_line == "\n": - continue - - try: - line = self._process_stream_line(raw_line) - except StopStreaming: - # we are done! - break - - yield self._process_response_data(data=json.loads(line), cast_to=cast_to, response=response) - def _request_api_list( self, model: Type[ModelT], @@ -1153,7 +1198,7 @@ async def post( files: RequestFiles | None = None, options: RequestOptions = {}, stream: Literal[True], - ) -> AsyncIterator[ResponseT]: + ) -> AsyncStream[ResponseT]: ... @overload @@ -1166,7 +1211,7 @@ async def post( files: RequestFiles | None = None, options: RequestOptions = {}, stream: bool, - ) -> ResponseT | AsyncIterator[ResponseT]: + ) -> ResponseT | AsyncStream[ResponseT]: ... async def post( @@ -1178,7 +1223,7 @@ async def post( files: RequestFiles | None = None, options: RequestOptions = {}, stream: bool = False, - ) -> ResponseT | AsyncIterator[ResponseT]: + ) -> ResponseT | AsyncStream[ResponseT]: opts = FinalRequestOptions.construct(method="post", url=path, json_data=body, files=files, **options) return await self.request(cast_to, opts, stream=stream) diff --git a/src/lithic/_client.py b/src/lithic/_client.py index a4efd3f8..643bdbd6 100644 --- a/src/lithic/_client.py +++ b/src/lithic/_client.py @@ -23,14 +23,10 @@ RequestOptions, ) from ._version import __version__ -from ._base_client import ( - DEFAULT_LIMITS, - DEFAULT_TIMEOUT, - DEFAULT_MAX_RETRIES, - SyncAPIClient, - AsyncAPIClient, - make_request_options, -) +from ._base_client import DEFAULT_LIMITS, DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES +from ._base_client import Stream as Stream +from ._base_client import AsyncStream as AsyncStream +from ._base_client import SyncAPIClient, AsyncAPIClient, make_request_options __all__ = [ "ENVIRONMENTS",