diff --git a/libs/partners/milvus/langchain_milvus/vectorstores/milvus.py b/libs/partners/milvus/langchain_milvus/vectorstores/milvus.py index 7a9476981a86e..5419119ba0867 100644 --- a/libs/partners/milvus/langchain_milvus/vectorstores/milvus.py +++ b/libs/partners/milvus/langchain_milvus/vectorstores/milvus.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Any, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from uuid import uuid4 import numpy as np @@ -9,6 +9,8 @@ from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore +from langchain_milvus.utils.sparse import BaseSparseEmbedding + logger = logging.getLogger(__name__) DEFAULT_MILVUS_CONNECTION = { @@ -110,7 +112,7 @@ class Milvus(VectorStore): Name of the collection. collection_description: str Description of the collection. - embedding_function: Embeddings + embedding_function: Union[Embeddings, BaseSparseEmbedding] Embedding function to use. Key init args — client params: @@ -219,7 +221,7 @@ class Milvus(VectorStore): def __init__( self, - embedding_function: Embeddings, + embedding_function: Union[Embeddings, BaseSparseEmbedding], # type: ignore collection_name: str = "LangChainCollection", collection_description: str = "", collection_properties: Optional[dict[str, Any]] = None, @@ -276,6 +278,11 @@ def __init__( }, "GPU_IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, "GPU_IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, + "SPARSE_INVERTED_INDEX": { + "metric_type": "IP", + "params": {"drop_ratio_build": 0.2}, + }, + "SPARSE_WAND": {"metric_type": "IP", "params": {"drop_ratio_build": 0.2}}, } self.embedding_func = embedding_function @@ -340,7 +347,7 @@ def __init__( ) @property - def embeddings(self) -> Embeddings: + def embeddings(self) -> Union[Embeddings, BaseSparseEmbedding]: # type: ignore return self.embedding_func def _create_connection_alias(self, connection_args: dict) -> str: @@ -402,6 +409,10 @@ def _create_connection_alias(self, connection_args: dict) -> str: logger.error("Failed to create new connection using: %s", alias) raise e + @property + def _is_sparse_embedding(self) -> bool: + return isinstance(self.embedding_func, BaseSparseEmbedding) + def _init( self, embeddings: Optional[list] = None, @@ -539,9 +550,14 @@ def _create_collection( ) ) # Create the vector field, supports binary or float vectors - fields.append( - FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim) - ) + if self._is_sparse_embedding: + fields.append(FieldSchema(self._vector_field, DataType.SPARSE_FLOAT_VECTOR)) + else: + fields.append( + FieldSchema( + self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim + ) + ) # Create the schema for the collection schema = CollectionSchema( @@ -606,11 +622,18 @@ def _create_index(self) -> None: try: # If no index params, use a default HNSW based one if self.index_params is None: - self.index_params = { - "metric_type": "L2", - "index_type": "HNSW", - "params": {"M": 8, "efConstruction": 64}, - } + if self._is_sparse_embedding: + self.index_params = { + "metric_type": "IP", + "index_type": "SPARSE_INVERTED_INDEX", + "params": {"drop_ratio_build": 0.2}, + } + else: + self.index_params = { + "metric_type": "L2", + "index_type": "HNSW", + "params": {"M": 8, "efConstruction": 64}, + } try: self.col.create_index( @@ -740,7 +763,7 @@ def add_texts( ) try: - embeddings = self.embedding_func.embed_documents(texts) + embeddings: list = self.embedding_func.embed_documents(texts) except NotImplementedError: embeddings = [self.embedding_func.embed_query(x) for x in texts] @@ -815,7 +838,7 @@ def add_texts( def _collection_search( self, - embedding: List[float], + embedding: List[float] | Dict[int, float], k: int = 4, param: Optional[dict] = None, expr: Optional[str] = None, @@ -829,7 +852,8 @@ def _collection_search( https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/search.md Args: - embedding (List[float]): The embedding vector being searched. + embedding (List[float] | Dict[int, float]): The embedding vector being + searched. k (int, optional): The amount of results to return. Defaults to 4. param (dict): The search params for the specified index. Defaults to None. @@ -976,7 +1000,7 @@ def similarity_search_with_score( def similarity_search_with_score_by_vector( self, - embedding: List[float], + embedding: List[float] | Dict[int, float], k: int = 4, param: Optional[dict] = None, expr: Optional[str] = None, @@ -990,7 +1014,8 @@ def similarity_search_with_score_by_vector( https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/search.md Args: - embedding (List[float]): The embedding vector being searched. + embedding (List[float] | Dict[int, float]): The embedding vector being + searched. k (int, optional): The amount of results to return. Defaults to 4. param (dict): The search params for the specified index. Defaults to None. @@ -1068,7 +1093,7 @@ def max_marginal_relevance_search( def max_marginal_relevance_search_by_vector( self, - embedding: list[float], + embedding: list[float] | dict[int, float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, @@ -1080,7 +1105,8 @@ def max_marginal_relevance_search_by_vector( """Perform a search and return results that are reordered by MMR. Args: - embedding (str): The embedding vector being searched. + embedding (list[float] | dict[int, float]): The embedding vector being + searched. k (int, optional): How many results to give. Defaults to 4. fetch_k (int, optional): Total results to select k from. Defaults to 20. @@ -1171,7 +1197,7 @@ def delete( # type: ignore[no-untyped-def] def from_texts( cls, texts: List[str], - embedding: Embeddings, + embedding: Union[Embeddings, BaseSparseEmbedding], # type: ignore metadatas: Optional[List[dict]] = None, collection_name: str = "LangChainCollection", connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION, @@ -1187,7 +1213,7 @@ def from_texts( Args: texts (List[str]): Text data. - embedding (Embeddings): Embedding function. + embedding (Union[Embeddings, BaseSparseEmbedding]): Embedding function. metadatas (Optional[List[dict]]): Metadata for each text if it exists. Defaults to None. collection_name (str, optional): Collection name to use. Defaults to diff --git a/libs/partners/milvus/langchain_milvus/vectorstores/zilliz.py b/libs/partners/milvus/langchain_milvus/vectorstores/zilliz.py index 02f2ce739ff4c..976651dbde8b4 100644 --- a/libs/partners/milvus/langchain_milvus/vectorstores/zilliz.py +++ b/libs/partners/milvus/langchain_milvus/vectorstores/zilliz.py @@ -1,10 +1,11 @@ from __future__ import annotations import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from langchain_core.embeddings import Embeddings +from langchain_milvus.utils.sparse import BaseSparseEmbedding from langchain_milvus.vectorstores.milvus import Milvus logger = logging.getLogger(__name__) @@ -141,7 +142,7 @@ def _create_index(self) -> None: def from_texts( cls, texts: List[str], - embedding: Embeddings, + embedding: Union[Embeddings, BaseSparseEmbedding], metadatas: Optional[List[dict]] = None, collection_name: str = "LangChainCollection", connection_args: Optional[Dict[str, Any]] = None, diff --git a/libs/partners/milvus/tests/integration_tests/vectorstores/test_milvus.py b/libs/partners/milvus/tests/integration_tests/vectorstores/test_milvus.py index 2066f61c56d51..cd7282c06e1e2 100644 --- a/libs/partners/milvus/tests/integration_tests/vectorstores/test_milvus.py +++ b/libs/partners/milvus/tests/integration_tests/vectorstores/test_milvus.py @@ -5,6 +5,7 @@ import pytest from langchain_core.documents import Document +from langchain_milvus.utils.sparse import BM25SparseEmbedding from langchain_milvus.vectorstores import Milvus from tests.integration_tests.utils import ( FakeEmbeddings, @@ -304,6 +305,31 @@ def test_milvus_enable_dynamic_field_with_partition_key() -> None: } +def test_milvus_sparse_embeddings() -> None: + texts = [ + "In 'The Clockwork Kingdom' by Augusta Wynter, a brilliant inventor discovers " + "a hidden world of clockwork machines and ancient magic, where a rebellion is " + "brewing against the tyrannical ruler of the land.", + "In 'The Phantom Pilgrim' by Rowan Welles, a charismatic smuggler is hired by " + "a mysterious organization to transport a valuable artifact across a war-torn " + "continent, but soon finds themselves pursued by assassins and rival factions.", + "In 'The Dreamwalker's Journey' by Lyra Snow, a young dreamwalker discovers " + "she has the ability to enter people's dreams, but soon finds herself trapped " + "in a surreal world of nightmares and illusions, where the boundaries between " + "reality and fantasy blur.", + ] + sparse_embedding_func = BM25SparseEmbedding(corpus=texts) + docsearch = Milvus.from_texts( + embedding=sparse_embedding_func, + texts=texts, + connection_args={"uri": "./milvus_demo.db"}, + drop_old=True, + ) + + output = docsearch.similarity_search("Pilgrim", k=1) + assert "Pilgrim" in output[0].page_content + + def test_milvus_array_field() -> None: """Manually specify metadata schema, including an array_field. For more information about array data type and filtering, please refer to @@ -365,4 +391,6 @@ def test_milvus_array_field() -> None: # test_milvus_enable_dynamic_field() # test_milvus_disable_dynamic_field() # test_milvus_metadata_field() +# test_milvus_enable_dynamic_field_with_partition_key() +# test_milvus_sparse_embeddings() # test_milvus_array_field()