Skip to content

Commit

Permalink
codestyle
Browse files Browse the repository at this point in the history
  • Loading branch information
pseusys committed Nov 29, 2024
1 parent 6fd0e1a commit 47edbda
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 51 deletions.
8 changes: 4 additions & 4 deletions chatsky/context_storages/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from logging import getLogger
from pathlib import Path
from time import time_ns
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Literal, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Union

from pydantic import BaseModel, Field, PrivateAttr, TypeAdapter, ValidationError, field_serializer, field_validator
from pydantic import BaseModel, Field, PrivateAttr, TypeAdapter, field_serializer, field_validator

from chatsky.utils.logging import collapse_num_list

Expand Down Expand Up @@ -64,11 +64,11 @@ def _validate_framework_data(cls, value: Any) -> Dict:
if isinstance(value, bytes) or isinstance(value, str):
value = loads(value)
return value

@field_serializer("misc", when_used="always")
def _serialize_misc(self, misc: Dict[str, Any]) -> bytes:
return self._misc_adaptor.dump_json(misc)

@field_serializer("framework_data", when_used="always")
def serialize_courses_in_order(self, framework_data: FrameworkData) -> bytes:
return framework_data.model_dump_json().encode()
Expand Down
20 changes: 12 additions & 8 deletions chatsky/context_storages/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def __init__(
async def _connect(self):
await gather(
self.main_table.create_index(NameConfig._id_column, background=True, unique=True),
self.turns_table.create_index([NameConfig._id_column, NameConfig._key_column], background=True, unique=True),
self.turns_table.create_index(
[NameConfig._id_column, NameConfig._key_column], background=True, unique=True
),
)

async def _load_main_info(self, ctx_id: str) -> Optional[ContextInfo]:
Expand All @@ -81,13 +83,15 @@ async def _load_main_info(self, ctx_id: str) -> Optional[ContextInfo]:
],
)
return (
ContextInfo.model_validate({
"turn_id": result[NameConfig._current_turn_id_column],
"created_at": result[NameConfig._created_at_column],
"updated_at": result[NameConfig._updated_at_column],
"misc": result[NameConfig._misc_column],
"framework_data": result[NameConfig._framework_data_column],
})
ContextInfo.model_validate(
{
"turn_id": result[NameConfig._current_turn_id_column],
"created_at": result[NameConfig._created_at_column],
"updated_at": result[NameConfig._updated_at_column],
"misc": result[NameConfig._misc_column],
"framework_data": result[NameConfig._framework_data_column],
}
)
if result is not None
else None
)
Expand Down
20 changes: 15 additions & 5 deletions chatsky/context_storages/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,28 @@ async def _load_main_info(self, ctx_id: str) -> Optional[ContextInfo]:
self.database.hget(f"{self._main_key}:{ctx_id}", NameConfig._misc_column),
self.database.hget(f"{self._main_key}:{ctx_id}", NameConfig._framework_data_column),
)
return ContextInfo.model_validate({"turn_id": cti, "created_at": ca, "updated_at": ua, "misc": msc, "framework_data": fd})
return ContextInfo.model_validate(
{"turn_id": cti, "created_at": ca, "updated_at": ua, "misc": msc, "framework_data": fd}
)
else:
return None

async def _update_main_info(self, ctx_id: str, ctx_info: ContextInfo) -> None:
ctx_info_dump = ctx_info.model_dump(mode="python")
await gather(
self.database.hset(f"{self._main_key}:{ctx_id}", NameConfig._current_turn_id_column, str(ctx_info_dump["turn_id"])),
self.database.hset(f"{self._main_key}:{ctx_id}", NameConfig._created_at_column, str(ctx_info_dump["created_at"])),
self.database.hset(f"{self._main_key}:{ctx_id}", NameConfig._updated_at_column, str(ctx_info_dump["updated_at"])),
self.database.hset(
f"{self._main_key}:{ctx_id}", NameConfig._current_turn_id_column, str(ctx_info_dump["turn_id"])
),
self.database.hset(
f"{self._main_key}:{ctx_id}", NameConfig._created_at_column, str(ctx_info_dump["created_at"])
),
self.database.hset(
f"{self._main_key}:{ctx_id}", NameConfig._updated_at_column, str(ctx_info_dump["updated_at"])
),
self.database.hset(f"{self._main_key}:{ctx_id}", NameConfig._misc_column, ctx_info_dump["misc"]),
self.database.hset(f"{self._main_key}:{ctx_id}", NameConfig._framework_data_column, ctx_info_dump["framework_data"]),
self.database.hset(
f"{self._main_key}:{ctx_id}", NameConfig._framework_data_column, ctx_info_dump["framework_data"]
),
)

async def _delete_context(self, ctx_id: str) -> None:
Expand Down
22 changes: 16 additions & 6 deletions chatsky/context_storages/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
from __future__ import annotations
import asyncio
from importlib import import_module
from os import getenv
from typing import Callable, Collection, List, Optional, Set, Tuple
import logging

from chatsky.utils.logging import collapse_num_list
from .database import ContextInfo, DBContextStorage, _SUBSCRIPT_DICT, NameConfig
from .protocol import get_protocol_install_suggestion

Expand Down Expand Up @@ -214,7 +212,19 @@ async def _load_main_info(self, ctx_id: str) -> Optional[ContextInfo]:
stmt = select(self.main_table).where(self.main_table.c[NameConfig._id_column] == ctx_id)
async with self.engine.begin() as conn:
result = (await conn.execute(stmt)).fetchone()
return None if result is None else ContextInfo.model_validate({"turn_id": result[1], "created_at": result[2], "updated_at": result[3], "misc": result[4], "framework_data": result[5]})
return (
None
if result is None
else ContextInfo.model_validate(
{
"turn_id": result[1],
"created_at": result[2],
"updated_at": result[3],
"misc": result[4],
"framework_data": result[5],
}
)
)

async def _update_main_info(self, ctx_id: str, ctx_info: ContextInfo) -> None:
ctx_info_dump = ctx_info.model_dump(mode="python")
Expand Down Expand Up @@ -253,7 +263,7 @@ async def _delete_context(self, ctx_id: str) -> None:
async def _load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[int, bytes]]:
stmt = select(self.turns_table.c[NameConfig._key_column], self.turns_table.c[field_name])
stmt = stmt.where(self.turns_table.c[NameConfig._id_column] == ctx_id)
stmt = stmt.where(self.turns_table.c[field_name] != None)
stmt = stmt.where(self.turns_table.c[field_name] != None) # noqa: E711
stmt = stmt.order_by(self.turns_table.c[NameConfig._key_column].desc())
if isinstance(self._subscripts[field_name], int):
stmt = stmt.limit(self._subscripts[field_name])
Expand All @@ -265,15 +275,15 @@ async def _load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[i
async def _load_field_keys(self, ctx_id: str, field_name: str) -> List[int]:
stmt = select(self.turns_table.c[NameConfig._key_column])
stmt = stmt.where(self.turns_table.c[NameConfig._id_column] == ctx_id)
stmt = stmt.where(self.turns_table.c[field_name] != None)
stmt = stmt.where(self.turns_table.c[field_name] != None) # noqa: E711
async with self.engine.begin() as conn:
return [k[0] for k in (await conn.execute(stmt)).fetchall()]

async def _load_field_items(self, ctx_id: str, field_name: str, keys: List[int]) -> List[Tuple[int, bytes]]:
stmt = select(self.turns_table.c[NameConfig._key_column], self.turns_table.c[field_name])
stmt = stmt.where(self.turns_table.c[NameConfig._id_column] == ctx_id)
stmt = stmt.where(self.turns_table.c[NameConfig._key_column].in_(tuple(keys)))
stmt = stmt.where(self.turns_table.c[field_name] != None)
stmt = stmt.where(self.turns_table.c[field_name] != None) # noqa: E711
async with self.engine.begin() as conn:
return list((await conn.execute(stmt)).fetchall())

Expand Down
16 changes: 9 additions & 7 deletions chatsky/context_storages/ydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,15 @@ async def callee(session: Session) -> Optional[ContextInfo]:
commit_tx=True,
)
return (
ContextInfo.model_validate({
"turn_id": result_sets[0].rows[0][NameConfig._current_turn_id_column],
"created_at": result_sets[0].rows[0][NameConfig._created_at_column],
"updated_at": result_sets[0].rows[0][NameConfig._updated_at_column],
"misc": result_sets[0].rows[0][NameConfig._misc_column],
"framework_data": result_sets[0].rows[0][NameConfig._framework_data_column],
})
ContextInfo.model_validate(
{
"turn_id": result_sets[0].rows[0][NameConfig._current_turn_id_column],
"created_at": result_sets[0].rows[0][NameConfig._created_at_column],
"updated_at": result_sets[0].rows[0][NameConfig._updated_at_column],
"misc": result_sets[0].rows[0][NameConfig._misc_column],
"framework_data": result_sets[0].rows[0][NameConfig._framework_data_column],
}
)
if len(result_sets[0].rows) > 0
else None
)
Expand Down
11 changes: 10 additions & 1 deletion chatsky/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,16 @@ async def store(self) -> None:
logger.debug(f"Storing context: {self.id}...")
self._updated_at = time_ns()
await gather(
self._storage.update_main_info(self.id, ContextInfo(turn_id=self.current_turn_id, created_at=self._created_at, updated_at=self._updated_at, misc=self.misc, framework_data=self.framework_data)),
self._storage.update_main_info(
self.id,
ContextInfo(
turn_id=self.current_turn_id,
created_at=self._created_at,
updated_at=self._updated_at,
misc=self.misc,
framework_data=self.framework_data,
),
),
self.labels.store(),
self.requests.store(),
self.responses.store(),
Expand Down
4 changes: 1 addition & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import pytest

from pydantic import TypeAdapter

from chatsky import Pipeline, Context, AbsoluteNodeLabel, Message
from chatsky import Pipeline, Context, AbsoluteNodeLabel


def pytest_report_header(config, start_path):
Expand Down
25 changes: 10 additions & 15 deletions tests/context_storages/test_dbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
mongo_available,
ydb_available,
)
from chatsky.core.context import FrameworkData
from chatsky.utils.testing.cleanup_db import (
delete_file,
delete_mongo,
Expand Down Expand Up @@ -213,21 +212,13 @@ async def test_update_main_info(self, db: DBContextStorage, add_context):
assert await db.load_main_info("2") == ContextInfo(turn_id=1, created_at=1, updated_at=1)

async def test_wrong_field_name(self, db: DBContextStorage):
with pytest.raises(
ValueError, match="Invalid value 'non-existent' for argument 'field_name'!"
):
with pytest.raises(ValueError, match="Invalid value 'non-existent' for argument 'field_name'!"):
await db.load_field_latest("1", "non-existent")
with pytest.raises(
ValueError, match="Invalid value 'non-existent' for argument 'field_name'!"
):
with pytest.raises(ValueError, match="Invalid value 'non-existent' for argument 'field_name'!"):
await db.load_field_keys("1", "non-existent")
with pytest.raises(
ValueError, match="Invalid value 'non-existent' for argument 'field_name'!"
):
with pytest.raises(ValueError, match="Invalid value 'non-existent' for argument 'field_name'!"):
await db.load_field_items("1", "non-existent", [1, 2])
with pytest.raises(
ValueError, match="Invalid value 'non-existent' for argument 'field_name'!"
):
with pytest.raises(ValueError, match="Invalid value 'non-existent' for argument 'field_name'!"):
await db.update_field_items("1", "non-existent", [(1, b"2")])

async def test_field_get(self, db: DBContextStorage, add_context):
Expand Down Expand Up @@ -310,9 +301,13 @@ async def db_operations(key: int):
str_key = str(key)
key_misc = {f"{key}": key + 2}
await asyncio.sleep(random.random() / 100)
await db.update_main_info(str_key, ContextInfo(turn_id=key, created_at=key + 1, updated_at=key, misc=key_misc))
await db.update_main_info(
str_key, ContextInfo(turn_id=key, created_at=key + 1, updated_at=key, misc=key_misc)
)
await asyncio.sleep(random.random() / 100)
assert await db.load_main_info(str_key) == ContextInfo(turn_id=key, created_at=key + 1, updated_at=key, misc=key_misc)
assert await db.load_main_info(str_key) == ContextInfo(
turn_id=key, created_at=key + 1, updated_at=key, misc=key_misc
)

for idx in range(1, 20):
await db.update_field_items(str_key, "requests", [(0, bytes(2 * key + idx)), (idx, bytes(key + idx))])
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class TestTurns:
@pytest.fixture
def ctx(self, context_factory):
return context_factory()

async def test_complete_turn(self, ctx: Context):
ctx.labels[5] = ("flow", "node5")
ctx.requests[5] = Message(text="text5")
Expand Down
4 changes: 3 additions & 1 deletion tests/utils/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ async def test_get_context(context_storage: JSONContextStorage):
await copy_ctx.requests.update({0: Message(misc={"0": ">e"}), 1: Message(misc={"0": "zv"})})
await copy_ctx.responses.update({0: Message(misc={"0": "3 "}), 1: Message(misc={"0": "sh"})})
copy_ctx.misc.update({"0": " d]", "1": " (b"})
assert context.model_dump(exclude={"id", "current_turn_id"}) == copy_ctx.model_dump(exclude={"id", "current_turn_id"})
assert context.model_dump(exclude={"id", "current_turn_id"}) == copy_ctx.model_dump(
exclude={"id", "current_turn_id"}
)


async def test_benchmark_config(context_storage: JSONContextStorage, monkeypatch: pytest.MonkeyPatch):
Expand Down

0 comments on commit 47edbda

Please sign in to comment.