diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index 56caca04381cf4..d4c2372c692a52 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -14,6 +14,7 @@ chardet>=5.1.0,<6 cloudpathlib>=0.18,<0.19 cloudpickle>=2.0.0 cohere>=4,<6 +crate==1.0.0dev1 databricks-vectorsearch>=0.21,<0.22 datasets>=2.15.0,<3 dgml-utils>=0.3.0,<0.4 @@ -76,6 +77,7 @@ requests-toolbelt>=1.0.0,<2 rspace_client>=2.5.0,<3 scikit-learn>=1.2.2,<2 simsimd>=5.0.0,<6 +sqlalchemy-cratedb>=0.40.0,<1 sqlite-vss>=0.1.2,<0.2 sqlite-vec>=0.1.0,<0.2 sseclient-py>=1.8.0,<2 diff --git a/libs/community/langchain_community/chat_message_histories/__init__.py b/libs/community/langchain_community/chat_message_histories/__init__.py index fc20cacacceab5..4f631a57c81c56 100644 --- a/libs/community/langchain_community/chat_message_histories/__init__.py +++ b/libs/community/langchain_community/chat_message_histories/__init__.py @@ -28,6 +28,9 @@ from langchain_community.chat_message_histories.cosmos_db import ( CosmosDBChatMessageHistory, ) + from langchain_community.chat_message_histories.cratedb import ( + CrateDBChatMessageHistory, + ) from langchain_community.chat_message_histories.dynamodb import ( DynamoDBChatMessageHistory, ) @@ -94,6 +97,7 @@ "CassandraChatMessageHistory", "ChatMessageHistory", "CosmosDBChatMessageHistory", + "CrateDBChatMessageHistory", "DynamoDBChatMessageHistory", "ElasticsearchChatMessageHistory", "FileChatMessageHistory", @@ -120,6 +124,7 @@ "CassandraChatMessageHistory": "langchain_community.chat_message_histories.cassandra", # noqa: E501 "ChatMessageHistory": "langchain_community.chat_message_histories.in_memory", "CosmosDBChatMessageHistory": "langchain_community.chat_message_histories.cosmos_db", # noqa: E501 + "CrateDBChatMessageHistory": "langchain_community.chat_message_histories.cratedb", # noqa: E501 "DynamoDBChatMessageHistory": "langchain_community.chat_message_histories.dynamodb", "ElasticsearchChatMessageHistory": "langchain_community.chat_message_histories.elasticsearch", # noqa: E501 "FileChatMessageHistory": "langchain_community.chat_message_histories.file", diff --git a/libs/community/langchain_community/chat_message_histories/cratedb.py b/libs/community/langchain_community/chat_message_histories/cratedb.py new file mode 100644 index 00000000000000..fb146956d2e0c7 --- /dev/null +++ b/libs/community/langchain_community/chat_message_histories/cratedb.py @@ -0,0 +1,109 @@ +import json +import typing as t + +import sqlalchemy as sa +from langchain.schema import BaseMessage, _message_to_dict, messages_from_dict + +from langchain_community.chat_message_histories.sql import ( + BaseMessageConverter, + SQLChatMessageHistory, +) + + +def create_message_model(table_name, DynamicBase): # type: ignore + """ + Create a message model for a given table name. + + This is a specialized version for CrateDB for generating integer-based primary keys. + TODO: Find a way to converge CrateDB's generate_random_uuid() into a variant + returning its integer value. + + Args: + table_name: The name of the table to use. + DynamicBase: The base class to use for the model. + + Returns: + The model class. + """ + + # Model is declared inside a function to be able to use a dynamic table name. + class Message(DynamicBase): + __tablename__ = table_name + id = sa.Column(sa.BigInteger, primary_key=True, server_default=sa.func.now()) + session_id = sa.Column(sa.Text) + message = sa.Column(sa.Text) + + return Message + + +class CrateDBMessageConverter(BaseMessageConverter): + """ + The default message converter for CrateDBMessageConverter. + + It is the same as the generic `SQLChatMessageHistory` converter, + but swaps in a different `create_message_model` function. + """ + + def __init__(self, table_name: str): + self.model_class = create_message_model(table_name, sa.orm.declarative_base()) + + def from_sql_model(self, sql_message: t.Any) -> BaseMessage: + return messages_from_dict([json.loads(sql_message.message)])[0] + + def to_sql_model(self, message: BaseMessage, session_id: str) -> t.Any: + return self.model_class( + session_id=session_id, message=json.dumps(_message_to_dict(message)) + ) + + def get_sql_model_class(self) -> t.Any: + return self.model_class + + +class CrateDBChatMessageHistory(SQLChatMessageHistory): + """ + It is the same as the generic `SQLChatMessageHistory` implementation, + but swaps in a different message converter by default. + """ + + DEFAULT_MESSAGE_CONVERTER: t.Type[BaseMessageConverter] = CrateDBMessageConverter + + def __init__( + self, + session_id: str, + connection_string: str, + table_name: str = "message_store", + session_id_field_name: str = "session_id", + custom_message_converter: t.Optional[BaseMessageConverter] = None, + ): + from sqlalchemy_cratedb.support import refresh_after_dml + + super().__init__( + session_id, + connection_string, + table_name=table_name, + session_id_field_name=session_id_field_name, + custom_message_converter=custom_message_converter, + ) + + # Patch dialect to invoke `REFRESH TABLE` after each DML operation. + refresh_after_dml(self.Session) + + def _messages_query(self) -> sa.sql.Select: + """ + Construct an SQLAlchemy selectable to query for messages. + For CrateDB, add an `ORDER BY` clause on the primary key. + """ + selectable = super()._messages_query() + selectable = selectable.order_by(self.sql_model_class.id) + return selectable + + def clear(self) -> None: + """ + Needed for CrateDB to synchronize data because `on_flush` does not catch it. + """ + from sqlalchemy_cratedb.support import refresh_table + + outcome = super().clear() + with self.Session() as session: + refresh_table(session, self.sql_model_class) + return outcome diff --git a/libs/community/langchain_community/chat_message_histories/sql.py b/libs/community/langchain_community/chat_message_histories/sql.py index 2c3b2351c471d0..9d7ea7faa7a78f 100644 --- a/libs/community/langchain_community/chat_message_histories/sql.py +++ b/libs/community/langchain_community/chat_message_histories/sql.py @@ -10,12 +10,14 @@ List, Optional, Sequence, + Type, Union, cast, ) from langchain_core._api import deprecated, warn_deprecated -from sqlalchemy import Column, Integer, Text, delete, select +from sqlalchemy import Column, Integer, Text, create_engine, delete, select +from sqlalchemy.sql import Select try: from sqlalchemy.orm import declarative_base @@ -27,7 +29,6 @@ message_to_dict, messages_from_dict, ) -from sqlalchemy import create_engine from sqlalchemy.engine import Engine from sqlalchemy.ext.asyncio import ( AsyncEngine, @@ -38,7 +39,6 @@ Session as SQLSession, ) from sqlalchemy.orm import ( - declarative_base, scoped_session, sessionmaker, ) @@ -55,6 +55,10 @@ class BaseMessageConverter(ABC): """Convert BaseMessage to the SQLAlchemy model.""" + @abstractmethod + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError + @abstractmethod def from_sql_model(self, sql_message: Any) -> BaseMessage: """Convert a SQLAlchemy model to a BaseMessage instance.""" @@ -146,6 +150,8 @@ class SQLChatMessageHistory(BaseChatMessageHistory): """ + DEFAULT_MESSAGE_CONVERTER: Type[BaseMessageConverter] = DefaultMessageConverter + @property @deprecated("0.2.2", removal="1.0", alternative="session_maker") def Session(self) -> Union[scoped_session, async_sessionmaker]: @@ -220,7 +226,9 @@ def __init__( self.session_maker = scoped_session(sessionmaker(bind=self.engine)) self.session_id_field_name = session_id_field_name - self.converter = custom_message_converter or DefaultMessageConverter(table_name) + self.converter = custom_message_converter or self.DEFAULT_MESSAGE_CONVERTER( + table_name + ) self.sql_model_class = self.converter.get_sql_model_class() if not hasattr(self.sql_model_class, session_id_field_name): raise ValueError("SQL model class must have session_id column") @@ -241,6 +249,17 @@ async def _acreate_table_if_not_exists(self) -> None: await conn.run_sync(self.sql_model_class.metadata.create_all) self._table_created = True + def _messages_query(self) -> Select: + """Construct an SQLAlchemy selectable to query for messages""" + return ( + select(self.sql_model_class) + .where( + getattr(self.sql_model_class, self.session_id_field_name) + == self.session_id + ) + .order_by(self.sql_model_class.id.asc()) + ) + @property def messages(self) -> List[BaseMessage]: # type: ignore """Retrieve all messages from db""" diff --git a/libs/community/tests/integration_tests/memory/test_cratedb.py b/libs/community/tests/integration_tests/memory/test_cratedb.py new file mode 100644 index 00000000000000..d84f34d32f4bad --- /dev/null +++ b/libs/community/tests/integration_tests/memory/test_cratedb.py @@ -0,0 +1,169 @@ +import json +import os +from typing import Any, Generator, Tuple + +import pytest +import sqlalchemy as sa +from langchain.memory import ConversationBufferMemory +from langchain.memory.chat_message_histories import CrateDBChatMessageHistory +from langchain.memory.chat_message_histories.sql import DefaultMessageConverter +from langchain.schema.messages import AIMessage, HumanMessage, _message_to_dict +from sqlalchemy import Column, Integer, Text +from sqlalchemy.orm import DeclarativeBase + + +@pytest.fixture() +def connection_string() -> str: + return os.environ.get( + "TEST_CRATEDB_CONNECTION_STRING", "crate://crate@localhost/?schema=testdrive" + ) + + +@pytest.fixture() +def engine(connection_string: str) -> sa.Engine: + """ + Return an SQLAlchemy engine object. + """ + return sa.create_engine(connection_string, echo=True) + + +@pytest.fixture(autouse=True) +def reset_database(engine: sa.Engine) -> None: + """ + Provision database with table schema and data. + """ + with engine.connect() as connection: + connection.execute(sa.text("DROP TABLE IF EXISTS test_table;")) + connection.commit() + + +@pytest.fixture() +def sql_histories( + connection_string: str, +) -> Generator[Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory], None, None]: + """ + Provide the test cases with data fixtures. + """ + message_history = CrateDBChatMessageHistory( + session_id="123", connection_string=connection_string, table_name="test_table" + ) + # Create history for other session + other_history = CrateDBChatMessageHistory( + session_id="456", connection_string=connection_string, table_name="test_table" + ) + + yield message_history, other_history + message_history.clear() + other_history.clear() + + +def test_add_messages( + sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory], +) -> None: + history1, _ = sql_histories + history1.add_user_message("Hello!") + history1.add_ai_message("Hi there!") + + messages = history1.messages + assert len(messages) == 2 + assert isinstance(messages[0], HumanMessage) + assert isinstance(messages[1], AIMessage) + assert messages[0].content == "Hello!" + assert messages[1].content == "Hi there!" + + +def test_multiple_sessions( + sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory], +) -> None: + history1, history2 = sql_histories + + # first session + history1.add_user_message("Hello!") + history1.add_ai_message("Hi there!") + history1.add_user_message("Whats cracking?") + + # second session + history2.add_user_message("Hellox") + + messages1 = history1.messages + messages2 = history2.messages + + # Ensure the messages are added correctly in the first session + assert len(messages1) == 3, "waat" + assert messages1[0].content == "Hello!" + assert messages1[1].content == "Hi there!" + assert messages1[2].content == "Whats cracking?" + + assert len(messages2) == 1 + assert len(messages1) == 3 + assert messages2[0].content == "Hellox" + assert messages1[0].content == "Hello!" + assert messages1[1].content == "Hi there!" + assert messages1[2].content == "Whats cracking?" + + +def test_clear_messages( + sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory], +) -> None: + sql_history, other_history = sql_histories + sql_history.add_user_message("Hello!") + sql_history.add_ai_message("Hi there!") + assert len(sql_history.messages) == 2 + # Now create another history with different session id + other_history.add_user_message("Hellox") + assert len(other_history.messages) == 1 + assert len(sql_history.messages) == 2 + # Now clear the first history + sql_history.clear() + assert len(sql_history.messages) == 0 + assert len(other_history.messages) == 1 + + +def test_model_no_session_id_field_error(connection_string: str) -> None: + class Base(DeclarativeBase): + pass + + class Model(Base): + __tablename__ = "test_table" + id = Column(Integer, primary_key=True) + test_field = Column(Text) + + class CustomMessageConverter(DefaultMessageConverter): + def get_sql_model_class(self) -> Any: + return Model + + with pytest.raises(ValueError): + CrateDBChatMessageHistory( + "test", + connection_string, + custom_message_converter=CustomMessageConverter("test_table"), + ) + + +def test_memory_with_message_store(connection_string: str) -> None: + """ + Test ConversationBufferMemory with a message store. + """ + # Setup CrateDB as a message store. + message_history = CrateDBChatMessageHistory( + connection_string=connection_string, session_id="test-session" + ) + memory = ConversationBufferMemory( + memory_key="baz", chat_memory=message_history, return_messages=True + ) + + # Add a few messages. + memory.chat_memory.add_ai_message("This is me, the AI") + memory.chat_memory.add_user_message("This is me, the human") + + # Get the message history from the memory store and turn it into JSON. + messages = memory.chat_memory.messages + messages_json = json.dumps([_message_to_dict(msg) for msg in messages]) + + # Verify the outcome. + assert "This is me, the AI" in messages_json + assert "This is me, the human" in messages_json + + # Clear the conversation history, and verify that. + memory.chat_memory.clear() + assert memory.chat_memory.messages == [] diff --git a/libs/community/tests/unit_tests/chat_message_histories/test_imports.py b/libs/community/tests/unit_tests/chat_message_histories/test_imports.py index 4c14a0efd9a56c..b9e2c06b715671 100644 --- a/libs/community/tests/unit_tests/chat_message_histories/test_imports.py +++ b/libs/community/tests/unit_tests/chat_message_histories/test_imports.py @@ -5,6 +5,7 @@ "CassandraChatMessageHistory", "ChatMessageHistory", "CosmosDBChatMessageHistory", + "CrateDBChatMessageHistory", "DynamoDBChatMessageHistory", "ElasticsearchChatMessageHistory", "FileChatMessageHistory",