Skip to content
This repository has been archived by the owner on Aug 25, 2024. It is now read-only.

Commit

Permalink
Add support for async functions in agents (LangStream#730)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet authored Nov 20, 2023
1 parent 81f66bb commit 72063bf
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import threading
from concurrent.futures import Future
from io import BytesIO
from typing import Union, List, Tuple, Any, Optional, Dict, AsyncIterable
from typing import Union, List, Tuple, Any, Optional, AsyncIterable

import fastavro
import grpc
Expand Down Expand Up @@ -86,40 +86,33 @@ async def get_topic_producer_records(self, request_iterator, context):
async for _ in request_iterator:
yield

async def read(self, requests: AsyncIterable[SourceRequest], _):
read_records = {}
op_result = []
async def do_read(self, context, read_records):
last_record_id = 0
read_requests_task = asyncio.create_task(
self.handle_read_requests(requests, read_records, op_result)
)
while True:
if len(op_result) > 0:
if op_result[0] is True:
break
raise op_result[0]
records = await asyncio.to_thread(self.agent.read)
if inspect.iscoroutinefunction(self.agent.read):
records = await self.agent.read()
else:
records = await asyncio.to_thread(self.agent.read)
if len(records) > 0:
records = [wrap_in_record(record) for record in records]
grpc_records = []
for record in records:
schemas, grpc_record = self.to_grpc_record(record)
for schema in schemas:
yield SourceResponse(schema=schema)
await context.write(SourceResponse(schema=schema))
grpc_records.append(grpc_record)
for i, record in enumerate(records):
last_record_id += 1
grpc_records[i].record_id = last_record_id
read_records[last_record_id] = record
yield SourceResponse(records=grpc_records)
read_requests_task.cancel()
await context.write(SourceResponse(records=grpc_records))
else:
await asyncio.sleep(0)

async def read(self, requests: AsyncIterable[SourceRequest], context):
read_records = {}
read_requests_task = asyncio.create_task(self.do_read(context, read_records))

async def handle_read_requests(
self,
requests: AsyncIterable[SourceRequest],
read_records: Dict[int, Record],
read_result,
):
try:
async for request in requests:
if len(request.committed_records) > 0:
Expand All @@ -136,9 +129,8 @@ async def handle_read_requests(
record,
RuntimeError(failure.error_message),
)
read_result.append(True)
except Exception as e:
read_result.append(e)
finally:
read_requests_task.cancel()

async def process(self, requests: AsyncIterable[ProcessorRequest], _):
async for request in requests:
Expand All @@ -149,9 +141,13 @@ async def process(self, requests: AsyncIterable[ProcessorRequest], _):
for source_record in request.records:
grpc_result = ProcessorResult(record_id=source_record.record_id)
try:
processed_records = await asyncio.to_thread(
self.agent.process, self.from_grpc_record(source_record)
)
r = self.from_grpc_record(source_record)
if inspect.iscoroutinefunction(self.agent.process):
processed_records = await self.agent.process(r)
else:
processed_records = await asyncio.to_thread(
self.agent.process, r
)
if isinstance(processed_records, Future):
processed_records = await asyncio.wrap_future(
processed_records
Expand All @@ -175,9 +171,11 @@ async def write(self, requests: AsyncIterable[SinkRequest], context):
self.client_schemas[request.schema.schema_id] = schema
if request.HasField("record"):
try:
result = await asyncio.to_thread(
self.agent.write, self.from_grpc_record(request.record)
)
r = self.from_grpc_record(request.record)
if inspect.iscoroutinefunction(self.agent.write):
result = await self.agent.write(r)
else:
result = await asyncio.to_thread(self.agent.write, r)
if isinstance(result, Future):
await asyncio.wrap_future(result)
yield SinkResponse(record_id=request.record.record_id)
Expand Down Expand Up @@ -280,9 +278,16 @@ def call_method_if_exists(klass, method, *args, **kwargs):
return None


async def acall_method_if_exists(klass, method, *args, **kwargs):
async def acall_method_if_exists(klass, method_name, *args, **kwargs):
method = getattr(klass, method_name, None)
if inspect.iscoroutinefunction(method):
defined_positional_parameters_count = len(inspect.signature(method).parameters)
if defined_positional_parameters_count >= len(args):
return await method(*args, **kwargs)
else:
return await method(*args[:defined_positional_parameters_count], **kwargs)
return await asyncio.to_thread(
call_method_if_exists, klass, method, *args, **kwargs
call_method_if_exists, klass, method_name, *args, **kwargs
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,10 @@ async def test_failing_record():
assert response.results[0].error == "failure"


async def test_future_record():
@pytest.mark.parametrize("klass", ["MyFutureProcessor", "MyAsyncProcessor"])
async def test_future_record(klass):
async with ServerAndStub(
"langstream_grpc.tests.test_grpc_processor.MyFutureProcessor"
f"langstream_grpc.tests.test_grpc_processor.{klass}"
) as server_and_stub:
response: ProcessorResponse
async for response in server_and_stub.stub.process(
Expand Down Expand Up @@ -270,13 +271,17 @@ def process(self, record: Record) -> List[RecordType]:

class MyFutureProcessor(Processor):
def __init__(self):
self.written_records = []
self.executor = ThreadPoolExecutor(max_workers=10)

def process(self, record: Record) -> Future[List[RecordType]]:
return self.executor.submit(lambda r: [r], record)


class MyAsyncProcessor(Processor):
async def process(self, record: Record) -> List[RecordType]:
return [record]


class ProcessorInitOneParameter:
def __init__(self):
self.myparam = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from io import BytesIO

import fastavro
import pytest

from langstream_grpc.api import Record, Sink
from langstream_grpc.proto.agent_pb2 import (
Expand All @@ -30,9 +31,10 @@
from langstream_grpc.tests.server_and_stub import ServerAndStub


async def test_write():
@pytest.mark.parametrize("klass", ["MySink", "MyFutureSink", "MyAsyncSink"])
async def test_write(klass):
async with ServerAndStub(
"langstream_grpc.tests.test_grpc_sink.MySink"
f"langstream_grpc.tests.test_grpc_sink.{klass}"
) as server_and_stub:

async def requests():
Expand Down Expand Up @@ -66,6 +68,7 @@ async def requests():

assert len(responses) == 1
assert responses[0].record_id == 43
assert responses[0].error == ""
assert len(server_and_stub.server.agent.written_records) == 1
assert (
server_and_stub.server.agent.written_records[0].value().value["field"]
Expand Down Expand Up @@ -94,30 +97,6 @@ async def test_write_error():
assert responses[0].error == "test-error"


async def test_write_future():
async with ServerAndStub(
"langstream_grpc.tests.test_grpc_sink.MyFutureSink"
) as server_and_stub:
responses: list[SinkResponse]
responses = [
response
async for response in server_and_stub.stub.write(
[
SinkRequest(
record=GrpcRecord(
record_id=42,
value=Value(string_value="test"),
)
)
]
)
]
assert len(responses) == 1
assert responses[0].record_id == 42
assert len(server_and_stub.server.agent.written_records) == 1
assert server_and_stub.server.agent.written_records[0].value() == "test"


class MySink(Sink):
def __init__(self):
self.written_records = []
Expand All @@ -126,15 +105,20 @@ def write(self, record: Record):
self.written_records.append(record)


class MyErrorSink(Sink):
def write(self, record: Record):
raise RuntimeError("test-error")


class MyFutureSink(Sink):
def __init__(self):
self.written_records = []
self.executor = ThreadPoolExecutor(max_workers=10)

def write(self, record: Record) -> Future[None]:
return self.executor.submit(lambda r: self.written_records.append(r), record)


class MyAsyncSink(MySink):
async def write(self, record: Record):
super().write(record)


class MyErrorSink(Sink):
def write(self, record: Record):
raise RuntimeError("test-error")
Loading

0 comments on commit 72063bf

Please sign in to comment.