Skip to content

Commit

Permalink
Fix async methods in azure chat store (run-llama#16531)
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich authored Oct 14, 2024
1 parent 7a9deee commit 96cd14f
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
from itertools import chain
from typing import Any, List, Optional

Expand All @@ -12,6 +11,7 @@
)
from azure.data.tables.aio import TableServiceClient as AsyncTableServiceClient

from llama_index.core.async_utils import asyncio_run
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.llms import ChatMessage
from llama_index.core.storage.chat_store.base import BaseChatStore
Expand Down Expand Up @@ -161,7 +161,7 @@ def from_aad_token(

def set_messages(self, key: str, messages: List[ChatMessage]) -> None:
"""Set messages for a key."""
asyncio.run(self.aset_messages(key, messages))
asyncio_run(self.aset_messages(key, messages))

async def aset_messages(self, key: str, messages: List[ChatMessage]) -> None:
"""Asynchronoulsy set messages for a key."""
Expand Down Expand Up @@ -208,22 +208,26 @@ async def aset_messages(self, key: str, messages: List[ChatMessage]) -> None:

def get_messages(self, key: str) -> List[ChatMessage]:
"""Get messages for a key."""
asyncio.run(self.aget_messages(key))
return asyncio_run(self.aget_messages(key))

async def aget_messages(self, key: str) -> List[ChatMessage]:
"""Asynchronously get messages for a key."""
chat_client = await self._atable_service_client.create_table_if_not_exists(
self.chat_table_name
)
entities = chat_client.query_entities(f"PartitionKey eq '{key}'")
return [
ChatMessage.parse_obj(deserialize(self.service_mode, entity))
for entity in entities
]
messages = []

async for entity in entities:
messages.append(
ChatMessage.model_validate(deserialize(self.service_mode, entity))
)

return messages

def add_message(self, key: str, message: ChatMessage, idx: int = None):
"""Add a message for a key."""
asyncio.run(self.async_add_message(key, message, idx))
asyncio_run(self.async_add_message(key, message, idx))

async def async_add_message(self, key: str, message: ChatMessage, idx: int = None):
metadata_client = await self._atable_service_client.create_table_if_not_exists(
Expand Down Expand Up @@ -255,11 +259,11 @@ async def async_add_message(self, key: str, message: ChatMessage, idx: int = Non
metadata["LastMessageRowKey"] = self._to_row_key(idx)
metadata["MessageCount"] = next_index + 1
# Update medatada
metadata_client.upsert_entity(metadata, UpdateMode.MERGE)
await metadata_client.upsert_entity(metadata, UpdateMode.MERGE)

def delete_messages(self, key: str) -> Optional[List[ChatMessage]]:
# Delete all messages for the key
asyncio.run(self.adelete_messages(key))
return asyncio_run(self.adelete_messages(key))

async def adelete_messages(self, key: str) -> Optional[List[ChatMessage]]:
"""Asynchronously delete all messages for a key."""
Expand All @@ -280,7 +284,7 @@ async def adelete_messages(self, key: str) -> Optional[List[ChatMessage]]:

def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]:
"""Delete specific message for a key."""
asyncio.run(self.adelete_message(key, idx))
return asyncio_run(self.adelete_message(key, idx))

async def adelete_message(self, key: str, idx: int) -> Optional[ChatMessage]:
"""Asynchronously delete specific message for a key."""
Expand Down Expand Up @@ -313,7 +317,7 @@ async def adelete_message(self, key: str, idx: int) -> Optional[ChatMessage]:

def delete_last_message(self, key: str) -> Optional[ChatMessage]:
"""Delete last message for a key."""
asyncio.run(self.adelete_last_message(key))
return asyncio_run(self.adelete_last_message(key))

async def adelete_last_message(self, key: str) -> Optional[ChatMessage]:
"""Async delete last message for a key."""
Expand Down Expand Up @@ -342,7 +346,7 @@ async def adelete_last_message(self, key: str) -> Optional[ChatMessage]:

def get_keys(self) -> List[str]:
"""Get all keys."""
asyncio.run(self.aget_keys())
return asyncio_run(self.aget_keys())

async def aget_keys(self) -> List[str]:
"""Asynchronously get all keys."""
Expand All @@ -352,7 +356,12 @@ async def aget_keys(self) -> List[str]:
entities = metadata_client.query_entities(
f"PartitionKey eq '{self.metadata_partition_key}'"
)
return [entity["RowKey"] for entity in entities]

keys = []
async for entity in entities:
keys.append(entity["RowKey"])

return keys

@classmethod
def class_name(cls) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-storage-chat-store-azure"
readme = "README.md"
version = "0.2.1"
version = "0.2.2"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand All @@ -42,6 +42,7 @@ mypy = "0.991"
pre-commit = "3.2.0"
pylint = "2.15.10"
pytest = "7.2.1"
pytest-asyncio = "*"
pytest-mock = "3.11.1"
ruff = "0.0.292"
tree-sitter-languages = "^1.8.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,87 @@
import pytest
from unittest.mock import AsyncMock, MagicMock
from llama_index.core.storage.chat_store.base import BaseChatStore
from llama_index.storage.chat_store.azure import AzureChatStore
from llama_index.core.llms import ChatMessage
from azure.data.tables import TableServiceClient, TableClient, TableEntity
from azure.data.tables.aio import TableServiceClient as AsyncTableServiceClient


def test_class():
names_of_base_classes = [b.__name__ for b in AzureChatStore.__mro__]
assert BaseChatStore.__name__ in names_of_base_classes


@pytest.fixture()
def azure_chat_store():
mock_table_service_client = MagicMock(spec=TableServiceClient)
mock_atable_service_client = AsyncMock(spec=AsyncTableServiceClient)
return AzureChatStore(mock_table_service_client, mock_atable_service_client)


def test_set_messages(azure_chat_store):
key = "test_key"
messages = [
ChatMessage(role="user", content="Hello"),
ChatMessage(role="assistant", content="Hi there!"),
]

mock_chat_client = AsyncMock(spec=TableClient)
mock_metadata_client = AsyncMock(spec=TableClient)

azure_chat_store._atable_service_client.create_table_if_not_exists.side_effect = [
mock_chat_client,
mock_metadata_client,
]

# Mock the query_entities method to return an empty list
mock_chat_client.query_entities.return_value = AsyncMock()
mock_chat_client.query_entities.return_value.__aiter__.return_value = []
mock_chat_client.submit_transaction = AsyncMock()
mock_metadata_client.upsert_entity = AsyncMock()

azure_chat_store.set_messages(key, messages)

azure_chat_store._atable_service_client.create_table_if_not_exists.assert_any_call(
azure_chat_store.chat_table_name
)
azure_chat_store._atable_service_client.create_table_if_not_exists.assert_any_call(
azure_chat_store.metadata_table_name
)
mock_chat_client.submit_transaction.assert_called_once()
mock_metadata_client.upsert_entity.assert_called_once()


def test_get_messages(azure_chat_store):
key = "test_key"

mock_chat_client = AsyncMock(spec=TableClient)
azure_chat_store._atable_service_client.create_table_if_not_exists.return_value = (
mock_chat_client
)

# Create mock TableEntity objects
mock_entities = [
TableEntity(
PartitionKey=key, RowKey="0000000000", role="user", content="Hello"
),
TableEntity(
PartitionKey=key, RowKey="0000000001", role="assistant", content="Hi there!"
),
]

# Set up the mock to return an async iterator of the mock entities
mock_chat_client.query_entities.return_value = AsyncMock()
mock_chat_client.query_entities.return_value.__aiter__.return_value = mock_entities

result = azure_chat_store.get_messages(key)

azure_chat_store._atable_service_client.create_table_if_not_exists.assert_called_once_with(
azure_chat_store.chat_table_name
)
mock_chat_client.query_entities.assert_called_once_with(f"PartitionKey eq '{key}'")
assert len(result) == 2
assert result[0].role == "user"
assert result[0].content == "Hello"
assert result[1].role == "assistant"
assert result[1].content == "Hi there!"

0 comments on commit 96cd14f

Please sign in to comment.