Skip to content

Commit

Permalink
astradb: Add AstraDBChatMessageHistory to langchain-astradb package (#…
Browse files Browse the repository at this point in the history
…17732)

Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
cbornet and baskaryan authored Feb 26, 2024
1 parent c06a873 commit b8b5ce0
Show file tree
Hide file tree
Showing 9 changed files with 1,083 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/docs/integrations/providers/astradb.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Learn more in the [example notebook](/docs/integrations/vectorstores/astradb).
## Chat message history

```python
from langchain_community.chat_message_histories import AstraDBChatMessageHistory
from langchain_astradb import AstraDBChatMessageHistory
message_history = AstraDBChatMessageHistory(
session_id="test-session",
api_endpoint="...",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
if TYPE_CHECKING:
from astrapy.db import AstraDB, AsyncAstraDB

from langchain_core._api.deprecation import deprecated
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
Expand All @@ -23,6 +24,11 @@
DEFAULT_COLLECTION_NAME = "langchain_message_store"


@deprecated(
since="0.0.25",
removal="0.2.0",
alternative_import="langchain_astradb.AstraDBChatMessageHistory",
)
class AstraDBChatMessageHistory(BaseChatMessageHistory):
def __init__(
self,
Expand Down
13 changes: 12 additions & 1 deletion libs/partners/astradb/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pip install langchain-astradb
### Vector Store

```python
from langchain_astradb.vectorstores import AstraDBVectorStore
from langchain_astradb import AstraDBVectorStore

my_store = AstraDBVectorStore(
embedding=my_embeddings,
Expand All @@ -30,6 +30,17 @@ my_store = AstraDBVectorStore(
)
```

### Chat message history

```python
from langchain_astradb import AstraDBChatMessageHistory
message_history = AstraDBChatMessageHistory(
session_id="test-session",
api_endpoint="...",
token="...",
)
```

### Store

```python
Expand Down
2 changes: 2 additions & 0 deletions libs/partners/astradb/langchain_astradb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory
from langchain_astradb.storage import AstraDBByteStore, AstraDBStore
from langchain_astradb.vectorstores import AstraDBVectorStore

__all__ = [
"AstraDBByteStore",
"AstraDBStore",
"AstraDBChatMessageHistory",
"AstraDBVectorStore",
]
148 changes: 148 additions & 0 deletions libs/partners/astradb/langchain_astradb/chat_message_histories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""Astra DB - based chat message history, based on astrapy."""
from __future__ import annotations

import json
import time
from typing import List, Optional, Sequence

from astrapy.db import AstraDB, AsyncAstraDB
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)

from langchain_astradb.utils.astradb import (
SetupMode,
_AstraDBCollectionEnvironment,
)

DEFAULT_COLLECTION_NAME = "langchain_message_store"


class AstraDBChatMessageHistory(BaseChatMessageHistory):
def __init__(
self,
*,
session_id: str,
collection_name: str = DEFAULT_COLLECTION_NAME,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
) -> None:
"""Chat message history that stores history in Astra DB.
Args:
session_id: arbitrary key that is used to store the messages
of a single chat session.
collection_name: name of the Astra DB collection to create/use.
token: API token for Astra DB usage.
api_endpoint: full URL to the API endpoint,
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace: namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
"""
self.astra_env = _AstraDBCollectionEnvironment(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)

self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection

self.session_id = session_id
self.collection_name = collection_name

@property
def messages(self) -> List[BaseMessage]:
"""Retrieve all session messages from DB"""
self.astra_env.ensure_db_setup()
message_blobs = [
doc["body_blob"]
for doc in sorted(
self.collection.paginated_find(
filter={
"session_id": self.session_id,
},
projection={
"timestamp": 1,
"body_blob": 1,
},
),
key=lambda _doc: _doc["timestamp"],
)
]
items = [json.loads(message_blob) for message_blob in message_blobs]
messages = messages_from_dict(items)
return messages

@messages.setter
def messages(self, messages: List[BaseMessage]) -> None:
raise NotImplementedError("Use add_messages instead")

async def aget_messages(self) -> List[BaseMessage]:
await self.astra_env.aensure_db_setup()
docs = self.async_collection.paginated_find(
filter={
"session_id": self.session_id,
},
projection={
"timestamp": 1,
"body_blob": 1,
},
)
sorted_docs = sorted(
[doc async for doc in docs],
key=lambda _doc: _doc["timestamp"],
)
message_blobs = [doc["body_blob"] for doc in sorted_docs]
items = [json.loads(message_blob) for message_blob in message_blobs]
messages = messages_from_dict(items)
return messages

def add_messages(self, messages: Sequence[BaseMessage]) -> None:
self.astra_env.ensure_db_setup()
docs = [
{
"timestamp": time.time(),
"session_id": self.session_id,
"body_blob": json.dumps(message_to_dict(message)),
}
for message in messages
]
self.collection.chunked_insert_many(docs)

async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
await self.astra_env.aensure_db_setup()
docs = [
{
"timestamp": time.time(),
"session_id": self.session_id,
"body_blob": json.dumps(message_to_dict(message)),
}
for message in messages
]
await self.async_collection.chunked_insert_many(docs)

def clear(self) -> None:
self.astra_env.ensure_db_setup()
self.collection.delete_many(filter={"session_id": self.session_id})

async def aclear(self) -> None:
await self.astra_env.aensure_db_setup()
await self.async_collection.delete_many(filter={"session_id": self.session_id})
Loading

0 comments on commit b8b5ce0

Please sign in to comment.