Skip to content

Commit

Permalink
kafka: Implement sync orchestrator and executor classes
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Sep 11, 2024
1 parent 217b07a commit 4accc19
Show file tree
Hide file tree
Showing 14 changed files with 1,628 additions and 49 deletions.
6 changes: 4 additions & 2 deletions libs/scheduler-kafka/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ start-services:
stop-services:
docker compose -f tests/compose.yml down

TEST_PATH ?= .

test:
make start-services && poetry run pytest; \
make start-services && poetry run pytest $(TEST_PATH); \
EXIT_CODE=$$?; \
make stop-services; \
exit $$EXIT_CODE

test_watch:
make start-services && poetry run ptw .; \
make start-services && poetry run ptw . -- $(TEST_PATH); \
EXIT_CODE=$$?; \
make stop-services; \
exit $$EXIT_CODE
Expand Down
14 changes: 10 additions & 4 deletions libs/scheduler-kafka/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import asyncio
import logging
import os

from langgraph.scheduler.kafka.orchestrator import KafkaOrchestrator
from langgraph.scheduler.kafka.orchestrator import AsyncKafkaOrchestrator
from langgraph.scheduler.kafka.types import Topics

from your_lib import graph # graph expected to be a compiled LangGraph graph
Expand All @@ -40,7 +40,7 @@ topics = Topics(
)

async def main():
async with KafkaOrchestrator(graph, topics) as orch:
async with AsyncKafkaOrchestrator(graph, topics) as orch:
async for msgs in orch:
logger.info('Procesed %d messages', len(msgs))

Expand All @@ -56,7 +56,7 @@ import asyncio
import logging
import os

from langgraph.scheduler.kafka.executor import KafkaExecutor
from langgraph.scheduler.kafka.executor import AsyncKafkaExecutor
from langgraph.scheduler.kafka.types import Topics

from your_lib import graph # graph expected to be a compiled LangGraph graph
Expand All @@ -70,7 +70,7 @@ topics = Topics(
)

async def main():
async with KafkaExecutor(graph, topics) as orch:
async with AsyncKafkaExecutor(graph, topics) as orch:
async for msgs in orch:
logger.info('Procesed %d messages', len(msgs))

Expand All @@ -89,6 +89,8 @@ python executor.py &

## Configuration

We offer sync and async versions of the orchestrator and executor, `KafkaOrchestrator` and `AsyncKafkaOrchestrator`, and `KafkaExecutor` and `AsyncKafkaExecutor` respectively. The async versions are recommended, especially if you want to process tasks in batches. With the async classes we recommend using `uvloop` for better performance.

You can pass any of the following values as `kwargs` to either `KafkaOrchestrator` or `KafkaExecutor` to configure the consumer:

- batch_max_n (int): Maximum number of messages to include in a single batch. Default: 10.
Expand Down Expand Up @@ -131,3 +133,7 @@ By default the orchestrator and executor will attempt to connect to a Kafka brok
- connections_max_idle_ms (int): Close idle connections after the number
of milliseconds specified by this config. Specifying `None` will
disable idle checks. Default: 540000 (9 minutes).

### Custom consumer/producer

Both the orchestrator and executor accept a `consumer` and `producer` argument, which should implement the `Consumer` or `Producer` protocols respectively. We expect the consumer to have auto-commit disabled, and the producer and consumer to have no serializers/deserializers set.
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
import dataclasses
from typing import Any, Sequence

import aiokafka


class DefaultAsyncConsumer(aiokafka.AIOKafkaConsumer):
async def getmany(
self, timeout_ms: int, max_records: int
) -> dict[str, Sequence[dict[str, Any]]]:
batch = await super().getmany(timeout_ms=timeout_ms, max_records=max_records)
return {t: [dataclasses.asdict(m) for m in msgs] for t, msgs in batch.items()}
pass


class DefaultAsyncProducer(aiokafka.AIOKafkaProducer):
Expand Down
39 changes: 39 additions & 0 deletions libs/scheduler-kafka/langgraph/scheduler/kafka/default_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import concurrent.futures
from typing import Optional, Sequence

from kafka import KafkaConsumer, KafkaProducer
from langgraph.scheduler.kafka.types import ConsumerRecord, TopicPartition


class DefaultConsumer(KafkaConsumer):
def getmany(
self, timeout_ms: int, max_records: int
) -> dict[TopicPartition, Sequence[ConsumerRecord]]:
return self.poll(timeout_ms=timeout_ms, max_records=max_records)

def __enter__(self):
return self

def __exit__(self, *args):
self.close()


class DefaultProducer(KafkaProducer):
def send(
self,
topic: str,
*,
key: Optional[bytes] = None,
value: Optional[bytes] = None,
) -> concurrent.futures.Future:
fut = concurrent.futures.Future()
kfut = super().send(topic, key=key, value=value)
kfut.add_callback(fut.set_result)
kfut.add_errback(fut.set_exception)
return fut

def __enter__(self):
return self

def __exit__(self, *args):
self.close()
218 changes: 211 additions & 7 deletions libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import asyncio
from contextlib import AbstractAsyncContextManager, AsyncExitStack
import concurrent.futures
from contextlib import (
AbstractAsyncContextManager,
AbstractContextManager,
AsyncExitStack,
ExitStack,
)
from functools import partial
from typing import Any, Optional, Sequence

Expand All @@ -12,23 +18,29 @@
from langgraph.errors import CheckpointNotLatest, GraphDelegate, TaskNotFound
from langgraph.pregel import Pregel
from langgraph.pregel.algo import prepare_single_task
from langgraph.pregel.executor import AsyncBackgroundExecutor, Submit
from langgraph.pregel.manager import AsyncChannelsManager
from langgraph.pregel.executor import (
AsyncBackgroundExecutor,
BackgroundExecutor,
Submit,
)
from langgraph.pregel.manager import AsyncChannelsManager, ChannelsManager
from langgraph.pregel.runner import PregelRunner
from langgraph.pregel.types import RetryPolicy
from langgraph.scheduler.kafka.retry import aretry
from langgraph.scheduler.kafka.retry import aretry, retry
from langgraph.scheduler.kafka.types import (
AsyncConsumer,
AsyncProducer,
Consumer,
ErrorMessage,
MessageToExecutor,
MessageToOrchestrator,
Producer,
Topics,
)
from langgraph.utils.config import patch_configurable


class KafkaExecutor(AbstractAsyncContextManager):
class AsyncKafkaExecutor(AbstractAsyncContextManager):
consumer: AsyncConsumer

producer: AsyncProducer
Expand Down Expand Up @@ -93,7 +105,7 @@ async def __anext__(self) -> Sequence[MessageToExecutor]:
timeout_ms=self.batch_max_ms, max_records=self.batch_max_n
)
msgs: list[MessageToExecutor] = [
serde.loads(msg["value"]) for msgs in recs.values() for msg in msgs
serde.loads(msg.value) for msgs in recs.values() for msg in msgs
]
# process batch
await asyncio.gather(*(self.each(msg) for msg in msgs))
Expand Down Expand Up @@ -189,7 +201,7 @@ async def attempt(self, msg: MessageToExecutor) -> None:
pass
else:
# task was not found
await self.graph.checkpointer.put_writes(
await self.graph.checkpointer.aput_writes(
msg["config"], [(ERROR, TaskNotFound())]
)
# notify orchestrator
Expand Down Expand Up @@ -220,3 +232,195 @@ def _put_writes(
writes: list[tuple[str, Any]],
) -> None:
return submit(self.graph.checkpointer.aput_writes, config, writes, task_id)


class KafkaExecutor(AbstractContextManager):
consumer: Consumer

producer: Producer

def __init__(
self,
graph: Pregel,
topics: Topics,
*,
batch_max_n: int = 10,
batch_max_ms: int = 1000,
retry_policy: Optional[RetryPolicy] = None,
consumer: Optional[Consumer] = None,
producer: Optional[Producer] = None,
**kwargs: Any,
) -> None:
self.graph = graph
self.topics = topics
self.stack = ExitStack()
self.kwargs = kwargs
self.consumer = consumer
self.producer = producer
self.batch_max_n = batch_max_n
self.batch_max_ms = batch_max_ms
self.retry_policy = retry_policy

def __enter__(self) -> Self:
self.subgraphs = dict(self.graph.get_subgraphs(recurse=True))
self.submit = self.stack.enter_context(BackgroundExecutor({}))
if self.consumer is None:
from langgraph.scheduler.kafka.default_sync import DefaultConsumer

self.consumer = self.stack.enter_context(
DefaultConsumer(
self.topics.executor,
auto_offset_reset="earliest",
group_id="executor",
enable_auto_commit=False,
**self.kwargs,
)
)
if self.producer is None:
from langgraph.scheduler.kafka.default_sync import DefaultProducer

self.producer = self.stack.enter_context(
DefaultProducer(
**self.kwargs,
)
)
return self

def __exit__(self, *args: Any) -> None:
return self.stack.__exit__(*args)

def __iter__(self) -> Self:
return self

def __next__(self) -> Sequence[MessageToExecutor]:
# wait for next batch
recs = self.consumer.getmany(
timeout_ms=self.batch_max_ms, max_records=self.batch_max_n
)
msgs: list[MessageToExecutor] = [
serde.loads(msg.value) for msgs in recs.values() for msg in msgs
]
# process batch
concurrent.futures.wait(self.submit(self.each, msg) for msg in msgs)
# commit offsets
self.consumer.commit()
# return message
return msgs

def each(self, msg: MessageToExecutor) -> None:
try:
retry(self.retry_policy, self.attempt, msg)
except CheckpointNotLatest:
pass
except GraphDelegate as exc:
for arg in exc.args:
fut = self.producer.send(
self.topics.orchestrator,
value=serde.dumps(
MessageToOrchestrator(
config=arg["config"],
input=orjson.Fragment(
self.graph.checkpointer.serde.dumps(arg["input"])
),
finally_executor=[msg],
)
),
# use thread_id, checkpoint_ns as partition key
key=serde.dumps(
(
arg["config"]["configurable"]["thread_id"],
arg["config"]["configurable"].get("checkpoint_ns"),
)
),
)
fut.result()
except Exception as exc:
fut = self.producer.send(
self.topics.error,
value=serde.dumps(
ErrorMessage(
topic=self.topics.executor,
msg=msg,
error=repr(exc),
)
),
)
fut.result()

def attempt(self, msg: MessageToExecutor) -> None:
# find graph
if checkpoint_ns := msg["config"]["configurable"].get("checkpoint_ns"):
# remove task_ids from checkpoint_ns
recast_checkpoint_ns = NS_SEP.join(
part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
)
# find the subgraph with the matching name
if recast_checkpoint_ns in self.subgraphs:
graph = self.subgraphs[recast_checkpoint_ns]
else:
raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")
else:
graph = self.graph
# process message
saved = self.graph.checkpointer.get_tuple(
patch_configurable(msg["config"], {"checkpoint_id": None})
)
if saved is None:
raise RuntimeError("Checkpoint not found")
if saved.checkpoint["id"] != msg["config"]["configurable"]["checkpoint_id"]:
raise CheckpointNotLatest()
with ChannelsManager(
graph.channels, saved.checkpoint, msg["config"], self.graph.store
) as (channels, managed), BackgroundExecutor({}) as submit:
if task := prepare_single_task(
msg["task"]["path"],
msg["task"]["id"],
checkpoint=saved.checkpoint,
processes=graph.nodes,
channels=channels,
managed=managed,
config=patch_configurable(msg["config"], {CONFIG_KEY_DELEGATE: True}),
step=saved.metadata["step"] + 1,
for_execution=True,
checkpointer=self.graph.checkpointer,
):
# execute task, saving writes
runner = PregelRunner(
submit=submit,
put_writes=partial(self._put_writes, submit, msg["config"]),
)
for _ in runner.tick([task], reraise=False):
pass
else:
# task was not found
self.graph.checkpointer.put_writes(
msg["config"], [(ERROR, TaskNotFound())]
)
# notify orchestrator
fut = self.producer.send(
self.topics.orchestrator,
value=serde.dumps(
MessageToOrchestrator(
input=None,
config=msg["config"],
finally_executor=msg.get("finally_executor"),
)
),
# use thread_id, checkpoint_ns as partition key
key=serde.dumps(
(
msg["config"]["configurable"]["thread_id"],
msg["config"]["configurable"].get("checkpoint_ns"),
)
),
)
fut.result()

def _put_writes(
self,
submit: Submit,
config: RunnableConfig,
task_id: str,
writes: list[tuple[str, Any]],
) -> None:
return submit(self.graph.checkpointer.put_writes, config, writes, task_id)
Loading

0 comments on commit 4accc19

Please sign in to comment.