Skip to content

Commit

Permalink
Added warning for unused fields in filtering. Added unit tests for to…
Browse files Browse the repository at this point in the history
…_datapoint
  • Loading branch information
Jorge committed Mar 13, 2024
1 parent 38400a3 commit 8f78490
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
11 changes: 11 additions & 0 deletions libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import uuid
import warnings
from typing import Any, Dict, List, Union

from google.cloud.aiplatform import MatchingEngineIndex
Expand Down Expand Up @@ -81,6 +82,7 @@ def to_data_points(
metadatas = [{}] * len(ids)

data_points = []
ignored_fields = set()

for id_, embedding, metadata in zip(ids, embeddings, metadatas):
restricts = []
Expand All @@ -107,6 +109,15 @@ def to_data_points(
namespace=namespace, value_float=value
)
numeric_restricts.append(restriction)
else:
ignored_fields.add(namespace)

if len(ignored_fields) > 0:
warnings.warn(
f"Some values in fields {', '.join(ignored_fields)} are not usable for"
f" restrictions. In order to be used they must be str, list[str] or"
f" numeric."
)

data_point = meidx_types.IndexDatapoint(
datapoint_id=id_,
Expand Down
49 changes: 49 additions & 0 deletions libs/vertexai/tests/unit_tests/test_vectorstores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest

from langchain_google_vertexai.vectorstores._utils import to_data_points


def test_to_data_points():
ids = ["Id1"]
embeddings = [[0.0, 0.0]]
metadatas = [
{
"some_string": "string",
"some_number": 1.1,
"some_list": ["a", "b"],
"some_random_object": {"foo": 1, "bar": 2},
}
]

with pytest.warns():
result = to_data_points(ids, embeddings, metadatas)

assert isinstance(result, list)
assert len(result) == 1

datapoint = result[0]
datapoint.datapoint_id == "Id1"
for component_emb, component_fv in (datapoint.feature_vector, embeddings[0]):
assert component_emb == pytest.approx(component_fv)

metadata = metadatas[0]

restriction_lookup = {
restriction.namespace: restriction for restriction in datapoint.restricts
}

restriction = restriction_lookup.pop("some_string")
assert restriction.allow_list == [metadata["some_string"]]

restriction = restriction_lookup.pop("some_list")
assert restriction.allow_list == metadata["some_list"]

assert len(restriction_lookup) == 0

num_restriction_lookup = {
restriction.namespace: restriction
for restriction in datapoint.numeric_restricts
}
restriction = num_restriction_lookup.pop("some_number")
assert round(restriction.value_float, 1) == pytest.approx(metadata["some_number"])
assert len(num_restriction_lookup) == 0

0 comments on commit 8f78490

Please sign in to comment.