diff --git a/integration/test_named_vectors.py b/integration/test_named_vectors.py index a544ba54f..21be17251 100644 --- a/integration/test_named_vectors.py +++ b/integration/test_named_vectors.py @@ -11,10 +11,12 @@ _VectorIndexConfigHNSW, _VectorIndexConfigFlat, Vectorizers, + ReferenceProperty, ) from weaviate.collections.classes.data import DataObject from weaviate.collections.classes.grpc import _MultiTargetVectorJoin from weaviate.exceptions import WeaviateInvalidInputError +from weaviate.types import INCLUDE_VECTOR def test_create_named_vectors_throws_error_in_old_version( @@ -756,3 +758,46 @@ def test_deprecated_syntax(collection_factory: CollectionFactory): return_metadata=wvc.query.MetadataQuery.full(), ) assert "Providing lists of lists has been deprecated" in str(e) + + +@pytest.mark.parametrize( + "include_vector, expected", + [ + (False, {}), + (["bringYourOwn1"], {"bringYourOwn1": [0, 1, 2]}), + # TODO: to be uncommented when https://github.com/weaviate/weaviate/issues/6279 is resolved + # (True, {"bringYourOwn1": [0, 1, 2], "bringYourOwn2": [3, 4, 5]}) + ], +) +def test_include_vector_on_references( + collection_factory: CollectionFactory, include_vector: INCLUDE_VECTOR, expected: dict +) -> None: + """Test include vector on reference""" + dummy = collection_factory() + if dummy._connection._weaviate_version.is_lower_than(1, 24, 0): + pytest.skip("Named vectorizers are only supported in Weaviate v1.24.0 and higher.") + + ref_collection = collection_factory( + name="Target", + vectorizer_config=[ + wvc.config.Configure.NamedVectors.none(name="bringYourOwn1"), + wvc.config.Configure.NamedVectors.none(name="bringYourOwn2"), + ], + ) + + TO_UUID = ref_collection.data.insert( + properties={}, vector={"bringYourOwn1": [0, 1, 2], "bringYourOwn2": [3, 4, 5]} + ) + + collection = collection_factory( + name="Source", + references=[ReferenceProperty(name="hasRef", target_collection=ref_collection.name)], + ) + + collection.data.insert({}, references={"hasRef": TO_UUID}) + + objs = collection.query.fetch_objects( + return_references=wvc.query.QueryReference(link_on="hasRef", include_vector=include_vector) + ).objects + + assert objs[0].references["hasRef"].objects[0].vector == expected diff --git a/weaviate/collections/classes/grpc.py b/weaviate/collections/classes/grpc.py index a708e6d58..c053aca42 100644 --- a/weaviate/collections/classes/grpc.py +++ b/weaviate/collections/classes/grpc.py @@ -433,7 +433,7 @@ def near_vector( class _QueryReference(_WeaviateInput): link_on: str - include_vector: bool = Field(default=False) + include_vector: INCLUDE_VECTOR = Field(default=False) return_metadata: Optional[MetadataQuery] = Field(default=None) return_properties: Union["PROPERTIES", bool, None] = Field(default=None) return_references: Optional["REFERENCES"] = Field(default=None)