Skip to content

refactor: vectorstores & chat_message_histories #179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 33 additions & 35 deletions langchain_postgres/chat_message_histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,67 +18,69 @@
logger = logging.getLogger(__name__)


def _create_table_and_index(table_name: str) -> List[sql.Composed]:
def _create_table_and_index(table_name: str, schema_name: str) -> List[sql.Composed]:
"""Make a SQL query to create a table."""
index_name = f"idx_{table_name}_session_id"
statements = [
sql.SQL(
"""
CREATE TABLE IF NOT EXISTS {table_name} (
CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} (
id SERIAL PRIMARY KEY,
session_id UUID NOT NULL,
message JSONB NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
).format(table_name=sql.Identifier(table_name)),
).format(schema_name=sql.Identifier(schema_name), table_name=sql.Identifier(table_name)),
sql.SQL(
"""
CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} (session_id);
CREATE INDEX IF NOT EXISTS {index_name} ON {schema_name}.{table_name} (session_id);
"""
).format(
table_name=sql.Identifier(table_name), index_name=sql.Identifier(index_name)
schema_name=sql.Identifier(schema_name), table_name=sql.Identifier(table_name), index_name=sql.Identifier(index_name)
),
]
return statements


def _get_messages_query(table_name: str) -> sql.Composed:
def _get_messages_query(table_name: str, schema_name: str) -> sql.Composed:
"""Make a SQL query to get messages for a given session."""
return sql.SQL(
"SELECT message "
"FROM {table_name} "
"FROM {schema_name}.{table_name} "
"WHERE session_id = %(session_id)s "
"ORDER BY id;"
).format(table_name=sql.Identifier(table_name))
).format(schema_name=sql.Identifier(schema_name), table_name=sql.Identifier(table_name))


def _delete_by_session_id_query(table_name: str) -> sql.Composed:
def _delete_by_session_id_query(table_name: str, schema_name: str) -> sql.Composed:
"""Make a SQL query to delete messages for a given session."""
return sql.SQL(
"DELETE FROM {table_name} WHERE session_id = %(session_id)s;"
).format(table_name=sql.Identifier(table_name))
"DELETE FROM {schema_name}.{table_name} WHERE session_id = %(session_id)s;"
).format(schema_name=sql.Identifier(schema_name) ,table_name=sql.Identifier(table_name))


def _delete_table_query(table_name: str) -> sql.Composed:
def _delete_table_query(table_name: str, schema_name: str) -> sql.Composed:
"""Make a SQL query to delete a table."""
return sql.SQL("DROP TABLE IF EXISTS {table_name};").format(
return sql.SQL("DROP TABLE IF EXISTS {schema_name}.{table_name};").format(
schema_name=sql.Identifier(schema_name),
table_name=sql.Identifier(table_name)
)


def _insert_message_query(table_name: str) -> sql.Composed:
def _insert_message_query(table_name: str, schema_name: str) -> sql.Composed:
"""Make a SQL query to insert a message."""
return sql.SQL(
"INSERT INTO {table_name} (session_id, message) VALUES (%s, %s)"
).format(table_name=sql.Identifier(table_name))
"INSERT INTO {schema_name}.{table_name} (session_id, message) VALUES (%s, %s)"
).format(schema_name=sql.Identifier(schema_name), table_name=sql.Identifier(table_name))


class PostgresChatMessageHistory(BaseChatMessageHistory):
def __init__(
self,
table_name: str,
session_id: str,
schema_name: str = "public",
/,
*,
sync_connection: Optional[psycopg.Connection] = None,
Expand Down Expand Up @@ -203,15 +205,17 @@ def __init__(
"characters and underscores."
)
self._table_name = table_name
self._schema_name = schema_name

@staticmethod
def create_tables(
connection: psycopg.Connection,
table_name: str,
schema_name: str = "public",
/,
) -> None:
"""Create the table schema in the database and create relevant indexes."""
queries = _create_table_and_index(table_name)
queries = _create_table_and_index(table_name, schema_name)
logger.info("Creating schema for table %s", table_name)
with connection.cursor() as cursor:
for query in queries:
Expand All @@ -220,18 +224,18 @@ def create_tables(

@staticmethod
async def acreate_tables(
connection: psycopg.AsyncConnection, table_name: str, /
connection: psycopg.AsyncConnection, table_name: str, schema_name: str = "public", /
) -> None:
"""Create the table schema in the database and create relevant indexes."""
queries = _create_table_and_index(table_name)
queries = _create_table_and_index(table_name, self._schema_name)
logger.info("Creating schema for table %s", table_name)
async with connection.cursor() as cur:
for query in queries:
await cur.execute(query)
await connection.commit()

@staticmethod
def drop_table(connection: psycopg.Connection, table_name: str, /) -> None:
def drop_table(connection: psycopg.Connection, table_name: str, schema_name: str = "public", /) -> None:
"""Delete the table schema in the database.

WARNING:
Expand All @@ -243,15 +247,15 @@ def drop_table(connection: psycopg.Connection, table_name: str, /) -> None:
table_name: The name of the table to create.
"""

query = _delete_table_query(table_name)
query = _delete_table_query(table_name, self._schema_name)
logger.info("Dropping table %s", table_name)
with connection.cursor() as cursor:
cursor.execute(query)
connection.commit()

@staticmethod
async def adrop_table(
connection: psycopg.AsyncConnection, table_name: str, /
connection: psycopg.AsyncConnection, table_name: str, schema_name: str = "public", /
) -> None:
"""Delete the table schema in the database.

Expand All @@ -263,7 +267,7 @@ async def adrop_table(
connection: Async database connection.
table_name: The name of the table to create.
"""
query = _delete_table_query(table_name)
query = _delete_table_query(table_name, self._schema_name)
logger.info("Dropping table %s", table_name)

async with connection.cursor() as acur:
Expand All @@ -283,7 +287,7 @@ def add_messages(self, messages: Sequence[BaseMessage]) -> None:
for message in messages
]

query = _insert_message_query(self._table_name)
query = _insert_message_query(self._table_name, self._schema_name)

with self._connection.cursor() as cursor:
cursor.executemany(query, values)
Expand All @@ -302,7 +306,7 @@ async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
for message in messages
]

query = _insert_message_query(self._table_name)
query = _insert_message_query(self._table_name, self._schema_name)
async with self._aconnection.cursor() as cursor:
await cursor.executemany(query, values)
await self._aconnection.commit()
Expand All @@ -315,7 +319,7 @@ def get_messages(self) -> List[BaseMessage]:
"with a sync connection or use the async aget_messages method instead."
)

query = _get_messages_query(self._table_name)
query = _get_messages_query(self._table_name, self._schema_name)

with self._connection.cursor() as cursor:
cursor.execute(query, {"session_id": self._session_id})
Expand All @@ -332,7 +336,7 @@ async def aget_messages(self) -> List[BaseMessage]:
"with an async connection or use the sync get_messages method instead."
)

query = _get_messages_query(self._table_name)
query = _get_messages_query(self._table_name, self._schema_name)
async with self._aconnection.cursor() as cursor:
await cursor.execute(query, {"session_id": self._session_id})
items = [record[0] for record in await cursor.fetchall()]
Expand All @@ -345,12 +349,6 @@ def messages(self) -> List[BaseMessage]:
"""The abstraction required a property."""
return self.get_messages()

@messages.setter
def messages(self, value: list[BaseMessage]) -> None:
"""Clear the stored messages and appends a list of messages."""
self.clear()
self.add_messages(value)

def clear(self) -> None:
"""Clear the chat message history for the GIVEN session."""
if self._connection is None:
Expand All @@ -359,7 +357,7 @@ def clear(self) -> None:
"with a sync connection or use the async clear method instead."
)

query = _delete_by_session_id_query(self._table_name)
query = _delete_by_session_id_query(self._table_name, self._schema_name)
with self._connection.cursor() as cursor:
cursor.execute(query, {"session_id": self._session_id})
self._connection.commit()
Expand All @@ -372,7 +370,7 @@ async def aclear(self) -> None:
"with an async connection or use the sync clear method instead."
)

query = _delete_by_session_id_query(self._table_name)
query = _delete_by_session_id_query(self._table_name, self._schema_name)
async with self._aconnection.cursor() as cursor:
await cursor.execute(query, {"session_id": self._session_id})
await self._aconnection.commit()
Loading