diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py b/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py index 5538aaf1..fc6f7b77 100644 --- a/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py @@ -1,5 +1,6 @@ import json import uuid +import warnings from typing import Any, Dict, List, Union from google.cloud.aiplatform import MatchingEngineIndex @@ -81,6 +82,7 @@ def to_data_points( metadatas = [{}] * len(ids) data_points = [] + ignored_fields = set() for id_, embedding, metadata in zip(ids, embeddings, metadatas): restricts = [] @@ -107,6 +109,15 @@ def to_data_points( namespace=namespace, value_float=value ) numeric_restricts.append(restriction) + else: + ignored_fields.add(namespace) + + if len(ignored_fields) > 0: + warnings.warn( + f"Some values in fields {', '.join(ignored_fields)} are not usable for" + f" restrictions. In order to be used they must be str, list[str] or" + f" numeric." + ) data_point = meidx_types.IndexDatapoint( datapoint_id=id_, diff --git a/libs/vertexai/tests/unit_tests/test_vectorstores.py b/libs/vertexai/tests/unit_tests/test_vectorstores.py new file mode 100644 index 00000000..9b1c30bc --- /dev/null +++ b/libs/vertexai/tests/unit_tests/test_vectorstores.py @@ -0,0 +1,49 @@ +import pytest + +from langchain_google_vertexai.vectorstores._utils import to_data_points + + +def test_to_data_points(): + ids = ["Id1"] + embeddings = [[0.0, 0.0]] + metadatas = [ + { + "some_string": "string", + "some_number": 1.1, + "some_list": ["a", "b"], + "some_random_object": {"foo": 1, "bar": 2}, + } + ] + + with pytest.warns(): + result = to_data_points(ids, embeddings, metadatas) + + assert isinstance(result, list) + assert len(result) == 1 + + datapoint = result[0] + datapoint.datapoint_id == "Id1" + for component_emb, component_fv in (datapoint.feature_vector, embeddings[0]): + assert component_emb == pytest.approx(component_fv) + + metadata = metadatas[0] + + restriction_lookup = { + restriction.namespace: restriction for restriction in datapoint.restricts + } + + restriction = restriction_lookup.pop("some_string") + assert restriction.allow_list == [metadata["some_string"]] + + restriction = restriction_lookup.pop("some_list") + assert restriction.allow_list == metadata["some_list"] + + assert len(restriction_lookup) == 0 + + num_restriction_lookup = { + restriction.namespace: restriction + for restriction in datapoint.numeric_restricts + } + restriction = num_restriction_lookup.pop("some_number") + assert round(restriction.value_float, 1) == pytest.approx(metadata["some_number"]) + assert len(num_restriction_lookup) == 0