Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[vertexai] Add Hybrid Search Capabilities to VertexAI Vector Search #628

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 96 additions & 20 deletions libs/vertexai/langchain_google_vertexai/vectorstores/_searcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, List, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Union

from google.cloud import storage # type: ignore[attr-defined, unused-ignore]
from google.cloud.aiplatform import telemetry
Expand All @@ -8,6 +8,7 @@
MatchingEngineIndexEndpoint,
)
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
HybridQuery,
MatchNeighbor,
Namespace,
NumericNamespace,
Expand All @@ -30,17 +31,33 @@ class Searcher(ABC):
def find_neighbors(
self,
embeddings: List[List[float]],
sparse_embeddings: Optional[List[Dict[str, List[int] | List[float]]]] = None,
lspataroG marked this conversation as resolved.
Show resolved Hide resolved
k: int = 4,
rrf_ranking_alpha: float = 1,
filter_: Union[List[Namespace], None] = None,
numeric_filter: Union[List[NumericNamespace], None] = None,
) -> List[List[Tuple[str, float]]]:
) -> List[List[Dict[str, Any]]]:
lspataroG marked this conversation as resolved.
Show resolved Hide resolved
"""Finds the k closes neighbors of each instance of embeddings.
Args:
embedding: List of embeddings vectors.
sparse_embeddings: List of Sparse embedding dictionaries which represents an
embedding as a list of dimensions and as a list of sparse values:
ie. [{"values": [0.7, 0.5], "dimensions": [10, 20]}]
k: Number of neighbors to be retrieved.
rrf_ranking_alpha: Reciprocal Ranking Fusion weight, float between 0 and 1.0
Weights Dense Search VS Sparse Search, as an example:
- rrf_ranking_alpha=1: Only Dense
- rrf_ranking_alpha=0: Only Sparse
- rrf_ranking_alpha=0.7: 0.7 weighting for dense and 0.3 for sparse
filter_: List of filters to apply.
Returns:
List of lists of Tuples (id, distance) for each embedding vector.
List of records: [
{
"doc_id": doc_id,
"dense_distance": dense_distance,
"sparse_distance": sparse_distance
}
]
"""
raise NotImplementedError()

Expand All @@ -49,6 +66,7 @@ def add_to_index(
self,
ids: List[str],
embeddings: List[List[float]],
sparse_embeddings: Optional[List[Dict[str, List[int] | List[float]]]] = None,
metadatas: Union[List[dict], None] = None,
is_complete_overwrite: bool = False,
**kwargs: Any,
Expand All @@ -58,6 +76,7 @@ def add_to_index(
Args:
ids: List of unique ids.
embeddings: List of embedddings for each record.
sparse_embeddings: List of sparse embedddings for each record.
metadatas: List of metadata of each record.
"""
raise NotImplementedError()
Expand All @@ -80,21 +99,36 @@ def get_datapoints_by_filter(

def _postprocess_response(
self, response: List[List[MatchNeighbor]]
) -> List[List[Tuple[str, float]]]:
"""Posproceses an endpoint response and converts it to a list of list of
tuples instead of using vertexai objects.
) -> List[List[Dict[str, Any]]]:
"""Posproceses an endpoint response and converts it to a list of list of records
instead of using vertexai objects.
Args:
response: Endpoint response.
Returns:
List of list of tuples of (id, distance).
"""
return [
[
(neighbor.id, cast(float, neighbor.distance))
for neighbor in matching_neighbor_list
List of records: [
{
"doc_id": doc_id,
"dense_distance": dense_distance,
"sparse_distance": sparse_distance
}
]
for matching_neighbor_list in response
]
"""
queries_results = []
for matching_neighbor_list in response:
query_results = []
for neighbor in matching_neighbor_list:
dense_dist = neighbor.distance if neighbor.distance else 0.0
sparse_dist = (
neighbor.sparse_distance if neighbor.sparse_distance else 0.0
)
result = {
"doc_id": neighbor.id,
"dense_distance": dense_dist,
"sparse_distance": sparse_dist,
}
query_results.append(result)
queries_results.append(query_results)
return queries_results


class VectorSearchSearcher(Searcher):
Expand Down Expand Up @@ -139,7 +173,7 @@ def get_datapoints_by_filter(
neighbors = self.find_neighbors(
embeddings=embeddings, k=max_datapoints, filter_=filter_
)
return [_id for (_id, _) in neighbors[0]] if neighbors else []
return [elem["doc_id"] for elem in neighbors[0]] if neighbors else []

def remove_datapoints(
self,
Expand All @@ -152,6 +186,7 @@ def add_to_index(
self,
ids: List[str],
embeddings: List[List[float]],
sparse_embeddings: Optional[List[Dict[str, List[int] | List[float]]]] = None,
metadatas: Union[List[dict], None] = None,
is_complete_overwrite: bool = False,
**kwargs: Any,
Expand All @@ -161,11 +196,17 @@ def add_to_index(
Args:
ids: List of unique ids.
embeddings: List of embedddings for each record.
sparse_embeddings: List of sparse embedddings for each record.
metadatas: List of metadata of each record.
is_complete_overwrite: Whether to overwrite everything.
"""

data_points = to_data_points(ids, embeddings, metadatas)
data_points = to_data_points(
ids=ids,
embeddings=embeddings,
sparse_embeddings=sparse_embeddings,
metadatas=metadatas,
)

if self._stream_update:
stream_update_index(index=self._index, data_points=data_points)
Expand All @@ -185,26 +226,61 @@ def add_to_index(
def find_neighbors(
self,
embeddings: List[List[float]],
sparse_embeddings: Optional[List[Dict[str, List[int] | List[float]]]] = None,
k: int = 4,
rrf_ranking_alpha: float = 1,
filter_: Union[List[Namespace], None] = None,
numeric_filter: Union[List[NumericNamespace], None] = None,
) -> List[List[Tuple[str, float]]]:
) -> List[List[Dict[str, Any]]]:
"""Finds the k closes neighbors of each instance of embeddings.
Args:
embedding: List of embeddings vectors.
embeddings: List of embedding vectors.
sparse_embeddings: List of Sparse embedding dictionaries which represents an
embedding as a list of dimensions and as a list of sparse values:
ie. [{"values": [0.7, 0.5], "dimensions": [10, 20]}]
k: Number of neighbors to be retrieved.
rrf_ranking_alpha: Reciprocal Ranking Fusion weight, float between 0 and 1.0
Weights Dense Search VS Sparse Search, as an example:
- rrf_ranking_alpha=1: Only Dense
- rrf_ranking_alpha=0: Only Sparse
- rrf_ranking_alpha=0.7: 0.7 weighting for dense and 0.3 for sparse
filter_: List of filters to apply.
Returns:
List of lists of Tuples (id, distance) for each embedding vector.
List of records: [
{
"doc_id": doc_id,
"dense_distance": dense_distance,
"sparse_distance": sparse_distance
}
]
"""

# No need to implement other method for private VPC, find_neighbors now works
# with public and private.
_, user_agent = get_user_agent("vertex-ai-matching-engine")
with telemetry.tool_context_manager(user_agent):
if sparse_embeddings is None:
queries = embeddings
else:
if len(sparse_embeddings) != len(embeddings):
raise ValueError(
"The number of `sparse_embeddings` should match the number of "
f"`embeddings` {len(sparse_embeddings)} != {len(embeddings)}"
)
queries = []

for embedding, sparse_embedding in zip(embeddings, sparse_embeddings):
hybrid_query = HybridQuery(
sparse_embedding_dimensions=sparse_embedding["dimensions"], # type: ignore
sparse_embedding_values=sparse_embedding["values"], # type: ignore
dense_embedding=embedding,
rrf_ranking_alpha=rrf_ranking_alpha,
)
queries.append(hybrid_query) # type: ignore

response = self._endpoint.find_neighbors(
deployed_index_id=self._deployed_index_id,
queries=embeddings,
queries=queries,
num_neighbors=k,
filter=filter_,
numeric_filter=numeric_filter,
Expand Down
13 changes: 10 additions & 3 deletions libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import uuid
import warnings
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

from google.cloud.aiplatform import MatchingEngineIndex
from google.cloud.aiplatform.compat.types import ( # type: ignore[attr-defined, unused-ignore]
Expand Down Expand Up @@ -65,7 +65,8 @@ def batch_update_index(
def to_data_points(
ids: List[str],
embeddings: List[List[float]],
metadatas: Union[List[Dict[str, Any]], None],
sparse_embeddings: Optional[List[Dict[str, List[int] | List[float]]]] = None,
metadatas: Union[List[Dict[str, Any]], None] = None,
) -> List["meidx_types.IndexDataPoint"]:
"""Converts triplets id, embedding, metadata into IndexDataPoints instances.

Expand All @@ -81,10 +82,15 @@ def to_data_points(
if metadatas is None:
metadatas = [{}] * len(ids)

if sparse_embeddings is None:
sparse_embeddings = [{"values": [], "dimensions": []}] * len(ids)

data_points = []
ignored_fields = set()

for id_, embedding, metadata in zip(ids, embeddings, metadatas):
for id_, embedding, sparse_embedding, metadata in zip(
ids, embeddings, sparse_embeddings, metadatas
):
restricts = []
numeric_restricts = []

Expand Down Expand Up @@ -122,6 +128,7 @@ def to_data_points(
data_point = meidx_types.IndexDatapoint(
datapoint_id=id_,
feature_vector=embedding,
sparse_embedding=sparse_embedding,
restricts=restricts,
numeric_restricts=numeric_restricts,
)
Expand Down
Loading
Loading