Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor backend of aggregate API to use gRPC #1522

Merged
merged 14 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ env:
WEAVIATE_126: 1.26.13
WEAVIATE_127: 1.27.9
WEAVIATE_128: 1.28.3
WEAVIATE_129: 1.29.0-dev-7b81c72
WEAVIATE_129: 1.29.0-dev-35036a8

jobs:
lint-and-format:
Expand Down
31 changes: 22 additions & 9 deletions integration/test_collection_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,17 +201,24 @@ def test_over_all_with_filters_ref(collection_factory: CollectionFactory) -> Non
assert res.properties["text"].count == 1
assert res.properties["text"].top_occurrences[0].value == "two"

with pytest.raises(WeaviateInvalidInputError):
res = collection.aggregate.over_all(
filters=Filter.by_ref("ref")
.by_property("text")
.equal("one"), # gRPC-compat API not support by GQL aggregation
return_metrics=[Metrics("text").text(count=True, top_occurrences_value=True)],
)
query = lambda: collection.aggregate.over_all(
filters=Filter.by_ref("ref").by_property("text").equal("one"),
return_metrics=[Metrics("text").text(count=True, top_occurrences_value=True)],
)
if collection._connection._weaviate_version.is_lower_than(1, 29, 0):
with pytest.raises(WeaviateInvalidInputError):
query()
else:
res = query()
assert isinstance(res.properties["text"], AggregateText)
assert res.properties["text"].count == 1
assert res.properties["text"].top_occurrences[0].value == "two"


def test_wrong_aggregation(collection_factory: CollectionFactory) -> None:
collection = collection_factory(properties=[Property(name="text", data_type=DataType.TEXT)])
if collection._connection._weaviate_version.is_at_least(1, 29, 0):
pytest.skip("GQL is only used for versions 1.28.4 and lower")
with pytest.raises(WeaviateQueryError) as e:
collection.aggregate.over_all(total_count=False)
assert (
Expand Down Expand Up @@ -658,13 +665,19 @@ def test_group_by_aggregation_argument(collection_factory: CollectionFactory) ->
groups = res.groups
assert len(groups) == 2
assert groups[0].grouped_by.prop == "int"
assert groups[0].grouped_by.value == "1" or groups[1].grouped_by.value == "1"
if collection._connection._weaviate_version.is_lower_than(1, 29, 0):
assert groups[0].grouped_by.value == "1" or groups[1].grouped_by.value == "1"
else:
assert groups[0].grouped_by.value == 1 or groups[1].grouped_by.value == 1
assert isinstance(groups[0].properties["text"], AggregateText)
assert groups[0].properties["text"].count == 1
assert isinstance(groups[0].properties["int"], AggregateInteger)
assert groups[0].properties["int"].count == 1
assert groups[1].grouped_by.prop == "int"
assert groups[1].grouped_by.value == "2" or groups[0].grouped_by.value == "2"
if collection._connection._weaviate_version.is_lower_than(1, 29, 0):
assert groups[1].grouped_by.value == "2" or groups[0].grouped_by.value == "2"
else:
assert groups[1].grouped_by.value == 2 or groups[0].grouped_by.value == 2
assert isinstance(groups[1].properties["text"], AggregateText)
assert groups[1].properties["text"].count == 1
assert isinstance(groups[1].properties["int"], AggregateInteger)
Expand Down
6 changes: 3 additions & 3 deletions test/collection/test_aggregates.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import pytest
from typing import Awaitable
from typing import Awaitable, Callable
from weaviate.connect import ConnectionV4
from weaviate.collections.aggregate import _AggregateCollectionAsync
from weaviate.exceptions import WeaviateInvalidInputError


async def _test_aggregate(aggregate: Awaitable) -> None:
async def _test_aggregate(aggregate: Callable[[], Awaitable]) -> None:
with pytest.raises(WeaviateInvalidInputError):
await aggregate()


@pytest.mark.asyncio
async def test_bad_aggregate_inputs(connection: ConnectionV4) -> None:
aggregate = _AggregateCollectionAsync(connection, "dummy", None, None)
aggregate = _AggregateCollectionAsync(connection, "dummy", None, None, False)
# over_all
await _test_aggregate(lambda: aggregate.over_all(filters="wrong"))
await _test_aggregate(lambda: aggregate.over_all(group_by=42))
Expand Down
154 changes: 147 additions & 7 deletions weaviate/collections/aggregations/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
AggregateDate,
AggregateInteger,
AggregateNumber,
# AggregateReference, # Aggregate references currently bugged on Weaviate's side
AggregateReference,
AggregateText,
AggregateGroup,
AggregateGroupByReturn,
Expand All @@ -25,21 +25,25 @@
_MetricsDate,
_MetricsNumber,
_MetricsInteger,
# _MetricsReference, # Aggregate references currently bugged on Weaviate's side
_MetricsReference,
dirkkul marked this conversation as resolved.
Show resolved Hide resolved
_MetricsText,
GroupedBy,
TopOccurrence,
)
from weaviate.collections.classes.config import ConsistencyLevel
from weaviate.collections.classes.filters import _Filters
from weaviate.collections.classes.grpc import Move
from weaviate.collections.classes.types import GeoCoordinate
from weaviate.collections.filters import _FilterToREST
from weaviate.collections.grpc.aggregate import _AggregateGRPC
from weaviate.connect import ConnectionV4
from weaviate.exceptions import WeaviateInvalidInputError, WeaviateQueryError
from weaviate.gql.aggregate import AggregateBuilder
from weaviate.proto.v1 import aggregate_pb2
from weaviate.types import NUMBER, UUID
from weaviate.util import file_encoder_b64, _decode_json_response_dict
from weaviate.validator import _ValidateArgument, _validate_input
from weaviate.warnings import _Warnings

P = ParamSpec("P")
T = TypeVar("T")
Expand All @@ -52,11 +56,19 @@ def __init__(
name: str,
consistency_level: Optional[ConsistencyLevel],
tenant: Optional[str],
validate_arguments: bool,
):
self._connection = connection
self.__name = name
self._tenant = tenant
self._consistency_level = consistency_level
self._grpc = _AggregateGRPC(
connection=connection,
name=name,
tenant=tenant,
consistency_level=consistency_level,
validate_arguments=validate_arguments,
)

def _query(self) -> AggregateBuilder:
return AggregateBuilder(
Expand All @@ -77,6 +89,77 @@ def _to_aggregate_result(
f"There was an error accessing the {e} key when parsing the GraphQL response: {response}"
)

def _to_result(
self, response: aggregate_pb2.AggregateReply
) -> Union[AggregateReturn, AggregateGroupByReturn]:
if response.HasField("single_result"):
return AggregateReturn(
properties={
aggregation.property: self.__parse_property_grpc(aggregation)
for aggregation in response.single_result.aggregations.aggregations
},
total_count=response.single_result.objects_count,
)
if response.HasField("grouped_results"):
return AggregateGroupByReturn(
groups=[
AggregateGroup(
grouped_by=self.__parse_grouped_by_value(group.grouped_by),
properties={
aggregation.property: self.__parse_property_grpc(aggregation)
for aggregation in group.aggregations.aggregations
},
total_count=group.objects_count,
)
for group in response.grouped_results.groups
]
)
else:
_Warnings.unknown_type_encountered(response.WhichOneof("result"))
return AggregateReturn(properties={}, total_count=None)

def __parse_grouped_by_value(
self, grouped_by: aggregate_pb2.AggregateReply.Group.GroupedBy
) -> GroupedBy:
value: Union[
str,
int,
float,
bool,
List[str],
List[int],
List[float],
List[bool],
GeoCoordinate,
None,
]
if grouped_by.HasField("text"):
value = grouped_by.text
elif grouped_by.HasField("int"):
value = grouped_by.int
elif grouped_by.HasField("number"):
value = grouped_by.number
elif grouped_by.HasField("boolean"):
value = grouped_by.boolean
elif grouped_by.HasField("texts"):
value = list(grouped_by.texts.values)
elif grouped_by.HasField("ints"):
value = list(grouped_by.ints.values)
elif grouped_by.HasField("numbers"):
value = list(grouped_by.numbers.values)
elif grouped_by.HasField("booleans"):
value = list(grouped_by.booleans.values)
elif grouped_by.HasField("geo"):
v = grouped_by.geo
value = GeoCoordinate(
latitude=v.latitude,
longitude=v.longitude,
)
else:
value = None
_Warnings.unknown_type_encountered(grouped_by.WhichOneof("value"))
return GroupedBy(prop=grouped_by.path[0], value=value)

def _to_group_by_result(
self, response: dict, metrics: Optional[List[_Metrics]]
) -> AggregateGroupByReturn:
Expand Down Expand Up @@ -108,13 +191,13 @@ def __parse_properties(self, result: dict, metrics: List[_Metrics]) -> AProperti
props: AProperties = {}
for metric in metrics:
if metric.property_name in result:
props[metric.property_name] = self.__parse_property(
props[metric.property_name] = self.__parse_property_gql(
result[metric.property_name], metric
)
return props

@staticmethod
def __parse_property(property_: dict, metric: _Metrics) -> AggregateResult:
def __parse_property_gql(property_: dict, metric: _Metrics) -> AggregateResult:
if isinstance(metric, _MetricsText):
return AggregateText(
count=property_.get("count"),
Expand Down Expand Up @@ -162,14 +245,71 @@ def __parse_property(property_: dict, metric: _Metrics) -> AggregateResult:
minimum=property_.get("minimum"),
mode=property_.get("mode"),
)
# Aggregate references currently bugged on Weaviate's side
# elif isinstance(metric, _MetricsReference):
# return AggregateReference(pointing_to=property_.get("pointingTo"))
elif isinstance(metric, _MetricsReference):
return AggregateReference(pointing_to=property_.get("pointingTo"))
else:
raise ValueError(
f"Unknown aggregation type {metric} encountered in _Aggregate.__parse_property() for property {property_}"
)

@staticmethod
def __parse_property_grpc(
aggregation: aggregate_pb2.AggregateReply.Aggregations.Aggregation,
) -> AggregateResult:
if aggregation.HasField("text"):
return AggregateText(
count=aggregation.text.count,
top_occurrences=[
TopOccurrence(
count=top_occurrence.occurs,
value=top_occurrence.value,
)
for top_occurrence in aggregation.text.top_occurences.items
],
)
elif aggregation.HasField("int"):
return AggregateInteger(
count=aggregation.int.count,
maximum=aggregation.int.maximum,
mean=aggregation.int.mean,
median=aggregation.int.median,
minimum=aggregation.int.minimum,
mode=aggregation.int.mode,
sum_=aggregation.int.sum,
)
elif aggregation.HasField("number"):
return AggregateNumber(
count=aggregation.number.count,
maximum=aggregation.number.maximum,
mean=aggregation.number.mean,
median=aggregation.number.median,
minimum=aggregation.number.minimum,
mode=aggregation.number.mode,
sum_=aggregation.number.sum,
)
elif aggregation.HasField("boolean"):
return AggregateBoolean(
count=aggregation.boolean.count,
percentage_false=aggregation.boolean.percentage_false,
percentage_true=aggregation.boolean.percentage_true,
total_false=aggregation.boolean.total_false,
total_true=aggregation.boolean.total_true,
)
elif aggregation.HasField("date"):
return AggregateDate(
count=aggregation.date.count,
maximum=aggregation.date.maximum,
median=aggregation.date.median,
minimum=aggregation.date.minimum,
mode=aggregation.date.mode,
)
elif aggregation.HasField("reference"):
return AggregateReference(pointing_to=list(aggregation.reference.pointing_to))
else:
raise ValueError(
f"Unknown aggregation type {aggregation} encountered in _Aggregate.__parse_property_grpc()"
)

@staticmethod
def _add_groupby_to_builder(
builder: AggregateBuilder, group_by: Union[str, GroupByAggregate, None]
Expand Down
65 changes: 47 additions & 18 deletions weaviate/collections/aggregations/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
GroupByAggregate,
)
from weaviate.collections.classes.filters import _Filters
from weaviate.collections.filters import _FilterToGRPC
from weaviate.exceptions import WeaviateUnsupportedFeatureError
from weaviate.types import NUMBER

Expand Down Expand Up @@ -69,24 +70,52 @@ async def hybrid(
if (return_metrics is None or isinstance(return_metrics, list))
else [return_metrics]
)
builder = self._base(return_metrics, filters, total_count)
builder = self._add_hybrid_to_builder(
builder,
query,
alpha,
vector,
query_properties,
object_limit,
target_vector,
max_vector_distance,
)
builder = self._add_groupby_to_builder(builder, group_by)
res = await self._do(builder)
return (
self._to_aggregate_result(res, return_metrics)
if group_by is None
else self._to_group_by_result(res, return_metrics)
)

if isinstance(group_by, str):
group_by = GroupByAggregate(prop=group_by)

if self._connection._weaviate_version.is_lower_than(1, 29, 0):
# use gql, remove once 1.29 is the minimum supported version

builder = self._base(return_metrics, filters, total_count)
builder = self._add_hybrid_to_builder(
builder,
query,
alpha,
vector,
query_properties,
object_limit,
target_vector,
max_vector_distance,
)
builder = self._add_groupby_to_builder(builder, group_by)
res = await self._do(builder)
return (
self._to_aggregate_result(res, return_metrics)
if group_by is None
else self._to_group_by_result(res, return_metrics)
)
else:
# use grpc
reply = await self._grpc.hybrid(
query=query,
alpha=alpha,
vector=vector,
properties=query_properties,
object_limit=object_limit,
target_vector=target_vector,
distance=max_vector_distance,
aggregations=(
[metric.to_grpc() for metric in return_metrics]
if return_metrics is not None
else []
),
filters=_FilterToGRPC.convert(filters) if filters is not None else None,
group_by=group_by._to_grpc() if group_by is not None else None,
limit=group_by.limit if group_by is not None else None,
objects_count=total_count,
)
return self._to_result(reply)


@syncify.convert
Expand Down
Loading
Loading