Skip to content

Commit

Permalink
turns added and tested
Browse files Browse the repository at this point in the history
  • Loading branch information
pseusys committed Nov 25, 2024
1 parent 1d3859c commit 5514c7b
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 15 deletions.
14 changes: 7 additions & 7 deletions chatsky/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from asyncio import Event, gather
from uuid import uuid4
from time import time_ns
from typing import Any, Callable, Iterable, Optional, Dict, TYPE_CHECKING, Tuple
from typing import Any, Callable, Iterable, Optional, Dict, TYPE_CHECKING, Tuple, Union
import logging

from pydantic import BaseModel, Field, PrivateAttr, TypeAdapter, model_validator
Expand Down Expand Up @@ -219,12 +219,12 @@ def current_node(self) -> Node:
raise ContextError("Current node is not set.")
return node

async def turns(self, key: slice) -> Iterable[Tuple[AbsoluteNodeLabel, Message, Message]]:
return zip(*gather(
self.labels.__getitem__(key),
self.requests.__getitem__(key),
self.responses.__getitem__(key)
))
async def turns(self, key: Union[int, slice]) -> Iterable[Tuple[AbsoluteNodeLabel, Message, Message]]:
turn_ids = range(self.current_turn_id + 1)[key]
turn_ids = turn_ids if isinstance(key, slice) else [turn_ids]
context_dicts = (self.labels, self.requests, self.responses)
turns_lists = await gather(*[gather(*[ctd.__getitem__(ti) for ti in turn_ids]) for ctd in context_dicts])
return zip(*turns_lists)

def __eq__(self, value: object) -> bool:
if isinstance(value, Context):
Expand Down
4 changes: 2 additions & 2 deletions chatsky/core/ctx_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _value_type(self) -> TypeAdapter[Type[V]]:
raise NotImplementedError

@classmethod
async def new(cls, storage: DBContextStorage, id: str, field: str) -> "ContextDict":
async def new(cls, storage: DBContextStorage, id: str, field: str) -> "ContextDict[K, V]":
instance = cls()
logger.debug(f"Disconnected context dict created for id {id} and field name: {field}")
instance._ctx_id = id
Expand All @@ -66,7 +66,7 @@ async def new(cls, storage: DBContextStorage, id: str, field: str) -> "ContextDi
return instance

@classmethod
async def connected(cls, storage: DBContextStorage, id: str, field: str) -> "ContextDict":
async def connected(cls, storage: DBContextStorage, id: str, field: str) -> "ContextDict[K, V]":
logger.debug(f"Connected context dict created for {id}, {field}")
keys, items = await gather(storage.load_field_keys(id, field), storage.load_field_latest(id, field))
val_key_items = [(k, v) for k, v in items if v is not None]
Expand Down
51 changes: 45 additions & 6 deletions tests/core/test_context.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from altair import Key
import pytest

from chatsky.core.context import Context, ContextError
Expand All @@ -13,11 +14,11 @@ class TestLabels:
def ctx(self, context_factory):
return context_factory(forbidden_fields=["requests", "responses"])

def test_raises_on_empty_labels(self, ctx):
def test_raises_on_empty_labels(self, ctx: Context):
with pytest.raises(ContextError):
ctx.last_label

def test_existing_labels(self, ctx):
def test_existing_labels(self, ctx: Context):
ctx.labels[5] = ("flow", "node1")

assert ctx.last_label == AbsoluteNodeLabel(flow_name="flow", node_name="node1")
Expand All @@ -31,14 +32,14 @@ class TestRequests:
def ctx(self, context_factory):
return context_factory(forbidden_fields=["labels", "responses"])

def test_existing_requests(self, ctx):
def test_existing_requests(self, ctx: Context):
ctx.requests[5] = Message(text="text1")
assert ctx.last_request == Message(text="text1")
ctx.requests[6] = "text2"
assert ctx.requests.keys() == [5, 6]
assert ctx.last_request == Message(text="text2")

def test_empty_requests(self, ctx):
def test_empty_requests(self, ctx: Context):
with pytest.raises(ContextError):
ctx.last_request

Expand All @@ -52,14 +53,14 @@ class TestResponses:
def ctx(self, context_factory):
return context_factory(forbidden_fields=["labels", "requests"])

def test_existing_responses(self, ctx):
def test_existing_responses(self, ctx: Context):
ctx.responses[5] = Message(text="text1")
assert ctx.last_response == Message(text="text1")
ctx.responses[6] = "text2"
assert ctx.responses.keys() == [5, 6]
assert ctx.last_response == Message(text="text2")

def test_empty_responses(self, ctx):
def test_empty_responses(self, ctx: Context):
with pytest.raises(ContextError):
ctx.last_response

Expand All @@ -68,6 +69,44 @@ def test_empty_responses(self, ctx):
assert ctx.responses.keys() == [1]


class TestTurns:
@pytest.fixture
def ctx(self, context_factory):
return context_factory()

async def test_complete_turn(self, ctx: Context):
ctx.labels[5] = ("flow", "node5")
ctx.requests[5] = Message(text="text5")
ctx.responses[5] = Message(text="text5")
ctx.current_turn_id = 5

label, request, response = list(await ctx.turns(5))[0]
assert label == AbsoluteNodeLabel(flow_name="flow", node_name="node5")
assert request == Message(text="text5")
assert response == Message(text="text5")

async def test_partial_turn(self, ctx: Context):
ctx.labels[6] = ("flow", "node6")
ctx.requests[6] = Message(text="text6")
ctx.current_turn_id = 6

with pytest.raises(KeyError):
await ctx.turns(6)

async def test_slice_turn(self, ctx: Context):
for i in range(2, 6):
ctx.labels[i] = ("flow", f"node{i}")
ctx.requests[i] = Message(text=f"text{i}")
ctx.responses[i] = Message(text=f"text{i}")
ctx.current_turn_id = i

labels, requests, responses = zip(*(await ctx.turns(slice(2, 6))))
for i in range(2, 6):
assert AbsoluteNodeLabel(flow_name="flow", node_name=f"node{i}") in labels
assert Message(text=f"text{i}") in requests
assert Message(text=f"text{i}") in responses


async def test_pipeline_available():
class MyResponse(BaseResponse):
async def call(self, ctx: Context) -> MessageInitTypes:
Expand Down

0 comments on commit 5514c7b

Please sign in to comment.