diff --git a/docs/chat_message_history.ipynb b/docs/chat_message_history.ipynb index 8b1a4cf..e4ff283 100644 --- a/docs/chat_message_history.ipynb +++ b/docs/chat_message_history.ipynb @@ -1,79 +1,168 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Google DATABASE\n", - "\n", - "[Google DATABASE](https://cloud.google.com/DATABASE).\n", - "\n", - "Save chat messages into `DATABASE`." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Pre-reqs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "%pip install PACKAGE_NAME" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from PACKAGE import LOADER" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Basic Usage" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.6" - } - }, - "nbformat": 4, - "nbformat_minor": 4 + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Bigtable\n", + "\n", + "[Bigtable](https://cloud.google.com/bigtable) is a key-value and wide-column store, ideal for fast access to structured, semi-structured, or unstructured data.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Before you begin" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To run this notebook, you will need a [Google Cloud Project](https://developers.google.com/workspace/guides/create-project), a [Bigtable instance](https://cloud.google.com/bigtable/docs/creating-instance), and [Google credentials](https://developers.google.com/workspace/guides/create-credentials)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install langchain-google-bigtable" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize the schema\n", + "\n", + "The schema for BigtableChatMessageHistory requires the instance and table to exist, and have a column family called `langchain`.\n", + "If the table or the column family do not exist, you can use the following function to create them:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from google.cloud import bigtable\n", + "from langchain_google_bigtable import create_chat_history_table\n", + "\n", + "create_chat_history_table(\n", + " instance_id=\"my-instance\",\n", + " table_id=\"my-table\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Usage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_google_bigtable import (\n", + " BigtableChatMessageHistory,\n", + ")\n", + "\n", + "message_history = BigtableChatMessageHistory(\n", + " instance_id=\"my-instance\",\n", + " table_id=\"my-table\",\n", + " session_id=\"user-session-id\",\n", + ")\n", + "\n", + "message_history.add_user_message(\"hi!\")\n", + "message_history.add_ai_message(\"whats up?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "message_history.messages" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cleaning up\n", + "\n", + "When the history of a specific session is obsolete and can be deleted, it can be done the following way.\n", + "Note: Once deleted, the data is no longer stored in Bigtable and is gone forever." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "message_history.clear()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Custom client\n", + "The client created by default is the default client, using only admin=True option. To use a non-default, a [custom client](https://cloud.google.com/python/docs/reference/bigtable/latest/client#class-googlecloudbigtableclientclientprojectnone-credentialsnone-readonlyfalse-adminfalse-clientinfonone-clientoptionsnone-adminclientoptionsnone-channelnone) can be passed to the constructor." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from google.cloud import bigtable\n", + "\n", + "client = (bigtable.Client(...),)\n", + "\n", + "create_chat_history_table(\n", + " instance_id=\"my-instance\",\n", + " table_id=\"my-table\",\n", + " client=client,\n", + ")\n", + "\n", + "custom_client_message_history = BigtableChatMessageHistory(\n", + " instance_id=\"my-instance\",\n", + " table_id=\"my-table\",\n", + " client=client,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/src/langchain_google_bigtable/__init__.py b/src/langchain_google_bigtable/__init__.py index f97003a..658fd06 100644 --- a/src/langchain_google_bigtable/__init__.py +++ b/src/langchain_google_bigtable/__init__.py @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. + +from langchain_google_bigtable.chat_message_history import ( + BigtableChatMessageHistory, + create_chat_history_table, +) from langchain_google_bigtable.document_loader import ( BigtableLoader, BigtableSaver, @@ -19,4 +24,11 @@ MetadataMapping, ) -__all__ = ["BigtableLoader", "BigtableSaver", "MetadataMapping", "Encoding"] +__all__ = [ + "BigtableChatMessageHistory", + "create_chat_history_table", + "BigtableLoader", + "BigtableSaver", + "MetadataMapping", + "Encoding", +] diff --git a/src/langchain_google_bigtable/chat_message_history.py b/src/langchain_google_bigtable/chat_message_history.py new file mode 100644 index 0000000..2084210 --- /dev/null +++ b/src/langchain_google_bigtable/chat_message_history.py @@ -0,0 +1,128 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bigtable-based chat message history""" +from __future__ import annotations + +import json +import re +import time +import uuid +from typing import List, Optional + +from google.cloud import bigtable +from google.cloud.bigtable.row_filters import RowKeyRegexFilter +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.messages import BaseMessage, messages_from_dict + +COLUMN_FAMILY = "langchain" +COLUMN_NAME = "history" + +default_client: Optional[bigtable.Client] = None + + +def create_chat_history_table( + instance_id: str, + table_id: str, + client: Optional[bigtable.Client] = None, +): + table_client = ( + (client or __get_default_client()).instance(instance_id).table(table_id) + ) + if not table_client.exists(): + table_client.create() + + families = table_client.list_column_families() + if COLUMN_FAMILY not in families: + table_client.column_family( + COLUMN_FAMILY, gc_rule=bigtable.column_family.MaxVersionsGCRule(1) + ).create() + + +def __get_default_client() -> bigtable.Client: + global default_client + if default_client is None: + default_client = bigtable.Client(admin=True) + return default_client + + +class BigtableChatMessageHistory(BaseChatMessageHistory): + """Chat message history that stores history in Bigtable. + + Args: + instance_id: The Bigtable instance to use for chat message history. + table_id: The Bigtable table to use for chat message history. + session_id: The session ID. + client : Optional. The pre-created client to query bigtable. + """ + + def __init__( + self, + instance_id: str, + table_id: str, + session_id: str, + client: Optional[bigtable.Client] = None, + ) -> None: + instance = (client or __get_default_client()).instance(instance_id) + if not instance.exists(): + raise NameError(f"Instance {instance_id} does not exist") + + self.table_client = instance.table(table_id) + if not self.table_client.exists(): + raise NameError( + f"Table {table_id} does not exist on instance {instance_id}" + ) + if COLUMN_FAMILY not in self.table_client.list_column_families(): + raise NameError( + f"Column family {COLUMN_FAMILY} does not exist on table {table_id}" + ) + + self.session_id = session_id + + @property + def messages(self) -> List[BaseMessage]: # type: ignore + """Retrieve all session messages from DB""" + rows = self.table_client.read_rows( + filter_=RowKeyRegexFilter( + str.encode("^" + re.escape(self.session_id) + "#.*") + ) + ) + items = [ + json.loads(row.cells[COLUMN_FAMILY][COLUMN_NAME.encode()][0].value.decode()) + for row in rows + ] + messages = messages_from_dict( + [{"type": item["type"], "data": item} for item in items] + ) + return messages + + def add_message(self, message: BaseMessage) -> None: + """Write a message to the table""" + + row_key = str.encode( + self.session_id + + "#" + + str(time.time_ns()).rjust(25, "0") + + "#" + + uuid.uuid4().hex + ) + row = self.table_client.direct_row(row_key) + value = str.encode(message.json()) + row.set_cell(COLUMN_FAMILY, COLUMN_NAME, value) + row.commit() + + def clear(self) -> None: + """Clear session memory from DB""" + row_key_prefix = self.session_id + self.table_client.drop_by_prefix(row_key_prefix) diff --git a/src/test/test_chat_message_history.py b/src/test/test_chat_message_history.py new file mode 100644 index 0000000..042b6a3 --- /dev/null +++ b/src/test/test_chat_message_history.py @@ -0,0 +1,219 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import random +import re +import string +import uuid +from multiprocessing import Process +from typing import Iterator + +import pytest +from google.cloud import bigtable +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage + +from langchain_google_bigtable.chat_message_history import ( + BigtableChatMessageHistory, + create_chat_history_table, +) + +TABLE_ID_PREFIX = "test-table-" + + +@pytest.fixture +def client() -> Iterator[bigtable.Client]: + yield bigtable.Client( + project=get_env_var("PROJECT_ID", "ID of the GCP project"), admin=True + ) + + +@pytest.fixture +def instance_id() -> Iterator[str]: + yield get_env_var("INSTANCE_ID", "ID of the Cloud Bigtable instance") + + +@pytest.fixture +def table_id(instance_id: str, client: bigtable.Client) -> Iterator[str]: + table_id = TABLE_ID_PREFIX + "".join( + random.choice(string.ascii_lowercase) for _ in range(10) + ) + # Create table and column family + create_chat_history_table(instance_id=instance_id, table_id=table_id, client=client) + + yield table_id + + # Teardown + client.instance(instance_id).table(table_id).delete() + + +def test_bigtable_full_workflow( + instance_id: str, table_id: str, client: bigtable.Client +) -> None: + session_id = uuid.uuid4().hex + history = BigtableChatMessageHistory( + instance_id, table_id, session_id, client=client + ) + + history.add_ai_message("Hey! I am AI!") + history.add_user_message("Hey! I am human!") + messages = history.messages + + assert len(messages) == 2 + assert isinstance(messages[0], AIMessage) + assert messages[0].content == "Hey! I am AI!" + assert isinstance(messages[1], HumanMessage) + assert messages[1].content == "Hey! I am human!" + + history.clear() + assert len(history.messages) == 0 + + +def get_index_from_message(message: BaseMessage) -> int: + match = re.search("^Hey! I am (AI|human)! Index: ([0-9]+)$", str(message.content)) + if match: + return int(match[2]) + return 0 + + +def test_bigtable_loads_of_messages( + instance_id: str, table_id: str, client: bigtable.Client +) -> None: + NUM_MESSAGES = 10000 + session_id = uuid.uuid4().hex + history = BigtableChatMessageHistory( + instance_id, table_id, session_id, client=client + ) + + proc = [] + for i in range(NUM_MESSAGES): + p = Process( + target=lambda i: history.add_ai_message(f"Hey! I am AI! Index: {2*i}"), + args=[i], + ) + p.start() + proc.append(p) + p = Process( + target=lambda i: history.add_user_message( + f"Hey! I am human! Index: {2*i+1}" + ), + args=[i], + ) + p.start() + proc.append(p) + + for p in proc: + p.join() + + messages = history.messages + + assert len(messages) == 2 * NUM_MESSAGES + + messages.sort(key=get_index_from_message) + + for i in range(2 * NUM_MESSAGES): + type = AIMessage if i % 2 == 0 else HumanMessage + content = ( + f"Hey! I am AI! Index: {i}" + if i % 2 == 0 + else f"Hey! I am human! Index: {i}" + ) + assert isinstance(messages[i], type) + assert messages[i].content == content + + history.clear() + assert len(history.messages) == 0 + + +def test_bigtable_multiple_sessions( + instance_id: str, table_id: str, client: bigtable.Client +) -> None: + session_id1 = uuid.uuid4().hex + history1 = BigtableChatMessageHistory( + instance_id, table_id, session_id1, client=client + ) + session_id2 = uuid.uuid4().hex + history2 = BigtableChatMessageHistory( + instance_id, table_id, session_id2, client=client + ) + + history1.add_ai_message("Hey! I am AI!") + history2.add_user_message("Hey! I am human!") + messages1 = history1.messages + messages2 = history2.messages + + assert len(messages1) == 1 + assert len(messages2) == 1 + assert isinstance(messages1[0], AIMessage) + assert messages1[0].content == "Hey! I am AI!" + assert isinstance(messages2[0], HumanMessage) + assert messages2[0].content == "Hey! I am human!" + + history1.clear() + assert len(history1.messages) == 0 + assert len(history2.messages) == 1 + + history2.clear() + assert len(history1.messages) == 0 + assert len(history2.messages) == 0 + + +def test_bigtable_missing_instance( + instance_id: str, table_id: str, client: bigtable.Client +) -> None: + non_existent_instance_id = "non-existent" + with pytest.raises(NameError) as excinfo: + BigtableChatMessageHistory( + non_existent_instance_id, table_id, "", client=client + ) + + assert str(excinfo.value) == f"Instance {non_existent_instance_id} does not exist" + + +def test_bigtable_missing_table( + instance_id: str, table_id: str, client: bigtable.Client +) -> None: + non_existent_table_id = "non_existent" + with pytest.raises(NameError) as excinfo: + BigtableChatMessageHistory( + instance_id, non_existent_table_id, "", client=client + ) + assert ( + str(excinfo.value) + == f"Table {non_existent_table_id} does not exist on instance {instance_id}" + ) + + +def test_bigtable_missing_column_family( + instance_id: str, table_id: str, client: bigtable.Client +) -> None: + other_table_id = table_id + "1" + client.instance(instance_id).table(other_table_id).create() + + with pytest.raises(NameError) as excinfo: + BigtableChatMessageHistory(instance_id, other_table_id, "", client=client) + assert ( + str(excinfo.value) + == f"Column family langchain does not exist on table {other_table_id}" + ) + + client.instance(instance_id).table(other_table_id).delete() + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v