Skip to content

Commit

Permalink
Astra DB, chat message history (#13836)
Browse files Browse the repository at this point in the history
This PR adds a chat message history component that uses Astra DB for
persistence through the JSON API.
The `astrapy` package is required for this class to work.

I have added tests and a small notebook, and updated the relevant
references in the other docs pages.

(@rlancemartin this is the counterpart of the Cassandra equivalent class
you so helpfully reviewed back at the end of June)

Thank you!
  • Loading branch information
hemidactylus authored Nov 25, 2023
1 parent 58f7e10 commit 272df9d
Show file tree
Hide file tree
Showing 7 changed files with 386 additions and 0 deletions.
147 changes: 147 additions & 0 deletions docs/docs/integrations/memory/astradb_chat_message_history.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "90cd3ded",
"metadata": {},
"source": [
"# Astra DB \n",
"\n",
"> DataStax [Astra DB](https://docs.datastax.com/en/astra/home/astra.html) is a serverless vector-capable database built on Cassandra and made conveniently available through an easy-to-use JSON API.\n",
"\n",
"This notebook goes over how to use Astra DB to store chat message history."
]
},
{
"cell_type": "markdown",
"id": "f507f58b-bf22-4a48-8daf-68d869bcd1ba",
"metadata": {},
"source": [
"## Setting up\n",
"\n",
"To run this notebook you need a running Astra DB. Get the connection secrets on your Astra dashboard:\n",
"\n",
"- the API Endpoint looks like `https://01234567-89ab-cdef-0123-456789abcdef-us-east1.apps.astra.datastax.com`;\n",
"- the Token looks like `AstraCS:6gBhNmsk135...`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d7092199",
"metadata": {},
"outputs": [],
"source": [
"!pip install --quiet \"astrapy>=0.6.2\""
]
},
{
"cell_type": "markdown",
"id": "e3d97b65",
"metadata": {},
"source": [
"### Set up the database connection parameters and secrets"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "163d97f0",
"metadata": {},
"outputs": [
{
"name": "stdin",
"output_type": "stream",
"text": [
"ASTRA_DB_API_ENDPOINT = https://01234567-89ab-cdef-0123-456789abcdef-us-east1.apps.astra.datastax.com\n",
"ASTRA_DB_APPLICATION_TOKEN = ········\n"
]
}
],
"source": [
"import getpass\n",
"\n",
"ASTRA_DB_API_ENDPOINT = input(\"ASTRA_DB_API_ENDPOINT = \")\n",
"ASTRA_DB_APPLICATION_TOKEN = getpass.getpass(\"ASTRA_DB_APPLICATION_TOKEN = \")"
]
},
{
"cell_type": "markdown",
"id": "55860b2d",
"metadata": {},
"source": [
"Depending on whether local or cloud-based Astra DB, create the corresponding database connection \"Session\" object."
]
},
{
"cell_type": "markdown",
"id": "36c163e8",
"metadata": {},
"source": [
"## Example"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d15e3302",
"metadata": {},
"outputs": [],
"source": [
"from langchain.memory import AstraDBChatMessageHistory\n",
"\n",
"message_history = AstraDBChatMessageHistory(\n",
" session_id=\"test-session\",\n",
" api_endpoint=ASTRA_DB_API_ENDPOINT,\n",
" token=ASTRA_DB_APPLICATION_TOKEN,\n",
")\n",
"\n",
"message_history.add_user_message(\"hi!\")\n",
"\n",
"message_history.add_ai_message(\"whats up?\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "64fc465e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[HumanMessage(content='hi!'), AIMessage(content='whats up?')]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"message_history.messages"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
14 changes: 14 additions & 0 deletions docs/docs/integrations/providers/astradb.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ vector_store = AstraDB(
Learn more in the [example notebook](/docs/integrations/vectorstores/astradb).


### Memory

```python
from langchain.memory import AstraDBChatMessageHistory
message_history = AstraDBChatMessageHistory(
session_id="test-session"
api_endpoint="...",
token="...",
)
```

Learn more in the [example notebook](/docs/integrations/memory/astradb_chat_message_history).


## Apache Cassandra and Astra DB through CQL

> [Cassandra](https://cassandra.apache.org/) is a NoSQL, row-oriented, highly scalable and highly available database.
Expand Down
2 changes: 2 additions & 0 deletions libs/langchain/langchain/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from langchain.memory.buffer_window import ConversationBufferWindowMemory
from langchain.memory.chat_message_histories import (
AstraDBChatMessageHistory,
CassandraChatMessageHistory,
ChatMessageHistory,
CosmosDBChatMessageHistory,
Expand Down Expand Up @@ -68,6 +69,7 @@
from langchain.memory.zep_memory import ZepMemory

__all__ = [
"AstraDBChatMessageHistory",
"CassandraChatMessageHistory",
"ChatMessageHistory",
"CombinedMemory",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from langchain.memory.chat_message_histories.astradb import (
AstraDBChatMessageHistory,
)
from langchain.memory.chat_message_histories.cassandra import (
CassandraChatMessageHistory,
)
Expand Down Expand Up @@ -31,6 +34,7 @@
from langchain.memory.chat_message_histories.zep import ZepChatMessageHistory

__all__ = [
"AstraDBChatMessageHistory",
"ChatMessageHistory",
"CassandraChatMessageHistory",
"CosmosDBChatMessageHistory",
Expand Down
114 changes: 114 additions & 0 deletions libs/langchain/langchain/memory/chat_message_histories/astradb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Astra DB - based chat message history, based on astrapy."""
from __future__ import annotations

import json
import time
import typing
from typing import List, Optional

if typing.TYPE_CHECKING:
from astrapy.db import AstraDB as LibAstraDB

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)

DEFAULT_COLLECTION_NAME = "langchain_message_store"


class AstraDBChatMessageHistory(BaseChatMessageHistory):
"""Chat message history that stores history in Astra DB.
Args (only keyword-arguments accepted):
session_id: arbitrary key that is used to store the messages
of a single chat session.
collection_name (str): name of the Astra DB collection to create/use.
token (Optional[str]): API token for Astra DB usage.
api_endpoint (Optional[str]): full URL to the API endpoint,
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
namespace (Optional[str]): namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
"""

def __init__(
self,
*,
session_id: str,
collection_name: str = DEFAULT_COLLECTION_NAME,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[LibAstraDB] = None, # type 'astrapy.db.AstraDB'
namespace: Optional[str] = None,
) -> None:
"""Create an Astra DB chat message history."""
try:
from astrapy.db import AstraDB as LibAstraDB
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import a recent astrapy python package. "
"Please install it with `pip install --upgrade astrapy`."
)

# Conflicting-arg checks:
if astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' to AstraDB if passing "
"'token' and 'api_endpoint'."
)

self.session_id = session_id
self.collection_name = collection_name
self.token = token
self.api_endpoint = api_endpoint
self.namespace = namespace
if astra_db_client is not None:
self.astra_db = astra_db_client
else:
self.astra_db = LibAstraDB(
token=self.token,
api_endpoint=self.api_endpoint,
namespace=self.namespace,
)
self.collection = self.astra_db.create_collection(self.collection_name)

@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve all session messages from DB"""
message_blobs = [
doc["body_blob"]
for doc in sorted(
self.collection.paginated_find(
filter={
"session_id": self.session_id,
},
projection={
"timestamp": 1,
"body_blob": 1,
},
),
key=lambda _doc: _doc["timestamp"],
)
]
items = [json.loads(message_blob) for message_blob in message_blobs]
messages = messages_from_dict(items)
return messages

def add_message(self, message: BaseMessage) -> None:
"""Write a message to the table"""
self.collection.insert_one(
{
"timestamp": time.time(),
"session_id": self.session_id,
"body_blob": json.dumps(message_to_dict(message)),
}
)

def clear(self) -> None:
"""Clear session memory from DB"""
self.collection.delete_many(filter={"session_id": self.session_id})
Loading

0 comments on commit 272df9d

Please sign in to comment.