-
Notifications
You must be signed in to change notification settings - Fork 15.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Astra DB, chat message history (#13836)
This PR adds a chat message history component that uses Astra DB for persistence through the JSON API. The `astrapy` package is required for this class to work. I have added tests and a small notebook, and updated the relevant references in the other docs pages. (@rlancemartin this is the counterpart of the Cassandra equivalent class you so helpfully reviewed back at the end of June) Thank you!
- Loading branch information
1 parent
58f7e10
commit 272df9d
Showing
7 changed files
with
386 additions
and
0 deletions.
There are no files selected for viewing
147 changes: 147 additions & 0 deletions
147
docs/docs/integrations/memory/astradb_chat_message_history.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "90cd3ded", | ||
"metadata": {}, | ||
"source": [ | ||
"# Astra DB \n", | ||
"\n", | ||
"> DataStax [Astra DB](https://docs.datastax.com/en/astra/home/astra.html) is a serverless vector-capable database built on Cassandra and made conveniently available through an easy-to-use JSON API.\n", | ||
"\n", | ||
"This notebook goes over how to use Astra DB to store chat message history." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f507f58b-bf22-4a48-8daf-68d869bcd1ba", | ||
"metadata": {}, | ||
"source": [ | ||
"## Setting up\n", | ||
"\n", | ||
"To run this notebook you need a running Astra DB. Get the connection secrets on your Astra dashboard:\n", | ||
"\n", | ||
"- the API Endpoint looks like `https://01234567-89ab-cdef-0123-456789abcdef-us-east1.apps.astra.datastax.com`;\n", | ||
"- the Token looks like `AstraCS:6gBhNmsk135...`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d7092199", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install --quiet \"astrapy>=0.6.2\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e3d97b65", | ||
"metadata": {}, | ||
"source": [ | ||
"### Set up the database connection parameters and secrets" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "163d97f0", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdin", | ||
"output_type": "stream", | ||
"text": [ | ||
"ASTRA_DB_API_ENDPOINT = https://01234567-89ab-cdef-0123-456789abcdef-us-east1.apps.astra.datastax.com\n", | ||
"ASTRA_DB_APPLICATION_TOKEN = ········\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import getpass\n", | ||
"\n", | ||
"ASTRA_DB_API_ENDPOINT = input(\"ASTRA_DB_API_ENDPOINT = \")\n", | ||
"ASTRA_DB_APPLICATION_TOKEN = getpass.getpass(\"ASTRA_DB_APPLICATION_TOKEN = \")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "55860b2d", | ||
"metadata": {}, | ||
"source": [ | ||
"Depending on whether local or cloud-based Astra DB, create the corresponding database connection \"Session\" object." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "36c163e8", | ||
"metadata": {}, | ||
"source": [ | ||
"## Example" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "d15e3302", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain.memory import AstraDBChatMessageHistory\n", | ||
"\n", | ||
"message_history = AstraDBChatMessageHistory(\n", | ||
" session_id=\"test-session\",\n", | ||
" api_endpoint=ASTRA_DB_API_ENDPOINT,\n", | ||
" token=ASTRA_DB_APPLICATION_TOKEN,\n", | ||
")\n", | ||
"\n", | ||
"message_history.add_user_message(\"hi!\")\n", | ||
"\n", | ||
"message_history.add_ai_message(\"whats up?\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "64fc465e", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"[HumanMessage(content='hi!'), AIMessage(content='whats up?')]" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"message_history.messages" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
114 changes: 114 additions & 0 deletions
114
libs/langchain/langchain/memory/chat_message_histories/astradb.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
"""Astra DB - based chat message history, based on astrapy.""" | ||
from __future__ import annotations | ||
|
||
import json | ||
import time | ||
import typing | ||
from typing import List, Optional | ||
|
||
if typing.TYPE_CHECKING: | ||
from astrapy.db import AstraDB as LibAstraDB | ||
|
||
from langchain_core.chat_history import BaseChatMessageHistory | ||
from langchain_core.messages import ( | ||
BaseMessage, | ||
message_to_dict, | ||
messages_from_dict, | ||
) | ||
|
||
DEFAULT_COLLECTION_NAME = "langchain_message_store" | ||
|
||
|
||
class AstraDBChatMessageHistory(BaseChatMessageHistory): | ||
"""Chat message history that stores history in Astra DB. | ||
Args (only keyword-arguments accepted): | ||
session_id: arbitrary key that is used to store the messages | ||
of a single chat session. | ||
collection_name (str): name of the Astra DB collection to create/use. | ||
token (Optional[str]): API token for Astra DB usage. | ||
api_endpoint (Optional[str]): full URL to the API endpoint, | ||
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com". | ||
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*, | ||
you can pass an already-created 'astrapy.db.AstraDB' instance. | ||
namespace (Optional[str]): namespace (aka keyspace) where the | ||
collection is created. Defaults to the database's "default namespace". | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
session_id: str, | ||
collection_name: str = DEFAULT_COLLECTION_NAME, | ||
token: Optional[str] = None, | ||
api_endpoint: Optional[str] = None, | ||
astra_db_client: Optional[LibAstraDB] = None, # type 'astrapy.db.AstraDB' | ||
namespace: Optional[str] = None, | ||
) -> None: | ||
"""Create an Astra DB chat message history.""" | ||
try: | ||
from astrapy.db import AstraDB as LibAstraDB | ||
except (ImportError, ModuleNotFoundError): | ||
raise ImportError( | ||
"Could not import a recent astrapy python package. " | ||
"Please install it with `pip install --upgrade astrapy`." | ||
) | ||
|
||
# Conflicting-arg checks: | ||
if astra_db_client is not None: | ||
if token is not None or api_endpoint is not None: | ||
raise ValueError( | ||
"You cannot pass 'astra_db_client' to AstraDB if passing " | ||
"'token' and 'api_endpoint'." | ||
) | ||
|
||
self.session_id = session_id | ||
self.collection_name = collection_name | ||
self.token = token | ||
self.api_endpoint = api_endpoint | ||
self.namespace = namespace | ||
if astra_db_client is not None: | ||
self.astra_db = astra_db_client | ||
else: | ||
self.astra_db = LibAstraDB( | ||
token=self.token, | ||
api_endpoint=self.api_endpoint, | ||
namespace=self.namespace, | ||
) | ||
self.collection = self.astra_db.create_collection(self.collection_name) | ||
|
||
@property | ||
def messages(self) -> List[BaseMessage]: # type: ignore | ||
"""Retrieve all session messages from DB""" | ||
message_blobs = [ | ||
doc["body_blob"] | ||
for doc in sorted( | ||
self.collection.paginated_find( | ||
filter={ | ||
"session_id": self.session_id, | ||
}, | ||
projection={ | ||
"timestamp": 1, | ||
"body_blob": 1, | ||
}, | ||
), | ||
key=lambda _doc: _doc["timestamp"], | ||
) | ||
] | ||
items = [json.loads(message_blob) for message_blob in message_blobs] | ||
messages = messages_from_dict(items) | ||
return messages | ||
|
||
def add_message(self, message: BaseMessage) -> None: | ||
"""Write a message to the table""" | ||
self.collection.insert_one( | ||
{ | ||
"timestamp": time.time(), | ||
"session_id": self.session_id, | ||
"body_blob": json.dumps(message_to_dict(message)), | ||
} | ||
) | ||
|
||
def clear(self) -> None: | ||
"""Clear session memory from DB""" | ||
self.collection.delete_many(filter={"session_id": self.session_id}) |
Oops, something went wrong.