diff --git a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-azure/llama_index/storage/chat_store/azure/base.py b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-azure/llama_index/storage/chat_store/azure/base.py index 34d89a370efbd..e676f1e70c3ac 100644 --- a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-azure/llama_index/storage/chat_store/azure/base.py +++ b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-azure/llama_index/storage/chat_store/azure/base.py @@ -1,4 +1,3 @@ -import asyncio from itertools import chain from typing import Any, List, Optional @@ -12,6 +11,7 @@ ) from azure.data.tables.aio import TableServiceClient as AsyncTableServiceClient +from llama_index.core.async_utils import asyncio_run from llama_index.core.bridge.pydantic import Field, PrivateAttr from llama_index.core.llms import ChatMessage from llama_index.core.storage.chat_store.base import BaseChatStore @@ -161,7 +161,7 @@ def from_aad_token( def set_messages(self, key: str, messages: List[ChatMessage]) -> None: """Set messages for a key.""" - asyncio.run(self.aset_messages(key, messages)) + asyncio_run(self.aset_messages(key, messages)) async def aset_messages(self, key: str, messages: List[ChatMessage]) -> None: """Asynchronoulsy set messages for a key.""" @@ -208,7 +208,7 @@ async def aset_messages(self, key: str, messages: List[ChatMessage]) -> None: def get_messages(self, key: str) -> List[ChatMessage]: """Get messages for a key.""" - asyncio.run(self.aget_messages(key)) + return asyncio_run(self.aget_messages(key)) async def aget_messages(self, key: str) -> List[ChatMessage]: """Asynchronously get messages for a key.""" @@ -216,14 +216,18 @@ async def aget_messages(self, key: str) -> List[ChatMessage]: self.chat_table_name ) entities = chat_client.query_entities(f"PartitionKey eq '{key}'") - return [ - ChatMessage.parse_obj(deserialize(self.service_mode, entity)) - for entity in entities - ] + messages = [] + + async for entity in entities: + messages.append( + ChatMessage.model_validate(deserialize(self.service_mode, entity)) + ) + + return messages def add_message(self, key: str, message: ChatMessage, idx: int = None): """Add a message for a key.""" - asyncio.run(self.async_add_message(key, message, idx)) + asyncio_run(self.async_add_message(key, message, idx)) async def async_add_message(self, key: str, message: ChatMessage, idx: int = None): metadata_client = await self._atable_service_client.create_table_if_not_exists( @@ -255,11 +259,11 @@ async def async_add_message(self, key: str, message: ChatMessage, idx: int = Non metadata["LastMessageRowKey"] = self._to_row_key(idx) metadata["MessageCount"] = next_index + 1 # Update medatada - metadata_client.upsert_entity(metadata, UpdateMode.MERGE) + await metadata_client.upsert_entity(metadata, UpdateMode.MERGE) def delete_messages(self, key: str) -> Optional[List[ChatMessage]]: # Delete all messages for the key - asyncio.run(self.adelete_messages(key)) + return asyncio_run(self.adelete_messages(key)) async def adelete_messages(self, key: str) -> Optional[List[ChatMessage]]: """Asynchronously delete all messages for a key.""" @@ -280,7 +284,7 @@ async def adelete_messages(self, key: str) -> Optional[List[ChatMessage]]: def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]: """Delete specific message for a key.""" - asyncio.run(self.adelete_message(key, idx)) + return asyncio_run(self.adelete_message(key, idx)) async def adelete_message(self, key: str, idx: int) -> Optional[ChatMessage]: """Asynchronously delete specific message for a key.""" @@ -313,7 +317,7 @@ async def adelete_message(self, key: str, idx: int) -> Optional[ChatMessage]: def delete_last_message(self, key: str) -> Optional[ChatMessage]: """Delete last message for a key.""" - asyncio.run(self.adelete_last_message(key)) + return asyncio_run(self.adelete_last_message(key)) async def adelete_last_message(self, key: str) -> Optional[ChatMessage]: """Async delete last message for a key.""" @@ -342,7 +346,7 @@ async def adelete_last_message(self, key: str) -> Optional[ChatMessage]: def get_keys(self) -> List[str]: """Get all keys.""" - asyncio.run(self.aget_keys()) + return asyncio_run(self.aget_keys()) async def aget_keys(self) -> List[str]: """Asynchronously get all keys.""" @@ -352,7 +356,12 @@ async def aget_keys(self) -> List[str]: entities = metadata_client.query_entities( f"PartitionKey eq '{self.metadata_partition_key}'" ) - return [entity["RowKey"] for entity in entities] + + keys = [] + async for entity in entities: + keys.append(entity["RowKey"]) + + return keys @classmethod def class_name(cls) -> str: diff --git a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-azure/pyproject.toml b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-azure/pyproject.toml index 89a383f059fce..c712f745f27d8 100644 --- a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-azure/pyproject.toml +++ b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-azure/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-storage-chat-store-azure" readme = "README.md" -version = "0.2.1" +version = "0.2.2" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" @@ -42,6 +42,7 @@ mypy = "0.991" pre-commit = "3.2.0" pylint = "2.15.10" pytest = "7.2.1" +pytest-asyncio = "*" pytest-mock = "3.11.1" ruff = "0.0.292" tree-sitter-languages = "^1.8.0" diff --git a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-azure/tests/test_chat_store_azure_chat_store.py b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-azure/tests/test_chat_store_azure_chat_store.py index 74526c7dd83b3..6853cd08f2417 100644 --- a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-azure/tests/test_chat_store_azure_chat_store.py +++ b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-azure/tests/test_chat_store_azure_chat_store.py @@ -1,7 +1,87 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock from llama_index.core.storage.chat_store.base import BaseChatStore from llama_index.storage.chat_store.azure import AzureChatStore +from llama_index.core.llms import ChatMessage +from azure.data.tables import TableServiceClient, TableClient, TableEntity +from azure.data.tables.aio import TableServiceClient as AsyncTableServiceClient def test_class(): names_of_base_classes = [b.__name__ for b in AzureChatStore.__mro__] assert BaseChatStore.__name__ in names_of_base_classes + + +@pytest.fixture() +def azure_chat_store(): + mock_table_service_client = MagicMock(spec=TableServiceClient) + mock_atable_service_client = AsyncMock(spec=AsyncTableServiceClient) + return AzureChatStore(mock_table_service_client, mock_atable_service_client) + + +def test_set_messages(azure_chat_store): + key = "test_key" + messages = [ + ChatMessage(role="user", content="Hello"), + ChatMessage(role="assistant", content="Hi there!"), + ] + + mock_chat_client = AsyncMock(spec=TableClient) + mock_metadata_client = AsyncMock(spec=TableClient) + + azure_chat_store._atable_service_client.create_table_if_not_exists.side_effect = [ + mock_chat_client, + mock_metadata_client, + ] + + # Mock the query_entities method to return an empty list + mock_chat_client.query_entities.return_value = AsyncMock() + mock_chat_client.query_entities.return_value.__aiter__.return_value = [] + mock_chat_client.submit_transaction = AsyncMock() + mock_metadata_client.upsert_entity = AsyncMock() + + azure_chat_store.set_messages(key, messages) + + azure_chat_store._atable_service_client.create_table_if_not_exists.assert_any_call( + azure_chat_store.chat_table_name + ) + azure_chat_store._atable_service_client.create_table_if_not_exists.assert_any_call( + azure_chat_store.metadata_table_name + ) + mock_chat_client.submit_transaction.assert_called_once() + mock_metadata_client.upsert_entity.assert_called_once() + + +def test_get_messages(azure_chat_store): + key = "test_key" + + mock_chat_client = AsyncMock(spec=TableClient) + azure_chat_store._atable_service_client.create_table_if_not_exists.return_value = ( + mock_chat_client + ) + + # Create mock TableEntity objects + mock_entities = [ + TableEntity( + PartitionKey=key, RowKey="0000000000", role="user", content="Hello" + ), + TableEntity( + PartitionKey=key, RowKey="0000000001", role="assistant", content="Hi there!" + ), + ] + + # Set up the mock to return an async iterator of the mock entities + mock_chat_client.query_entities.return_value = AsyncMock() + mock_chat_client.query_entities.return_value.__aiter__.return_value = mock_entities + + result = azure_chat_store.get_messages(key) + + azure_chat_store._atable_service_client.create_table_if_not_exists.assert_called_once_with( + azure_chat_store.chat_table_name + ) + mock_chat_client.query_entities.assert_called_once_with(f"PartitionKey eq '{key}'") + assert len(result) == 2 + assert result[0].role == "user" + assert result[0].content == "Hello" + assert result[1].role == "assistant" + assert result[1].content == "Hi there!"