diff --git a/langchain_postgres/chat_message_histories.py b/langchain_postgres/chat_message_histories.py index c1ff7cc..95a84ff 100644 --- a/langchain_postgres/chat_message_histories.py +++ b/langchain_postgres/chat_message_histories.py @@ -18,60 +18,61 @@ logger = logging.getLogger(__name__) -def _create_table_and_index(table_name: str) -> List[sql.Composed]: +def _create_table_and_index(table_name: str, schema_name: str) -> List[sql.Composed]: """Make a SQL query to create a table.""" index_name = f"idx_{table_name}_session_id" statements = [ sql.SQL( """ - CREATE TABLE IF NOT EXISTS {table_name} ( + CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ( id SERIAL PRIMARY KEY, session_id UUID NOT NULL, message JSONB NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); """ - ).format(table_name=sql.Identifier(table_name)), + ).format(schema_name=sql.Identifier(schema_name), table_name=sql.Identifier(table_name)), sql.SQL( """ - CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} (session_id); + CREATE INDEX IF NOT EXISTS {index_name} ON {schema_name}.{table_name} (session_id); """ ).format( - table_name=sql.Identifier(table_name), index_name=sql.Identifier(index_name) + schema_name=sql.Identifier(schema_name), table_name=sql.Identifier(table_name), index_name=sql.Identifier(index_name) ), ] return statements -def _get_messages_query(table_name: str) -> sql.Composed: +def _get_messages_query(table_name: str, schema_name: str) -> sql.Composed: """Make a SQL query to get messages for a given session.""" return sql.SQL( "SELECT message " - "FROM {table_name} " + "FROM {schema_name}.{table_name} " "WHERE session_id = %(session_id)s " "ORDER BY id;" - ).format(table_name=sql.Identifier(table_name)) + ).format(schema_name=sql.Identifier(schema_name), table_name=sql.Identifier(table_name)) -def _delete_by_session_id_query(table_name: str) -> sql.Composed: +def _delete_by_session_id_query(table_name: str, schema_name: str) -> sql.Composed: """Make a SQL query to delete messages for a given session.""" return sql.SQL( - "DELETE FROM {table_name} WHERE session_id = %(session_id)s;" - ).format(table_name=sql.Identifier(table_name)) + "DELETE FROM {schema_name}.{table_name} WHERE session_id = %(session_id)s;" + ).format(schema_name=sql.Identifier(schema_name) ,table_name=sql.Identifier(table_name)) -def _delete_table_query(table_name: str) -> sql.Composed: +def _delete_table_query(table_name: str, schema_name: str) -> sql.Composed: """Make a SQL query to delete a table.""" - return sql.SQL("DROP TABLE IF EXISTS {table_name};").format( + return sql.SQL("DROP TABLE IF EXISTS {schema_name}.{table_name};").format( + schema_name=sql.Identifier(schema_name), table_name=sql.Identifier(table_name) ) -def _insert_message_query(table_name: str) -> sql.Composed: +def _insert_message_query(table_name: str, schema_name: str) -> sql.Composed: """Make a SQL query to insert a message.""" return sql.SQL( - "INSERT INTO {table_name} (session_id, message) VALUES (%s, %s)" - ).format(table_name=sql.Identifier(table_name)) + "INSERT INTO {schema_name}.{table_name} (session_id, message) VALUES (%s, %s)" + ).format(schema_name=sql.Identifier(schema_name), table_name=sql.Identifier(table_name)) class PostgresChatMessageHistory(BaseChatMessageHistory): @@ -79,6 +80,7 @@ def __init__( self, table_name: str, session_id: str, + schema_name: str = "public", /, *, sync_connection: Optional[psycopg.Connection] = None, @@ -203,15 +205,17 @@ def __init__( "characters and underscores." ) self._table_name = table_name + self._schema_name = schema_name @staticmethod def create_tables( connection: psycopg.Connection, table_name: str, + schema_name: str = "public", /, ) -> None: """Create the table schema in the database and create relevant indexes.""" - queries = _create_table_and_index(table_name) + queries = _create_table_and_index(table_name, schema_name) logger.info("Creating schema for table %s", table_name) with connection.cursor() as cursor: for query in queries: @@ -220,10 +224,10 @@ def create_tables( @staticmethod async def acreate_tables( - connection: psycopg.AsyncConnection, table_name: str, / + connection: psycopg.AsyncConnection, table_name: str, schema_name: str = "public", / ) -> None: """Create the table schema in the database and create relevant indexes.""" - queries = _create_table_and_index(table_name) + queries = _create_table_and_index(table_name, self._schema_name) logger.info("Creating schema for table %s", table_name) async with connection.cursor() as cur: for query in queries: @@ -231,7 +235,7 @@ async def acreate_tables( await connection.commit() @staticmethod - def drop_table(connection: psycopg.Connection, table_name: str, /) -> None: + def drop_table(connection: psycopg.Connection, table_name: str, schema_name: str = "public", /) -> None: """Delete the table schema in the database. WARNING: @@ -243,7 +247,7 @@ def drop_table(connection: psycopg.Connection, table_name: str, /) -> None: table_name: The name of the table to create. """ - query = _delete_table_query(table_name) + query = _delete_table_query(table_name, self._schema_name) logger.info("Dropping table %s", table_name) with connection.cursor() as cursor: cursor.execute(query) @@ -251,7 +255,7 @@ def drop_table(connection: psycopg.Connection, table_name: str, /) -> None: @staticmethod async def adrop_table( - connection: psycopg.AsyncConnection, table_name: str, / + connection: psycopg.AsyncConnection, table_name: str, schema_name: str = "public", / ) -> None: """Delete the table schema in the database. @@ -263,7 +267,7 @@ async def adrop_table( connection: Async database connection. table_name: The name of the table to create. """ - query = _delete_table_query(table_name) + query = _delete_table_query(table_name, self._schema_name) logger.info("Dropping table %s", table_name) async with connection.cursor() as acur: @@ -283,7 +287,7 @@ def add_messages(self, messages: Sequence[BaseMessage]) -> None: for message in messages ] - query = _insert_message_query(self._table_name) + query = _insert_message_query(self._table_name, self._schema_name) with self._connection.cursor() as cursor: cursor.executemany(query, values) @@ -302,7 +306,7 @@ async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: for message in messages ] - query = _insert_message_query(self._table_name) + query = _insert_message_query(self._table_name, self._schema_name) async with self._aconnection.cursor() as cursor: await cursor.executemany(query, values) await self._aconnection.commit() @@ -315,7 +319,7 @@ def get_messages(self) -> List[BaseMessage]: "with a sync connection or use the async aget_messages method instead." ) - query = _get_messages_query(self._table_name) + query = _get_messages_query(self._table_name, self._schema_name) with self._connection.cursor() as cursor: cursor.execute(query, {"session_id": self._session_id}) @@ -332,7 +336,7 @@ async def aget_messages(self) -> List[BaseMessage]: "with an async connection or use the sync get_messages method instead." ) - query = _get_messages_query(self._table_name) + query = _get_messages_query(self._table_name, self._schema_name) async with self._aconnection.cursor() as cursor: await cursor.execute(query, {"session_id": self._session_id}) items = [record[0] for record in await cursor.fetchall()] @@ -345,12 +349,6 @@ def messages(self) -> List[BaseMessage]: """The abstraction required a property.""" return self.get_messages() - @messages.setter - def messages(self, value: list[BaseMessage]) -> None: - """Clear the stored messages and appends a list of messages.""" - self.clear() - self.add_messages(value) - def clear(self) -> None: """Clear the chat message history for the GIVEN session.""" if self._connection is None: @@ -359,7 +357,7 @@ def clear(self) -> None: "with a sync connection or use the async clear method instead." ) - query = _delete_by_session_id_query(self._table_name) + query = _delete_by_session_id_query(self._table_name, self._schema_name) with self._connection.cursor() as cursor: cursor.execute(query, {"session_id": self._session_id}) self._connection.commit() @@ -372,7 +370,7 @@ async def aclear(self) -> None: "with an async connection or use the sync clear method instead." ) - query = _delete_by_session_id_query(self._table_name) + query = _delete_by_session_id_query(self._table_name, self._schema_name) async with self._aconnection.cursor() as cursor: await cursor.execute(query, {"session_id": self._session_id}) await self._aconnection.commit() diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index 0fc41ea..e2efa0c 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -49,9 +49,8 @@ from langchain_postgres._utils import maximal_marginal_relevance -warnings.simplefilter("once", PendingDeprecationWarning) - +warnings.simplefilter("once", PendingDeprecationWarning) class DistanceStrategy(str, enum.Enum): """Enumerator of the Distance strategies.""" @@ -101,7 +100,7 @@ class DistanceStrategy(str, enum.Enum): ) -def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> Any: +def _get_embedding_collection_store(vector_dimension: Optional[int] = None, table_name: str = "pg_embedding", schema_name: str = "public") -> Any: global _classes if _classes is not None: return _classes @@ -111,7 +110,8 @@ def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> A class CollectionStore(Base): """Collection store.""" - __tablename__ = "langchain_pg_collection" + __tablename__ = f"{table_name}_collection" + __table_args__ = {"schema": schema_name} uuid = sqlalchemy.Column( UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 @@ -198,16 +198,25 @@ async def aget_or_create( class EmbeddingStore(Base): """Embedding store.""" - __tablename__ = "langchain_pg_embedding" + __tablename__ = table_name + __table_args__ = ( + {"schema": schema_name}, + sqlalchemy.Index( + f"{table_name}_ix_cmetadata_gin", + "cmetadata", + postgresql_using="gin", + postgresql_ops={"cmetadata": "jsonb_path_ops"}, + ), + ) id = sqlalchemy.Column( - sqlalchemy.String, primary_key=True + sqlalchemy.String, nullable=True, primary_key=True, index=True, unique=True ) collection_id = sqlalchemy.Column( UUID(as_uuid=True), sqlalchemy.ForeignKey( - f"{CollectionStore.__tablename__}.uuid", + f"{schema_name}.{CollectionStore.__tablename__}.uuid", ondelete="CASCADE", ), ) @@ -219,7 +228,7 @@ class EmbeddingStore(Base): __table_args__ = ( sqlalchemy.Index( - "ix_cmetadata_gin", + f"{table_name}_ix_cmetadata_gin", "cmetadata", postgresql_using="gin", postgresql_ops={"cmetadata": "jsonb_path_ops"}, @@ -272,6 +281,7 @@ class PGVector(VectorStore): Instantiate: .. code-block:: python + from langchain_postgres import PGVector from langchain_postgres.vectorstores import PGVector from langchain_openai import OpenAIEmbeddings @@ -282,6 +292,8 @@ class PGVector(VectorStore): vector_store = PGVector( embeddings=OpenAIEmbeddings(model="text-embedding-3-large"), collection_name=collection_name, + schema_name=schema_name, + table_name=table_name, connection=connection, use_jsonb=True, ) @@ -303,7 +315,8 @@ class PGVector(VectorStore): .. code-block:: python vector_store.delete(ids=["3"]) - + or + vector_store.delete(cmetadata={"key1": "value1"}) Search: .. code-block:: python @@ -345,6 +358,7 @@ class PGVector(VectorStore): # delete documents # await vector_store.adelete(ids=["3"]) + # await vector_store.adelete(cmetadata={"key1": "value1"}) # search # results = vector_store.asimilarity_search(query="thud",k=1) @@ -379,6 +393,8 @@ def __init__( *, connection: Union[None, DBConnection, Engine, AsyncEngine, str] = None, embedding_length: Optional[int] = None, + table_name: str = "knowledge_pg_embedding", + schema_name: str = "public", collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_metadata: Optional[dict] = None, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, @@ -421,7 +437,9 @@ def __init__( self.async_mode = async_mode self.embedding_function = embeddings self._embedding_length = embedding_length + self.table_name = table_name self.collection_name = collection_name + self.schema_name = schema_name self.collection_metadata = collection_metadata self._distance_strategy = distance_strategy self.pre_delete_collection = pre_delete_collection @@ -432,12 +450,16 @@ def __init__( self._async_init = False warnings.warn( + "PGVector is being deprecated and will be removed in the future. " + "Please migrate to PGVectorStore. " + "Refer to the migration guide at [https://github.com/langchain-ai/langchain-postgres/blob/main/examples/migrate_pgvector_to_pgvectorstore.md] for details.", + PendingDeprecationWarning, - ) + ) if isinstance(connection, str): if async_mode: self._async_engine = create_async_engine( @@ -479,7 +501,7 @@ def __post_init__( self.create_vector_extension() EmbeddingStore, CollectionStore = _get_embedding_collection_store( - self._embedding_length + self._embedding_length, self.table_name, self.schema_name ) self.CollectionStore = CollectionStore self.EmbeddingStore = EmbeddingStore @@ -495,7 +517,7 @@ async def __apost_init__( self._async_init = True EmbeddingStore, CollectionStore = _get_embedding_collection_store( - self._embedding_length + self._embedding_length, self.table_name, self.schema_name ) self.CollectionStore = CollectionStore self.EmbeddingStore = EmbeddingStore @@ -600,6 +622,7 @@ def delete( self, ids: Optional[List[str]] = None, collection_only: bool = False, + cmetadata: dict = {}, **kwargs: Any, ) -> None: """Delete vectors by ids or uuids. @@ -629,12 +652,48 @@ def delete( stmt = stmt.where(self.EmbeddingStore.id.in_(ids)) session.execute(stmt) + elif cmetadata: + self.logger.debug( + "Trying to delete vectors by cmetadata (represented by the model " + "using the custom cmetadata field)" + ) + key = list(cmetadata.keys())[0] + val = list(cmetadata.values())[0] + stmt = delete(self.EmbeddingStore).where(self.EmbeddingStore.cmetadata[key].astext == val) + + if collection_only: + collection = self.get_collection(session) + if not collection: + self.logger.warning("Collection not found") + return + + stmt = stmt.where( + self.EmbeddingStore.collection_id == collection.uuid + ) + session.execute(stmt) + elif ids is None and not cmetadata: + self.logger.debug( + "Trying to delete all vectors" + ) + stmt = delete(self.EmbeddingStore) + + if collection_only: + collection = self.get_collection(session) + if not collection: + self.logger.warning("Collection not found") + return + + stmt = stmt.where( + self.EmbeddingStore.collection_id == collection.uuid + ) + session.execute(stmt) session.commit() async def adelete( self, ids: Optional[List[str]] = None, collection_only: bool = False, + cmetadata: dict = {}, **kwargs: Any, ) -> None: """Async delete vectors by ids or uuids. @@ -665,6 +724,41 @@ async def adelete( stmt = stmt.where(self.EmbeddingStore.id.in_(ids)) await session.execute(stmt) + elif cmetadata: + self.logger.debug( + "Trying to delete vectors by cmetadata (represented by the model " + "using the custom cmetadata field)" + ) + key = list(cmetadata.keys())[0] + val = list(cmetadata.values())[0] + stmt = delete(self.EmbeddingStore).where(self.EmbeddingStore.cmetadata[key].astext == val) + + if collection_only: + collection = await self.aget_collection(session) + if not collection: + self.logger.warning("Collection not found") + return + + stmt = stmt.where( + self.EmbeddingStore.collection_id == collection.uuid + ) + await session.execute(stmt) + elif ids is None and not cmetadata: + self.logger.debug( + "Trying to delete all vectors" + ) + stmt = delete(self.EmbeddingStore) + + if collection_only: + collection = await self.aget_collection(session) + if not collection: + self.logger.warning("Collection not found") + return + + stmt = stmt.where( + self.EmbeddingStore.collection_id == collection.uuid + ) + await session.execute(stmt) await session.commit() def get_collection(self, session: Session) -> Any: