Skip to content

Commit

Permalink
Add local async persisters and tests
Browse files Browse the repository at this point in the history
Prototyping and testing async persisters:
- Add AsyncDevNull and AsyncInMemory persisters for tests
- Added support for async sqlite persister
- Test Async persister interface, async builder, async application
  • Loading branch information
jernejfrank committed Dec 30, 2024
1 parent 8389da9 commit 40fdede
Show file tree
Hide file tree
Showing 5 changed files with 804 additions and 7 deletions.
317 changes: 317 additions & 0 deletions burr/core/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from collections import defaultdict
from typing import Any, Dict, Literal, Optional, TypedDict

import aiosqlite

from burr.common.types import BaseCopyable
from burr.core import Action
from burr.core.state import State, logger
Expand Down Expand Up @@ -261,6 +263,30 @@ def save(
return


class AsyncDevNullPersister(AsyncBaseStatePersister):
"""Does nothing asynchronously, do not use this. This is for testing only."""

async def load(
self, partition_key: str, app_id: Optional[str], sequence_id: Optional[int] = None, **kwargs
) -> Optional[PersistedStateData]:
return None

async def list_app_ids(self, partition_key: str, **kwargs) -> list[str]:
return []

async def save(
self,
partition_key: Optional[str],
app_id: str,
sequence_id: int,
position: str,
state: State,
status: Literal["completed", "failed"],
**kwargs,
):
return


class SQLitePersister(BaseStatePersister, BaseCopyable):
"""Class for SQLite persistence of state. This is a simple implementation."""

Expand Down Expand Up @@ -476,6 +502,243 @@ def __setstate__(self, state):
)


class AsyncSQLitePersister(AsyncBaseStatePersister, BaseCopyable):
"""Class for asynchronous SQLite persistence of state. This is a simple implementation.
SQLite is specifically single-threaded and `aiosqlite <https://aiosqlite.omnilib.dev/en/latest/index.html>`_
creates async support through multi-threading. This persister is mainly here for quick prototyping and testing;
we suggest to consider a different database with native async support for production.
Note the third-party library `aiosqlite <https://aiosqlite.omnilib.dev/en/latest/index.html>`_,
is maintained and considered stable considered stable: https://github.com/omnilib/aiosqlite/issues/309.
"""

def copy(self) -> "Self":
return AsyncSQLitePersister(
db_path=self.db_path,
table_name=self.table_name,
serde_kwargs=self.serde_kwargs,
connect_kwargs=self._connect_kwargs,
)

PARTITION_KEY_DEFAULT = ""

@classmethod
async def from_values(
cls,
db_path: str,
table_name: str = "burr_state",
serde_kwargs: dict = None,
connect_kwargs: dict = None,
) -> "AsyncSQLitePersister":
"""Creates a new instance of the AsyncSQLitePersister from passed in values.
:param db_path: the path the DB will be stored.
:param table_name: the table name to store things under.
:param serde_kwargs: kwargs for state serialization/deserialization.
:param connect_kwargs: kwargs to pass to the aiosqlite.connect method.
:return: async sqlite persister instance with an open connection. You are responsible
for closing the connection yourself.
"""
connection = await aiosqlite.connect(
db_path, **connect_kwargs if connect_kwargs is not None else {}
)
return cls(connection, table_name, serde_kwargs)

def __init__(
self,
connection,
table_name: str = "burr_state",
serde_kwargs: dict = None,
):
"""Constructor.
NOTE: you are responsible to handle closing of the connection / teardown manually. To help,
we provide a close() method.
:param connection: the path the DB will be stored.
:param table_name: the table name to store things under.
:param serde_kwargs: kwargs for state serialization/deserialization.
"""
self.connection = connection
self.table_name = table_name
self.serde_kwargs = serde_kwargs or {}
self._initialized = False

async def create_table_if_not_exists(self, table_name: str):
"""Helper function to create the table where things are stored if it doesn't exist."""
cursor = await self.connection.cursor()
await cursor.execute(
f"""
CREATE TABLE IF NOT EXISTS {table_name} (
partition_key TEXT DEFAULT '{AsyncSQLitePersister.PARTITION_KEY_DEFAULT}',
app_id TEXT NOT NULL,
sequence_id INTEGER NOT NULL,
position TEXT NOT NULL,
status TEXT NOT NULL,
state TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (partition_key, app_id, sequence_id, position)
)"""
)
await cursor.execute(
f"""
CREATE INDEX IF NOT EXISTS {table_name}_created_at_index ON {table_name} (created_at);
"""
)
await self.connection.commit()

async def initialize(self):
"""Asynchronously creates the table if it doesn't exist"""
# Usage
await self.create_table_if_not_exists(self.table_name)
self._initialized = True

async def is_initialized(self) -> bool:
"""This checks to see if the table has been created in the database or not.
It defaults to using the initialized field, else queries the database to see if the table exists.
It then sets the initialized field to True if the table exists.
"""
if self._initialized:
return True

cursor = await self.connection.cursor()
await cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (self.table_name,)
)
self._initialized = await cursor.fetchone() is not None
return self._initialized

async def list_app_ids(self, partition_key: Optional[str] = None, **kwargs) -> list[str]:
partition_key = (
partition_key
if partition_key is not None
else AsyncSQLitePersister.PARTITION_KEY_DEFAULT
)

cursor = await self.connection.cursor()
await cursor.execute(
f"SELECT DISTINCT app_id FROM {self.table_name} "
f"WHERE partition_key = ? "
f"ORDER BY created_at DESC",
(partition_key,),
)
app_ids = [row[0] for row in await cursor.fetchall()]
return app_ids

async def load(
self,
partition_key: Optional[str],
app_id: Optional[str],
sequence_id: Optional[int] = None,
**kwargs,
) -> Optional[PersistedStateData]:
"""Asynchronously loads state for a given partition id.
Depending on the parameters, this will return the last thing written, the last thing written for a given app_id,
or a specific sequence_id for a given app_id.
:param partition_key:
:param app_id:
:param sequence_id:
:return:
"""
partition_key = (
partition_key
if partition_key is not None
else AsyncSQLitePersister.PARTITION_KEY_DEFAULT
)
logger.debug("Loading %s, %s, %s", partition_key, app_id, sequence_id)
cursor = await self.connection.cursor()
if app_id is None:
# get latest for all app_ids
await cursor.execute(
f"SELECT position, state, sequence_id, app_id, created_at, status FROM {self.table_name} "
f"WHERE partition_key = ? "
f"ORDER BY CREATED_AT DESC LIMIT 1",
(partition_key,),
)
elif sequence_id is None:
await cursor.execute(
f"SELECT position, state, sequence_id, app_id, created_at, status FROM {self.table_name} "
f"WHERE partition_key = ? AND app_id = ? "
f"ORDER BY sequence_id DESC LIMIT 1",
(partition_key, app_id),
)
else:
await cursor.execute(
f"SELECT position, state, sequence_id, app_id, created_at, status FROM {self.table_name} "
f"WHERE partition_key = ? AND app_id = ? AND sequence_id = ?",
(partition_key, app_id, sequence_id),
)
row = await cursor.fetchone()
if row is None:
return None
_state = State.deserialize(json.loads(row[1]), **self.serde_kwargs)
return {
"partition_key": partition_key,
"app_id": row[3],
"sequence_id": row[2],
"position": row[0],
"state": _state,
"created_at": row[4],
"status": row[5],
}

async def save(
self,
partition_key: Optional[str],
app_id: str,
sequence_id: int,
position: str,
state: State,
status: Literal["completed", "failed"],
**kwargs,
):
"""
Asynchronously saves the state for a given app_id, sequence_id, and position.
This method connects to the SQLite database, converts the state to a JSON string, and inserts a new record
into the table with the provided partition_key, app_id, sequence_id, position, and state. After the operation,
it commits the changes and closes the connection to the database.
:param partition_key: The partition key. This could be None, but it's up to the persister to whether
that is a valid value it can handle.
:param app_id: The identifier for the app instance being recorded.
:param sequence_id: The state corresponding to a specific point in time.
:param position: The position in the sequence of states.
:param state: The state to be saved, an instance of the State class.
:param status: The status of this state, either "completed" or "failed". If "failed" the state is what it was
before the action was applied.
:return: None
"""
logger.debug(
"saving %s, %s, %s, %s, %s, %s",
partition_key,
app_id,
sequence_id,
position,
state,
status,
)
partition_key = (
partition_key
if partition_key is not None
else AsyncSQLitePersister.PARTITION_KEY_DEFAULT
)
cursor = await self.connection.cursor()
json_state = json.dumps(state.serialize(**self.serde_kwargs))
await cursor.execute(
f"INSERT INTO {self.table_name} (partition_key, app_id, sequence_id, position, state, status) "
f"VALUES (?, ?, ?, ?, ?, ?)",
(partition_key, app_id, sequence_id, position, json_state, status),
)
await self.connection.commit()

async def close(self):
await self.connection.close()


class InMemoryPersister(BaseStatePersister):
"""In-memory persister for testing purposes. This is not recommended for production use."""

Expand Down Expand Up @@ -529,7 +792,61 @@ def save(
self._storage[partition_key][app_id].append(persisted_state)


class AsyncInMemoryPersister(AsyncBaseStatePersister):
"""Sync in-memory persister for testing purposes. This is not recommended for production use."""

def __init__(self):
self._storage = defaultdict(lambda: defaultdict(list))

async def load(
self, partition_key: str, app_id: Optional[str], sequence_id: Optional[int] = None, **kwargs
) -> Optional[PersistedStateData]:
# If no app_id provided, return None
if app_id is None:
return None

if not (states := self._storage[partition_key][app_id]):
return None

if sequence_id is None:
return states[-1]

# Find states matching the specific sequence_id
matching_states = [state for state in states if state["sequence_id"] == sequence_id]

# Return the latest state for this sequence_id, if exists
return matching_states[-1] if matching_states else None

async def list_app_ids(self, partition_key: str, **kwargs) -> list[str]:
return list(self._storage[partition_key].keys())

async def save(
self,
partition_key: Optional[str],
app_id: str,
sequence_id: int,
position: str,
state: State,
status: Literal["completed", "failed"],
**kwargs,
):
# Create a PersistedStateData entry
persisted_state: PersistedStateData = {
"partition_key": partition_key or "",
"app_id": app_id,
"sequence_id": sequence_id,
"position": position,
"state": state,
"created_at": datetime.datetime.now().isoformat(),
"status": status,
}

# Store the state
self._storage[partition_key][app_id].append(persisted_state)


SQLLitePersister = SQLitePersister
AsyncSQLLitePersister = AsyncSQLitePersister

if __name__ == "__main__":
s = SQLitePersister(db_path=".SQLite.db", table_name="test1")
Expand Down
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ streamlit = [
"streamlit",
"graphviz",
"matplotlib",
"sf-hamilton"
"sf-hamilton",
]

hamilton = [
Expand All @@ -40,6 +40,10 @@ graphviz = [
"graphviz"
]

sqlite = [
"aiosqlite"
]

postgresql = [
"psycopg2-binary"
]
Expand All @@ -61,6 +65,7 @@ tests = [
"pydantic[email]",
"pyarrow",
"redis",
"aiosqlite",
"burr[opentelemetry]",
"burr[haystack]",
"burr[ray]"
Expand All @@ -77,6 +82,7 @@ documentation = [
"psycopg2-binary",
"redis",
"ray",
"aiosqlite",
"sphinxcontrib-googleanalytics"
]

Expand Down
Loading

0 comments on commit 40fdede

Please sign in to comment.