Skip to content

Commit

Permalink
turns added, empty ctx_dict method also added
Browse files Browse the repository at this point in the history
  • Loading branch information
pseusys committed Nov 22, 2024
1 parent 1f96f6d commit ce6c8b6
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
23 changes: 11 additions & 12 deletions chatsky/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from asyncio import Event, gather
from uuid import uuid4
from time import time_ns
from typing import Any, Callable, Optional, Dict, TYPE_CHECKING
from typing import Any, Callable, Iterable, Optional, Dict, TYPE_CHECKING, Tuple
import logging

from pydantic import BaseModel, Field, PrivateAttr, TypeAdapter, model_validator
Expand All @@ -39,14 +39,6 @@
logger = logging.getLogger(__name__)


"""
class Turn(BaseModel):
label: Optional[NodeLabel2Type] = Field(default=None)
request: Optional[Message] = Field(default=None)
response: Optional[Message] = Field(default=None)
"""


class ContextError(Exception):
"""Raised when context methods are not used correctly."""

Expand Down Expand Up @@ -110,9 +102,9 @@ class Context(BaseModel):
It is set (and managed) by :py:class:`~chatsky.context_storages.DBContextStorage`.
"""
current_turn_id: int = Field(default=0)
labels: ContextDict[int, AbsoluteNodeLabel] = Field(default_factory=ContextDict)
requests: ContextDict[int, Message] = Field(default_factory=ContextDict)
responses: ContextDict[int, Message] = Field(default_factory=ContextDict)
labels: ContextDict[int, AbsoluteNodeLabel] = Field(default_factory=lambda: ContextDict.empty(AbsoluteNodeLabel))
requests: ContextDict[int, Message] = Field(default_factory=lambda: ContextDict.empty(Message))
responses: ContextDict[int, Message] = Field(default_factory=lambda: ContextDict.empty(Message))
"""
`turns` stores the history of all passed `labels`, `requests`, and `responses`.
Expand Down Expand Up @@ -227,6 +219,13 @@ def current_node(self) -> Node:
raise ContextError("Current node is not set.")
return node

async def turns(self, key: slice) -> Iterable[Tuple[AbsoluteNodeLabel, Message, Message]]:
return zip(*gather(
self.labels.__getitem__(key),
self.requests.__getitem__(key),
self.responses.__getitem__(key)
))

def __eq__(self, value: object) -> bool:
if isinstance(value, Context):
return (
Expand Down
11 changes: 8 additions & 3 deletions chatsky/core/ctx_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,18 @@ class ContextDict(BaseModel, Generic[K, V]):
_value_type: Optional[TypeAdapter[Type[V]]] = PrivateAttr(None)

@classmethod
async def new(cls, storage: DBContextStorage, id: str, field: str, value_type: Type[V]) -> "ContextDict":
def empty(cls, value_type: Type[V]) -> "ContextDict":
instance = cls()
instance._value_type = TypeAdapter(value_type)
return instance

@classmethod
async def new(cls, storage: DBContextStorage, id: str, field: str, value_type: Type[V]) -> "ContextDict":
instance = cls.empty(value_type)
logger.debug(f"Disconnected context dict created for id {id} and field name: {field}")
instance._storage = storage
instance._ctx_id = id
instance._field_name = field
instance._value_type = TypeAdapter(value_type)
instance._storage = storage
return instance

@classmethod
Expand Down
6 changes: 1 addition & 5 deletions tests/core/test_context_dict.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import pytest

from pydantic import TypeAdapter

from chatsky.context_storages import MemoryContextStorage
from chatsky.core.message import Message
from chatsky.core.ctx_dict import ContextDict
Expand All @@ -11,9 +9,7 @@ class TestContextDict:
@pytest.fixture(scope="function")
async def empty_dict(self) -> ContextDict:
# Empty (disconnected) context dictionary
ctx_dict = ContextDict()
ctx_dict._value_type = TypeAdapter(Message)
return ctx_dict
return ContextDict.empty(Message)

@pytest.fixture(scope="function")
async def attached_dict(self) -> ContextDict:
Expand Down

0 comments on commit ce6c8b6

Please sign in to comment.