Skip to content

Commit

Permalink
async lock introduced
Browse files Browse the repository at this point in the history
  • Loading branch information
pseusys committed Nov 8, 2024
1 parent 20b6b5f commit 2b6eebf
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 18 deletions.
21 changes: 15 additions & 6 deletions chatsky/context_storages/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
"""

from abc import ABC, abstractmethod
from asyncio import Lock
from importlib import import_module
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Union
from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Set, Tuple, Union

from pydantic import BaseModel, Field, field_validator, validate_call
from .protocol import PROTOCOLS
Expand All @@ -35,11 +36,6 @@ class DBContextStorage(ABC):
_responses_field_name: Literal["responses"] = "responses"
_default_subscript_value: int = 3

@property
@abstractmethod
def is_asynchronous(self) -> bool:
raise NotImplementedError()

def __init__(
self,
path: str,
Expand All @@ -55,10 +51,23 @@ def __init__(
self.rewrite_existing = rewrite_existing
"""Whether to rewrite existing data in the storage."""
self._subscripts = dict()
self._sync_lock = Lock()
for field in (self._labels_field_name, self._requests_field_name, self._responses_field_name):
value = configuration.get(field, self._default_subscript_value)
self._subscripts[field] = 0 if value == "__none__" else value

@staticmethod
def _synchronously_lock(method: Coroutine):
def setup_lock(condition: Callable[["DBContextStorage"], bool] = lambda _: True):
async def lock(self: "DBContextStorage", *args, **kwargs):
if condition(self):
async with self._sync_lock:
return await method(self, *args, **kwargs)
else:
return await method(self, *args, **kwargs)
return lock
return setup_lock

@staticmethod
def _verify_field_name(method: Callable):
def verifier(self: "DBContextStorage", *args, **kwargs):
Expand Down
8 changes: 6 additions & 2 deletions chatsky/context_storages/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ class FileContextStorage(DBContextStorage, ABC):
:param serializer: Serializer that will be used for serializing contexts.
"""

is_asynchronous = False

def __init__(
self,
path: str = "",
Expand Down Expand Up @@ -134,12 +132,14 @@ async def clear_all(self) -> None:


class JSONContextStorage(FileContextStorage):
@DBContextStorage._synchronously_lock
async def _save(self, data: SerializableStorage) -> None:
if not await isfile(self.path) or (await stat(self.path)).st_size == 0:
await makedirs(self.path.parent, exist_ok=True)
async with open(self.path, "w", encoding="utf-8") as file_stream:
await file_stream.write(data.model_dump_json())

@DBContextStorage._synchronously_lock
async def _load(self) -> SerializableStorage:
if not await isfile(self.path) or (await stat(self.path)).st_size == 0:
storage = SerializableStorage()
Expand All @@ -151,12 +151,14 @@ async def _load(self) -> SerializableStorage:


class PickleContextStorage(FileContextStorage):
@DBContextStorage._synchronously_lock
async def _save(self, data: SerializableStorage) -> None:
if not await isfile(self.path) or (await stat(self.path)).st_size == 0:
await makedirs(self.path.parent, exist_ok=True)
async with open(self.path, "wb") as file_stream:
await file_stream.write(dumps(data.model_dump()))

@DBContextStorage._synchronously_lock
async def _load(self) -> SerializableStorage:
if not await isfile(self.path) or (await stat(self.path)).st_size == 0:
storage = SerializableStorage()
Expand All @@ -179,9 +181,11 @@ def __init__(
self._storage = None
FileContextStorage.__init__(self, path, rewrite_existing, configuration)

@DBContextStorage._synchronously_lock
async def _save(self, data: SerializableStorage) -> None:
self._storage[self._SHELVE_ROOT] = data.model_dump()

@DBContextStorage._synchronously_lock
async def _load(self) -> SerializableStorage:
if self._storage is None:
content = SerializableStorage()
Expand Down
3 changes: 0 additions & 3 deletions chatsky/context_storages/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ class MemoryContextStorage(DBContextStorage):
- `misc`: [context_id, turn_number, misc]
"""

is_asynchronous = True

def __init__(
self,
path: str = "",
Expand Down Expand Up @@ -46,7 +44,6 @@ async def delete_context(self, ctx_id: str) -> None:
@DBContextStorage._verify_field_name
async def load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[int, bytes]]:
select = sorted([k for k, v in self._aux_storage[field_name].get(ctx_id, dict()).items() if v is not None], reverse=True)
print("SUBS:", self._subscripts[field_name])
if isinstance(self._subscripts[field_name], int):
select = select[:self._subscripts[field_name]]
elif isinstance(self._subscripts[field_name], Set):
Expand Down
2 changes: 0 additions & 2 deletions chatsky/context_storages/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ class MongoContextStorage(DBContextStorage):
_UNIQUE_KEYS = "unique_keys"
_ID_FIELD = "_id"

is_asynchronous = True

def __init__(
self,
path: str,
Expand Down
2 changes: 0 additions & 2 deletions chatsky/context_storages/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ class RedisContextStorage(DBContextStorage):
:param key_prefix: "namespace" prefix for all keys, should be set for efficient clearing of all data.
"""

is_asynchronous = True

def __init__(
self,
path: str,
Expand Down
10 changes: 9 additions & 1 deletion chatsky/context_storages/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(
@property
def is_asynchronous(self) -> bool:
return self.dialect != "sqlite"

async def _create_self_tables(self):
"""
Create tables required for context storing, if they do not exist yet.
Expand Down Expand Up @@ -222,6 +222,7 @@ def _check_availability(self):
install_suggestion = get_protocol_install_suggestion("sqlite")
raise ImportError("Package `sqlalchemy` and/or `aiosqlite` is missing.\n" + install_suggestion)

@DBContextStorage._synchronously_lock(lambda s: s.is_asynchronous)
async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]:
logger.debug(f"Loading main info for {ctx_id}...")
stmt = select(self.main_table).where(self.main_table.c[self._id_column_name] == ctx_id)
Expand All @@ -230,6 +231,7 @@ async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, byt
logger.debug(f"Main info loaded for {ctx_id}")
return None if result is None else result[1:]

@DBContextStorage._synchronously_lock(lambda s: s.is_asynchronous)
async def update_main_info(self, ctx_id: str, turn_id: int, crt_at: int, upd_at: int, misc: bytes, fw_data: bytes) -> None:
logger.debug(f"Updating main info for {ctx_id}...")
insert_stmt = self._INSERT_CALLABLE(self.main_table).values(
Expand All @@ -253,6 +255,7 @@ async def update_main_info(self, ctx_id: str, turn_id: int, crt_at: int, upd_at:
logger.debug(f"Main info updated for {ctx_id}")

# TODO: use foreign keys instead maybe?
@DBContextStorage._synchronously_lock(lambda s: s.is_asynchronous)
async def delete_context(self, ctx_id: str) -> None:
logger.debug(f"Deleting context {ctx_id}...")
async with self.engine.begin() as conn:
Expand All @@ -263,6 +266,7 @@ async def delete_context(self, ctx_id: str) -> None:
logger.debug(f"Context {ctx_id} deleted")

@DBContextStorage._verify_field_name
@DBContextStorage._synchronously_lock(lambda s: s.is_asynchronous)
async def load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[int, bytes]]:
logger.debug(f"Loading latest items for {ctx_id}, {field_name}...")
stmt = select(self.turns_table.c[self._key_column_name], self.turns_table.c[field_name])
Expand All @@ -278,6 +282,7 @@ async def load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[in
return result

@DBContextStorage._verify_field_name
@DBContextStorage._synchronously_lock(lambda s: s.is_asynchronous)
async def load_field_keys(self, ctx_id: str, field_name: str) -> List[int]:
logger.debug(f"Loading field keys for {ctx_id}, {field_name}...")
stmt = select(self.turns_table.c[self._key_column_name]).where((self.turns_table.c[self._id_column_name] == ctx_id) & (self.turns_table.c[field_name] != None))
Expand All @@ -287,6 +292,7 @@ async def load_field_keys(self, ctx_id: str, field_name: str) -> List[int]:
return result

@DBContextStorage._verify_field_name
@DBContextStorage._synchronously_lock(lambda s: s.is_asynchronous)
async def load_field_items(self, ctx_id: str, field_name: str, keys: List[int]) -> List[bytes]:
logger.debug(f"Loading field items for {ctx_id}, {field_name} ({collapse_num_list(keys)})...")
stmt = select(self.turns_table.c[self._key_column_name], self.turns_table.c[field_name])
Expand All @@ -297,6 +303,7 @@ async def load_field_items(self, ctx_id: str, field_name: str, keys: List[int])
return result

@DBContextStorage._verify_field_name
@DBContextStorage._synchronously_lock(lambda s: s.is_asynchronous)
async def update_field_items(self, ctx_id: str, field_name: str, items: List[Tuple[int, bytes]]) -> None:
logger.debug(f"Updating fields for {ctx_id}, {field_name}: {collapse_num_list(list(k for k, _ in items))}...")
if len(items) == 0:
Expand All @@ -320,6 +327,7 @@ async def update_field_items(self, ctx_id: str, field_name: str, items: List[Tup
await conn.execute(update_stmt)
logger.debug(f"Fields updated for {ctx_id}, {field_name}")

@DBContextStorage._synchronously_lock(lambda s: s.is_asynchronous)
async def clear_all(self) -> None:
logger.debug("Clearing all")
async with self.engine.begin() as conn:
Expand Down
2 changes: 0 additions & 2 deletions chatsky/context_storages/ydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ class YDBContextStorage(DBContextStorage):
_LIMIT_VAR = "limit"
_KEY_VAR = "key"

is_asynchronous = True

def __init__(
self,
path: str,
Expand Down

0 comments on commit 2b6eebf

Please sign in to comment.