Skip to content

Commit

Permalink
test-time comparison fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
pseusys committed Nov 29, 2024
1 parent e40786c commit a54df18
Showing 1 changed file with 13 additions and 19 deletions.
32 changes: 13 additions & 19 deletions chatsky/context_storages/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
This class implements the basic functionality and can be extended to add additional features as needed.
"""

from __future__ import annotations
from abc import ABC, abstractmethod
from asyncio import Lock
from json import loads
from functools import wraps
from importlib import import_module
from logging import getLogger
Expand Down Expand Up @@ -52,29 +54,16 @@ class ContextInfo(BaseModel):
created_at: int = Field(default_factory=time_ns)
updated_at: int = Field(default_factory=time_ns)
misc: Dict[str, Any] = Field(default_factory=dict)
framework_data: FrameworkData = Field(default_factory=FrameworkData)
framework_data: FrameworkData = Field(default_factory=dict, validate_default=True)

_misc_adaptor: TypeAdapter[Dict[str, Any]] = PrivateAttr(default=TypeAdapter(Dict[str, Any]))

@field_validator("misc")
@field_validator("framework_data", "misc", mode="before")
@classmethod
def _validate_misc(cls, value: Any) -> Dict[str, Any]:
if isinstance(value, Dict):
return value
elif isinstance(value, bytes) or isinstance(value, str):
return cls._misc_adaptor.validate_json(value)
else:
raise ValidationError(f"Value of type {type(value).__name__} can not be validated as misc!")

@field_validator("framework_data")
@classmethod
def _validate_framework_data(cls, value: Any) -> FrameworkData:
if isinstance(value, FrameworkData):
return value
elif isinstance(value, bytes) or isinstance(value, str):
return FrameworkData.model_validate_json(value)
else:
raise ValidationError(f"Value of type {type(value).__name__} can not be validated as framework data!")
def _validate_framework_data(cls, value: Any) -> Dict:
if isinstance(value, bytes) or isinstance(value, str):
value = loads(value)
return value

@field_serializer("misc", when_used="always")
def _serialize_misc(self, misc: Dict[str, Any]) -> bytes:
Expand All @@ -84,6 +73,11 @@ def _serialize_misc(self, misc: Dict[str, Any]) -> bytes:
def serialize_courses_in_order(self, framework_data: FrameworkData) -> bytes:
return framework_data.model_dump_json().encode()

def __eq__(self, other: Any) -> bool:
if isinstance(other, BaseModel):
return self.model_dump() == other.model_dump()
return super().__eq__(other)


class DBContextStorage(ABC):
_default_subscript_value: int = 3
Expand Down

0 comments on commit a54df18

Please sign in to comment.