diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index 6afdd98..782c7aa 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -29,7 +29,7 @@ from langchain_core.indexing import UpsertResponse from langchain_core.utils import get_from_dict_or_env from langchain_core.vectorstores import VectorStore -from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select +from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select, text from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert from sqlalchemy.engine import Connection, Engine from sqlalchemy.ext.asyncio import ( @@ -97,6 +97,138 @@ class DistanceStrategy(str, enum.Enum): .union(SPECIAL_CASED_OPERATORS) ) +class IndexManager: + """Manages the creation, listing, and retrieval of indexes for the embedding column in a PostgreSQL database. + + This class provides both synchronous and asynchronous methods to interact with the database, allowing for + the creation of different types of indexes (e.g., HNSW, IVFFlat) with various distance functions (e.g., l2, cosine). + + Args: + connection (Union[str, Engine, AsyncEngine]): The database connection string or engine instance. + async_mode (bool): Flag to indicate if asynchronous operations should be used. Defaults to False. + """ + def __init__(self, connection: Union[str, Engine, AsyncEngine], async_mode: bool = False): + self.async_mode = async_mode + if isinstance(connection, str): + if async_mode: + self._engine = create_async_engine(connection) + else: + self._engine = create_engine(connection) + elif isinstance(connection, (Engine, AsyncEngine)): + self._engine = connection + else: + raise ValueError("Invalid connection type") + + def list_indexes(self) -> List[Dict[str, Any]]: + """List all indexes from the embeddings column.""" + with self._engine.connect() as conn: + result = conn.execute( + text(""" + SELECT * + FROM pg_indexes + WHERE tablename = 'langchain_pg_embedding' + AND indexdef LIKE '%embedding%' + """) + ) + indexes = [dict(row) for row in result] + return indexes + + async def alist_indexes(self) -> List[Dict[str, Any]]: + """Asynchronously list all indexes from the embeddings column.""" + async with self._engine.connect() as conn: + result = await conn.execute( + text(""" + SELECT * + FROM pg_indexes + WHERE tablename = 'langchain_pg_embedding' + AND indexdef LIKE '%embedding%' + """) + ) + indexes = [dict(row) for row in result] + return indexes + + def create_index(self, index_type: str, distance_strategy: DistanceStrategy, **kwargs: Any) -> str: + """Create an index (HNSW or IVFFlat) on the embedding column. + + Args: + index_type: The type of index to create ('hnsw' or 'ivfflat'). + distance_strategy: The distance strategy to use (e.g., DistanceStrategy.L2, DistanceStrategy.COSINE). + kwargs: Additional parameters for the index creation (e.g., m, ef_construction, lists). + + Returns: + The name of the created index. + """ + index_ops = f"vector_{distance_strategy.value}_ops" + index_name = f"{index_type}_{distance_strategy.value}_index" + index_params = ", ".join(f"{key} = {value}" for key, value in kwargs.items()) + with self._engine.connect() as conn: + conn.execute( + text( + f""" + CREATE INDEX {index_name} ON langchain_pg_embedding USING {index_type} (embedding {index_ops}) + WITH ({index_params}); + """ + ) + ) + return index_name + + async def acreate_index(self, index_type: str, distance_strategy: DistanceStrategy, **kwargs: Any) -> str: + """Asynchronously create an index (HNSW or IVFFlat) on the embedding column. + + Args: + index_type: The type of index to create ('hnsw' or 'ivfflat'). + distance_strategy: The distance strategy to use (e.g., DistanceStrategy.L2, DistanceStrategy.COSINE). + kwargs: Additional parameters for the index creation (e.g., m, ef_construction, lists). + + Returns: + The name of the created index. + """ + index_ops = f"vector_{distance_strategy.value}_ops" + index_name = f"{index_type}_{distance_strategy.value}_index" + index_params = ", ".join(f"{key} = {value}" for key, value in kwargs.items()) + async with self._engine.connect() as conn: + await conn.execute( + text( + f""" + CREATE INDEX {index_name} ON langchain_pg_embedding USING {index_type} (embedding {index_ops}) + WITH ({index_params}); + """ + ) + ) + return index_name + + def get_index(self, index_name: str, embeddings: Embeddings, collection_name: str) -> Optional[PGVector]: + """Get details of a specific index and return a PGVector instance.""" + with self._engine.connect() as conn: + result = conn.execute(text(f"SELECT * FROM pg_indexes WHERE indexname = :index_name"), {"index_name": index_name}) + index = result.fetchone() + if index: + distance_strategy = DistanceStrategy(index['indexdef'].split(' ')[-1].split('_')[1]) + return PGVector( + embeddings=embeddings, + connection=self._engine, + collection_name=collection_name, + distance_strategy=distance_strategy, + async_mode=self.async_mode + ) + return None + + async def aget_index(self, index_name: str, embeddings: Embeddings, collection_name: str) -> Optional[PGVector]: + """Asynchronously get details of a specific index and return a PGVector instance.""" + async with self._engine.connect() as conn: + result = await conn.execute(text(f"SELECT * FROM pg_indexes WHERE indexname = :index_name"), {"index_name": index_name}) + index = result.fetchone() + if index: + distance_strategy = DistanceStrategy(index['indexdef'].split(' ')[-1].split('_')[1]) + return PGVector( + embeddings=embeddings, + connection=self._engine, + collection_name=collection_name, + distance_strategy=distance_strategy, + async_mode=self.async_mode + ) + return None + def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> Any: global _classes @@ -355,6 +487,8 @@ def __init__( use_jsonb: bool = True, create_extension: bool = True, async_mode: bool = False, + index_type: Optional[str] = None, + index_params: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the PGVector store. For an async version, use `PGVector.acreate()` instead. @@ -383,6 +517,8 @@ def __init__( create_extension: If True, will create the vector extension if it doesn't exist. disabling creation is useful when using ReadOnly Databases. + index_type: The type of index to create. (default: None) + index_params: The parameters for the index. (default: None) """ self.async_mode = async_mode self.embedding_function = embeddings @@ -396,6 +532,8 @@ def __init__( self._engine: Optional[Engine] = None self._async_engine: Optional[AsyncEngine] = None self._async_init = False + self.index_type = index_type + self.index_params = index_params or {} if isinstance(connection, str): if async_mode: @@ -427,6 +565,9 @@ def __init__( if not use_jsonb: # Replace with a deprecation warning. raise NotImplementedError("use_jsonb=False is no longer supported.") + + self.index_manager = IndexManager(connection=self._engine, async_mode=self.async_mode) + if not self.async_mode: self.__post_init__() @@ -445,6 +586,9 @@ def __post_init__( self.create_tables_if_not_exists() self.create_collection() + if self.index_type: + self.index_manager.create_index(self.index_type, self._distance_strategy, **self.index_params) + async def __apost_init__( self, ) -> None: @@ -464,6 +608,9 @@ async def __apost_init__( await self.acreate_tables_if_not_exists() await self.acreate_collection() + if self.index_type: + await self.index_manager.acreate_index(self.index_type, self._distance_strategy, **self.index_params) + @property def embeddings(self) -> Embeddings: return self.embedding_function diff --git a/tests/unit_tests/test_vectorstore.py b/tests/unit_tests/test_vectorstore.py index 9945e51..bf34e2d 100644 --- a/tests/unit_tests/test_vectorstore.py +++ b/tests/unit_tests/test_vectorstore.py @@ -9,6 +9,8 @@ from langchain_postgres.vectorstores import ( SUPPORTED_OPERATORS, + DistanceStrategy, + IndexManager, PGVector, ) from tests.unit_tests.fake_embeddings import FakeEmbeddings @@ -1068,3 +1070,86 @@ def test_validate_operators() -> None: "$not", "$or", ] + +def test_pgvector_with_hnsw_index() -> None: + """Test end to end construction and search with HNSW index.""" + texts = ["foo", "bar", "baz"] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection_hnsw", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + index_type="hnsw", + index_params={"m": 16, "ef_construction": 64}, + distance_strategy=DistanceStrategy.L2, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo", id=AnyStr())] + +@pytest.mark.asyncio +async def test_async_pgvector_with_hnsw_index() -> None: + """Test end to end construction and search with HNSW index.""" + texts = ["foo", "bar", "baz"] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection_hnsw", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + index_type="hnsw", + index_params={"m": 16, "ef_construction": 64}, + distance_strategy=DistanceStrategy.L2, + ) + output = await docsearch.asimilarity_search("foo", k=1) + assert output == [Document(page_content="foo", id=AnyStr())] + +def test_pgvector_with_ivfflat_index() -> None: + """Test end to end construction and search with IVFFlat index.""" + texts = ["foo", "bar", "baz"] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection_ivfflat", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + index_type="ivfflat", + index_params={"lists": 100}, + distance_strategy=DistanceStrategy.COSINE, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo", id=AnyStr())] + +@pytest.mark.asyncio +async def test_async_pgvector_with_ivfflat_index() -> None: + """Test end to end construction and search with IVFFlat index.""" + texts = ["foo", "bar", "baz"] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection_ivfflat", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + index_type="ivfflat", + index_params={"lists": 100}, + distance_strategy=DistanceStrategy.COSINE, + ) + output = await docsearch.asimilarity_search("foo", k=1) + assert output == [Document(page_content="foo", id=AnyStr())] + +def test_get_index() -> None: + """Test retrieving a VectorStore instance.""" + index_manager = IndexManager(connection=CONNECTION_STRING) + vectorstore = index_manager.get_index("hnsw_l2_index", FakeEmbeddingsWithAdaDimension(), "test_collection_hnsw") + assert vectorstore is not None + output = vectorstore.similarity_search("foo", k=1) + assert output == [Document(page_content="foo", id=AnyStr())] + +@pytest.mark.asyncio +async def test_async_get_index() -> None: + """Test asynchronously retrieving a VectorStore instance.""" + index_manager = IndexManager(connection=CONNECTION_STRING, async_mode=True) + vectorstore = await index_manager.aget_index("hnsw_l2_index", FakeEmbeddingsWithAdaDimension(), "test_collection_hnsw") + assert vectorstore is not None + output = await vectorstore.asimilarity_search("foo", k=1) + assert output == [Document(page_content="foo", id=AnyStr())] \ No newline at end of file