Skip to content

Commit

Permalink
partners/milvus: allow creating a vectorstore with sparse embeddings (#…
Browse files Browse the repository at this point in the history
…25284)

# Description
Milvus (and `pymilvus`) recently added the option to use [sparse
vectors](https://milvus.io/docs/sparse_vector.md#Sparse-Vector) with
appropriate search methods (e.g., `SPARSE_INVERTED_INDEX`) and
embeddings (e.g., `BM25`, `SPLADE`).

This PR allow creating a vector store using langchain's `Milvus` class,
setting the matching vector field type to `DataType.SPARSE_FLOAT_VECTOR`
and the default index type to `SPARSE_INVERTED_INDEX`.

It is only extending functionality, and backward compatible. 

## Note
I also interested in extending the Milvus class further to support multi
vector search (aka hybrid search). Will be happy to discuss that. See
[here](#19955),
[here](#20375), and
[here](#22886)
similar needs.

---------

Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
ohadeytan and efriis authored Aug 30, 2024
1 parent 09b04c7 commit b5d6704
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 23 deletions.
68 changes: 47 additions & 21 deletions libs/partners/milvus/langchain_milvus/vectorstores/milvus.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from __future__ import annotations

import logging
from typing import Any, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from uuid import uuid4

import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore

from langchain_milvus.utils.sparse import BaseSparseEmbedding

logger = logging.getLogger(__name__)

DEFAULT_MILVUS_CONNECTION = {
Expand Down Expand Up @@ -110,7 +112,7 @@ class Milvus(VectorStore):
Name of the collection.
collection_description: str
Description of the collection.
embedding_function: Embeddings
embedding_function: Union[Embeddings, BaseSparseEmbedding]
Embedding function to use.
Key init args — client params:
Expand Down Expand Up @@ -219,7 +221,7 @@ class Milvus(VectorStore):

def __init__(
self,
embedding_function: Embeddings,
embedding_function: Union[Embeddings, BaseSparseEmbedding], # type: ignore
collection_name: str = "LangChainCollection",
collection_description: str = "",
collection_properties: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -276,6 +278,11 @@ def __init__(
},
"GPU_IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
"GPU_IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
"SPARSE_INVERTED_INDEX": {
"metric_type": "IP",
"params": {"drop_ratio_build": 0.2},
},
"SPARSE_WAND": {"metric_type": "IP", "params": {"drop_ratio_build": 0.2}},
}

self.embedding_func = embedding_function
Expand Down Expand Up @@ -340,7 +347,7 @@ def __init__(
)

@property
def embeddings(self) -> Embeddings:
def embeddings(self) -> Union[Embeddings, BaseSparseEmbedding]: # type: ignore
return self.embedding_func

def _create_connection_alias(self, connection_args: dict) -> str:
Expand Down Expand Up @@ -402,6 +409,10 @@ def _create_connection_alias(self, connection_args: dict) -> str:
logger.error("Failed to create new connection using: %s", alias)
raise e

@property
def _is_sparse_embedding(self) -> bool:
return isinstance(self.embedding_func, BaseSparseEmbedding)

def _init(
self,
embeddings: Optional[list] = None,
Expand Down Expand Up @@ -539,9 +550,14 @@ def _create_collection(
)
)
# Create the vector field, supports binary or float vectors
fields.append(
FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim)
)
if self._is_sparse_embedding:
fields.append(FieldSchema(self._vector_field, DataType.SPARSE_FLOAT_VECTOR))
else:
fields.append(
FieldSchema(
self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim
)
)

# Create the schema for the collection
schema = CollectionSchema(
Expand Down Expand Up @@ -606,11 +622,18 @@ def _create_index(self) -> None:
try:
# If no index params, use a default HNSW based one
if self.index_params is None:
self.index_params = {
"metric_type": "L2",
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
}
if self._is_sparse_embedding:
self.index_params = {
"metric_type": "IP",
"index_type": "SPARSE_INVERTED_INDEX",
"params": {"drop_ratio_build": 0.2},
}
else:
self.index_params = {
"metric_type": "L2",
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
}

try:
self.col.create_index(
Expand Down Expand Up @@ -740,7 +763,7 @@ def add_texts(
)

try:
embeddings = self.embedding_func.embed_documents(texts)
embeddings: list = self.embedding_func.embed_documents(texts)
except NotImplementedError:
embeddings = [self.embedding_func.embed_query(x) for x in texts]

Expand Down Expand Up @@ -815,7 +838,7 @@ def add_texts(

def _collection_search(
self,
embedding: List[float],
embedding: List[float] | Dict[int, float],
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
Expand All @@ -829,7 +852,8 @@ def _collection_search(
https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/search.md
Args:
embedding (List[float]): The embedding vector being searched.
embedding (List[float] | Dict[int, float]): The embedding vector being
searched.
k (int, optional): The amount of results to return. Defaults to 4.
param (dict): The search params for the specified index.
Defaults to None.
Expand Down Expand Up @@ -976,7 +1000,7 @@ def similarity_search_with_score(

def similarity_search_with_score_by_vector(
self,
embedding: List[float],
embedding: List[float] | Dict[int, float],
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
Expand All @@ -990,7 +1014,8 @@ def similarity_search_with_score_by_vector(
https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/search.md
Args:
embedding (List[float]): The embedding vector being searched.
embedding (List[float] | Dict[int, float]): The embedding vector being
searched.
k (int, optional): The amount of results to return. Defaults to 4.
param (dict): The search params for the specified index.
Defaults to None.
Expand Down Expand Up @@ -1068,7 +1093,7 @@ def max_marginal_relevance_search(

def max_marginal_relevance_search_by_vector(
self,
embedding: list[float],
embedding: list[float] | dict[int, float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
Expand All @@ -1080,7 +1105,8 @@ def max_marginal_relevance_search_by_vector(
"""Perform a search and return results that are reordered by MMR.
Args:
embedding (str): The embedding vector being searched.
embedding (list[float] | dict[int, float]): The embedding vector being
searched.
k (int, optional): How many results to give. Defaults to 4.
fetch_k (int, optional): Total results to select k from.
Defaults to 20.
Expand Down Expand Up @@ -1171,7 +1197,7 @@ def delete( # type: ignore[no-untyped-def]
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
embedding: Union[Embeddings, BaseSparseEmbedding], # type: ignore
metadatas: Optional[List[dict]] = None,
collection_name: str = "LangChainCollection",
connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION,
Expand All @@ -1187,7 +1213,7 @@ def from_texts(
Args:
texts (List[str]): Text data.
embedding (Embeddings): Embedding function.
embedding (Union[Embeddings, BaseSparseEmbedding]): Embedding function.
metadatas (Optional[List[dict]]): Metadata for each text if it exists.
Defaults to None.
collection_name (str, optional): Collection name to use. Defaults to
Expand Down
5 changes: 3 additions & 2 deletions libs/partners/milvus/langchain_milvus/vectorstores/zilliz.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from langchain_core.embeddings import Embeddings

from langchain_milvus.utils.sparse import BaseSparseEmbedding
from langchain_milvus.vectorstores.milvus import Milvus

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -141,7 +142,7 @@ def _create_index(self) -> None:
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
embedding: Union[Embeddings, BaseSparseEmbedding],
metadatas: Optional[List[dict]] = None,
collection_name: str = "LangChainCollection",
connection_args: Optional[Dict[str, Any]] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from langchain_core.documents import Document

from langchain_milvus.utils.sparse import BM25SparseEmbedding
from langchain_milvus.vectorstores import Milvus
from tests.integration_tests.utils import (
FakeEmbeddings,
Expand Down Expand Up @@ -304,6 +305,31 @@ def test_milvus_enable_dynamic_field_with_partition_key() -> None:
}


def test_milvus_sparse_embeddings() -> None:
texts = [
"In 'The Clockwork Kingdom' by Augusta Wynter, a brilliant inventor discovers "
"a hidden world of clockwork machines and ancient magic, where a rebellion is "
"brewing against the tyrannical ruler of the land.",
"In 'The Phantom Pilgrim' by Rowan Welles, a charismatic smuggler is hired by "
"a mysterious organization to transport a valuable artifact across a war-torn "
"continent, but soon finds themselves pursued by assassins and rival factions.",
"In 'The Dreamwalker's Journey' by Lyra Snow, a young dreamwalker discovers "
"she has the ability to enter people's dreams, but soon finds herself trapped "
"in a surreal world of nightmares and illusions, where the boundaries between "
"reality and fantasy blur.",
]
sparse_embedding_func = BM25SparseEmbedding(corpus=texts)
docsearch = Milvus.from_texts(
embedding=sparse_embedding_func,
texts=texts,
connection_args={"uri": "./milvus_demo.db"},
drop_old=True,
)

output = docsearch.similarity_search("Pilgrim", k=1)
assert "Pilgrim" in output[0].page_content


def test_milvus_array_field() -> None:
"""Manually specify metadata schema, including an array_field.
For more information about array data type and filtering, please refer to
Expand Down Expand Up @@ -365,4 +391,6 @@ def test_milvus_array_field() -> None:
# test_milvus_enable_dynamic_field()
# test_milvus_disable_dynamic_field()
# test_milvus_metadata_field()
# test_milvus_enable_dynamic_field_with_partition_key()
# test_milvus_sparse_embeddings()
# test_milvus_array_field()

0 comments on commit b5d6704

Please sign in to comment.