diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index 95023f1..d272062 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -59,9 +59,25 @@ class DistanceStrategy(str, enum.Enum): DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE + +class HNSWDistanceStrategy(str, enum.Enum): + """Enumerator of the HNSW index Distance strategies.""" + + EUCLIDEAN = "vector_l2_ops" + COSINE = "vector_cosine_ops" + MAX_INNER_PRODUCT = "vector_ip_ops" + L1_DISTANCE = "vector_l1_ops" + HAMMING_DISTANCE = "bit_hamming_ops" + Jaccard_DISTANCE = "bit_jaccard_ops" + + +DEFAULT_HNSW_DISTANCE_STRATEGY = HNSWDistanceStrategy.COSINE + + Base = declarative_base() # type: Any +ADA_TOKEN_COUNT = 1536 _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" @@ -887,6 +903,11 @@ def similarity_search( query (str): Query text to search for. k (int): Number of results to return. Defaults to 4. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + **kwargs (Any): Additional parameters. + - use_hnsw (bool): Use HNSW index for similarity search. + Defaults to False. + - ef_search (int): The number of candidates to consider + during the search. Defaults to 100. Returns: List of Documents most similar to the query. @@ -897,6 +918,7 @@ def similarity_search( embedding=embedding, k=k, filter=filter, + **kwargs, ) async def asimilarity_search( @@ -912,6 +934,11 @@ async def asimilarity_search( query (str): Query text to search for. k (int): Number of results to return. Defaults to 4. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + **kwargs (Any): Additional parameters. + - use_hnsw (bool): Use HNSW index for similarity search. + Defaults to False. + - ef_search (int): The number of candidates to consider + during the search. Defaults to 100. Returns: List of Documents most similar to the query. @@ -922,6 +949,7 @@ async def asimilarity_search( embedding=embedding, k=k, filter=filter, + **kwargs, ) def similarity_search_with_score( @@ -929,6 +957,7 @@ def similarity_search_with_score( query: str, k: int = 4, filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs most similar to query. @@ -936,6 +965,11 @@ def similarity_search_with_score( query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + **kwargs (Any): Additional parameters. + - use_hnsw (bool): Use HNSW index for similarity search. + Defaults to False. + - ef_search (int): The number of candidates to consider + during the search. Defaults to 100. Returns: List of Documents most similar to the query and score for each. @@ -943,7 +977,7 @@ def similarity_search_with_score( assert not self._async_engine, "This method must be called without async_mode" embedding = self.embedding_function.embed_query(query) docs = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter + embedding=embedding, k=k, filter=filter, **kwargs ) return docs @@ -952,6 +986,7 @@ async def asimilarity_search_with_score( query: str, k: int = 4, filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs most similar to query. @@ -959,6 +994,11 @@ async def asimilarity_search_with_score( query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + **kwargs (Any): Additional parameters. + - use_hnsw (bool): Use HNSW index for similarity search. + Defaults to False. + - ef_search (int): The number of candidates to consider + during the search. Defaults to 100. Returns: List of Documents most similar to the query and score for each. @@ -966,7 +1006,7 @@ async def asimilarity_search_with_score( await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) docs = await self.asimilarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter + embedding=embedding, k=k, filter=filter, **kwargs ) return docs @@ -989,9 +1029,15 @@ def similarity_search_with_score_by_vector( embedding: List[float], k: int = 4, filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: assert not self._async_engine, "This method must be called without async_mode" - results = self.__query_collection(embedding=embedding, k=k, filter=filter) + results = self.__query_collection( + embedding=embedding, + k=k, + filter=filter, + **kwargs, + ) return self._results_to_docs_and_scores(results) @@ -1000,11 +1046,16 @@ async def asimilarity_search_with_score_by_vector( embedding: List[float], k: int = 4, filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: await self.__apost_init__() # Lazy async init async with self._make_async_session() as session: # type: ignore[arg-type] results = await self.__aquery_collection( - session=session, embedding=embedding, k=k, filter=filter + session=session, + embedding=embedding, + k=k, + filter=filter, + **kwargs, ) return self._results_to_docs_and_scores(results) @@ -1343,11 +1394,30 @@ def _create_filter_clause(self, filters: Any) -> Any: f"Invalid type: Expected a dictionary but got type: {type(filters)}" ) + def _execute_hnsw_settings(self, session: Session, **kwargs: Any) -> None: + session.execute(sqlalchemy.text("SET LOCAL enable_seqscan = off;")) + session.execute( + sqlalchemy.text( + f"SET LOCAL hnsw.ef_search = {kwargs.get('ef_search', 100)};" + ) + ) + + async def _aexecute_hnsw_settings( + self, session: AsyncSession, **kwargs: Any + ) -> None: + await session.execute(sqlalchemy.text("SET LOCAL enable_seqscan = off;")) + await session.execute( + sqlalchemy.text( + f"SET LOCAL hnsw.ef_search = {kwargs.get('ef_search', 100)};" + ) + ) + def __query_collection( self, embedding: List[float], k: int = 4, filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> Sequence[Any]: """Query the collection.""" with self._make_sync_session() as session: # type: ignore[arg-type] @@ -1367,6 +1437,8 @@ def __query_collection( filter_by.extend(filter_clauses) _type = self.EmbeddingStore + if kwargs.get("use_hnsw", True): + self._execute_hnsw_settings(session, **kwargs) results: List[Any] = ( session.query( @@ -1391,6 +1463,7 @@ async def __aquery_collection( embedding: List[float], k: int = 4, filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> Sequence[Any]: """Query the collection.""" async with self._make_async_session() as session: # type: ignore[arg-type] @@ -1411,6 +1484,9 @@ async def __aquery_collection( _type = self.EmbeddingStore + if kwargs.get("use_hnsw", True): + await self._aexecute_hnsw_settings(session, **kwargs) + stmt = ( select( self.EmbeddingStore, @@ -1442,13 +1518,18 @@ def similarity_search_by_vector( embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + **kwargs (Any): Additional parameters. + - use_hnsw (bool): Use HNSW index for similarity search. + Defaults to False. + - ef_search (int): The number of candidates to consider + during the search. Defaults to 100. Returns: List of Documents most similar to the query vector. """ assert not self._async_engine, "This method must be called without async_mode" docs_and_scores = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter + embedding=embedding, k=k, filter=filter, **kwargs ) return _results_to_docs(docs_and_scores) @@ -1472,7 +1553,10 @@ async def asimilarity_search_by_vector( assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init docs_and_scores = await self.asimilarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter + embedding=embedding, + k=k, + filter=filter, + **kwargs, ) return _results_to_docs(docs_and_scores) @@ -1852,7 +1936,9 @@ def max_marginal_relevance_search_with_score_by_vector( relevance to the query and score for each. """ assert not self._async_engine, "This method must be called without async_mode" - results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter) + results = self.__query_collection( + embedding=embedding, k=fetch_k, filter=filter, **kwargs + ) embedding_list = [result.EmbeddingStore.embedding for result in results] @@ -1892,6 +1978,11 @@ async def amax_marginal_relevance_search_with_score_by_vector( to maximum diversity and 1 to minimum diversity. Defaults to 0.5. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + **kwargs (Any): Additional parameters. + - use_hnsw (bool): Use HNSW index for similarity search. + Defaults to False. + - ef_search (int): The number of candidates to consider + during the search. Defaults to 100. Returns: List[Tuple[Document, float]]: List of Documents selected by maximal marginal @@ -1900,7 +1991,11 @@ async def amax_marginal_relevance_search_with_score_by_vector( await self.__apost_init__() # Lazy async init async with self._make_async_session() as session: results = await self.__aquery_collection( - session=session, embedding=embedding, k=fetch_k, filter=filter + session=session, + embedding=embedding, + k=fetch_k, + filter=filter, + **kwargs, ) embedding_list = [result.EmbeddingStore.embedding for result in results] @@ -1940,6 +2035,11 @@ def max_marginal_relevance_search( to maximum diversity and 1 to minimum diversity. Defaults to 0.5. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + **kwargs (Any): Additional parameters. + - use_hnsw (bool): Use HNSW index for similarity search. + Defaults to False. + - ef_search (int): The number of candidates to consider + during the search. Defaults to 100. Returns: List[Document]: List of Documents selected by maximal marginal relevance. @@ -1978,6 +2078,11 @@ async def amax_marginal_relevance_search( to maximum diversity and 1 to minimum diversity. Defaults to 0.5. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + **kwargs (Any): Additional parameters. + - use_hnsw (bool): Use HNSW index for similarity search. + Defaults to False. + - ef_search (int): The number of candidates to consider + during the search. Defaults to 100. Returns: List[Document]: List of Documents selected by maximal marginal relevance. @@ -2017,6 +2122,11 @@ def max_marginal_relevance_search_with_score( to maximum diversity and 1 to minimum diversity. Defaults to 0.5. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + **kwargs (Any): Additional parameters. + - use_hnsw (bool): Use HNSW index for similarity search. + Defaults to False. + - ef_search (int): The number of candidates to consider + during the search. Defaults to 100. Returns: List[Tuple[Document, float]]: List of Documents selected by maximal marginal @@ -2057,6 +2167,11 @@ async def amax_marginal_relevance_search_with_score( to maximum diversity and 1 to minimum diversity. Defaults to 0.5. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + **kwargs (Any): Additional parameters. + - use_hnsw (bool): Use HNSW index for similarity search. + Defaults to False. + - ef_search (int): The number of candidates to consider + during the search. Defaults to 100. Returns: List[Tuple[Document, float]]: List of Documents selected by maximal marginal @@ -2099,6 +2214,11 @@ def max_marginal_relevance_search_by_vector( to maximum diversity and 1 to minimum diversity. Defaults to 0.5. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + **kwargs (Any): Additional parameters. + - use_hnsw (bool): Use HNSW index for similarity search. + Defaults to False. + - ef_search (int): The number of candidates to consider + during the search. Defaults to 100. Returns: List[Document]: List of Documents selected by maximal marginal relevance. @@ -2139,6 +2259,11 @@ async def amax_marginal_relevance_search_by_vector( to maximum diversity and 1 to minimum diversity. Defaults to 0.5. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + **kwargs (Any): Additional parameters. + - use_hnsw (bool): Use HNSW index for similarity search. + Defaults to False. + - ef_search (int): The number of candidates to consider + during the search. Defaults to 100. Returns: List[Document]: List of Documents selected by maximal marginal relevance. @@ -2157,6 +2282,66 @@ async def amax_marginal_relevance_search_by_vector( return _results_to_docs(docs_and_scores) + def _prepare_create_hnsw_index_query( + self, + dims: int = ADA_TOKEN_COUNT, + distance_strategy: HNSWDistanceStrategy = DEFAULT_HNSW_DISTANCE_STRATEGY, + m: int = 8, + ef_construction: int = 16, + ) -> sqlalchemy.TextClause: + create_index_query = sqlalchemy.text( + "CREATE INDEX IF NOT EXISTS langchain_pg_embedding_idx " + "ON langchain_pg_embedding USING hnsw ((embedding::vector({})) {}) " + "WITH (" + "m = {}, " + "ef_construction = {}" + ");".format(dims, distance_strategy, m, ef_construction) + ) + + return create_index_query + + def create_hnsw_index( + self, + distance_strategy: HNSWDistanceStrategy = DEFAULT_HNSW_DISTANCE_STRATEGY, + m: int = 8, + ef_construction: int = 16, + ) -> None: + assert self._engine, "engine not found" + create_index_query = self._prepare_create_hnsw_index_query( + distance_strategy=distance_strategy, m=m, ef_construction=ef_construction + ) + + # Execute the queries + try: + with self._make_sync_session() as session: + session.execute(create_index_query) + session.commit() + print("HNSW extension and index created successfully.") # noqa: T201 + except Exception as e: + print(f"Failed to create HNSW extension or index: {e}") # noqa: T201 + raise e + + async def acreate_hnsw_index( + self, + distance_strategy: HNSWDistanceStrategy = DEFAULT_HNSW_DISTANCE_STRATEGY, + m: int = 8, + ef_construction: int = 16, + ) -> None: + assert self._async_engine, "This method must be called with async_mode" + create_index_query = self._prepare_create_hnsw_index_query( + distance_strategy=distance_strategy, m=m, ef_construction=ef_construction + ) + + # Execute the queries + try: + async with self._make_async_session() as session: + await session.execute(create_index_query) + await session.commit() + print("HNSW extension and index created successfully.") # noqa: T201 + except Exception as e: + print(f"Failed to create HNSW extension or index: {e}") # noqa: T201 + raise e + @contextlib.contextmanager def _make_sync_session(self) -> Generator[Session, None, None]: """Make an async session.""" diff --git a/tests/unit_tests/test_vectorstore.py b/tests/unit_tests/test_vectorstore.py index cf0a184..e0e45d0 100644 --- a/tests/unit_tests/test_vectorstore.py +++ b/tests/unit_tests/test_vectorstore.py @@ -1020,3 +1020,75 @@ def test_validate_operators() -> None: "$not", "$or", ] + + +def test_pgvector_similarity_search_with_hnsw() -> None: + """Test similarity search using HNSW index.""" + texts = ["foo", "bar", "baz"] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection_hnsw", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + docsearch.create_hnsw_index() + output = docsearch.similarity_search("foo", k=1, use_hnsw=True, ef_search=200) + assert output == [Document(page_content="foo")] + + +@pytest.mark.asyncio +async def test_async_pgvector_similarity_search_with_hnsw() -> None: + """Test similarity search using HNSW index asynchronously.""" + texts = ["foo", "bar", "baz"] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection_hnsw", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + await docsearch.acreate_hnsw_index() + output = await docsearch.asimilarity_search( + "foo", k=1, use_hnsw=True, ef_search=200 + ) + assert output == [Document(page_content="foo")] + + +def test_pgvector_similarity_search_with_hnsw_and_filter() -> None: + """Test similarity search using HNSW index with a filter.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection_hnsw_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + docsearch.create_hnsw_index() + output = docsearch.similarity_search( + "foo", k=1, use_hnsw=True, ef_search=200, filter={"page": "0"} + ) + assert output == [Document(page_content="foo", metadata={"page": "0"})] + + +@pytest.mark.asyncio +async def test_async_pgvector_similarity_search_with_hnsw_and_filter() -> None: + """Test similarity search using HNSW index with a filter asynchronously.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection_hnsw_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + await docsearch.acreate_hnsw_index() + output = await docsearch.asimilarity_search( + "foo", k=1, use_hnsw=True, ef_search=200, filter={"page": "0"} + ) + assert output == [Document(page_content="foo", metadata={"page": "0"})]