Skip to content

Commit

Permalink
key filter implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
pseusys committed Oct 21, 2024
1 parent 757fe48 commit 5340256
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 8 deletions.
41 changes: 39 additions & 2 deletions chatsky/context_storages/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,32 @@

from abc import ABC, abstractmethod
from importlib import import_module
from inspect import signature
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Union

from pydantic import BaseModel, Field, field_validator, validate_call
from pydantic import BaseModel, Field

from .protocol import PROTOCOLS

_SUBSCRIPT_TYPE = Union[Literal["__all__"], int, Set[str]]
_SUBSCRIPT_DICT = Dict[str, Union[_SUBSCRIPT_TYPE, Literal["__none__"]]]


class ContextIdFilter(BaseModel):
update_time_greater: Optional[int] = Field(default=None)
update_time_less: Optional[int] = Field(default=None)
origin_interface_whitelist: Set[str] = Field(default_factory=set)

def filter_keys(self, keys: Set[str]) -> Set[str]:
if self.update_time_greater is not None:
keys = {k for k in keys if k > self.update_time_greater}
if self.update_time_less is not None:
keys = {k for k in keys if k < self.update_time_greater}
if len(self.origin_interface_whitelist) > 0:
keys = {k for k in keys if k in self.origin_interface_whitelist}
return keys


class DBContextStorage(ABC):
_main_table_name: Literal["main"] = "main"
_turns_table_name: Literal["turns"] = "turns"
Expand Down Expand Up @@ -72,6 +86,29 @@ def verifier(self, *args, **kwargs):
else:
return method(self, *args, **kwargs)
return verifier

@staticmethod
def _convert_id_filter(method: Callable):
def verifier(self, *args, **kwargs):
if len(args) >= 1:
args, filter_obj = [args[0]] + args[1:], args[1]
else:
filter_obj = kwargs.pop("filter", None)
if filter_obj is None:
raise ValueError(f"For method {method.__name__} argument 'filter' is not found!")
elif isinstance(filter_obj, Dict):
filter_obj = ContextIdFilter.validate_model(filter_obj)
elif not isinstance(filter_obj, ContextIdFilter):
raise ValueError(f"Invalid type '{type(filter_obj).__name__}' for method '{method.__name__}' argument 'filter'!")
return method(self, *args, filter=filter_obj, **kwargs)
return verifier

@abstractmethod
async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> List[str]:
"""
:param filter:
"""
raise NotImplementedError

@abstractmethod
async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]:
Expand Down
8 changes: 6 additions & 2 deletions chatsky/context_storages/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
import asyncio
from pickle import loads, dumps
from shelve import DbfilenameShelf
from typing import List, Set, Tuple, Dict, Optional
from typing import Any, List, Set, Tuple, Dict, Optional, Union

from pydantic import BaseModel, Field

from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE
from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT

try:
from aiofiles import open
Expand Down Expand Up @@ -61,6 +61,10 @@ async def _save(self, data: SerializableStorage) -> None:
async def _load(self) -> SerializableStorage:
raise NotImplementedError

@DBContextStorage._verify_field_name
async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]:
return filter.filter_keys(set((await self._load()).main.keys()))

async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]:
return (await self._load()).main.get(ctx_id, None)

Expand Down
8 changes: 6 additions & 2 deletions chatsky/context_storages/memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE
from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT


class MemoryContextStorage(DBContextStorage):
Expand Down Expand Up @@ -32,6 +32,10 @@ def __init__(
self._responses_field_name: dict(),
}

@DBContextStorage._verify_field_name
async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]:
return filter.filter_keys(set(self._main_storage.keys()))

async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]:
return self._main_storage.get(ctx_id, None)

Expand Down
9 changes: 7 additions & 2 deletions chatsky/context_storages/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""

from asyncio import gather
from typing import Callable, List, Dict, Set, Tuple, Optional
from typing import Any, List, Dict, Set, Tuple, Optional, Union

try:
from redis.asyncio import Redis
Expand All @@ -23,7 +23,7 @@
except ImportError:
redis_available = False

from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE
from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT
from .protocol import get_protocol_install_suggestion


Expand Down Expand Up @@ -76,6 +76,11 @@ def _keys_to_bytes(keys: List[int]) -> List[bytes]:
def _bytes_to_keys(keys: List[bytes]) -> List[int]:
return [int(f.decode("utf-8")) for f in keys]

@DBContextStorage._verify_field_name
async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]:
context_ids = {k.decode("utf-8") for k in await self.database.keys(f"{self._main_key}:*")}
return filter.filter_keys(context_ids)

async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]:
if await self.database.exists(f"{self._main_key}:{ctx_id}"):
cti, ca, ua, msc, fd = await gather(
Expand Down

0 comments on commit 5340256

Please sign in to comment.