From 28aebaeaf5600503669c0c2c995674fb2ad9db1a Mon Sep 17 00:00:00 2001 From: Fabio Massimo Ercoli Date: Thu, 27 Jun 2024 14:58:47 +0200 Subject: [PATCH] HSEARCH-5133 Support avg field-typed --- .../ElasticsearchMetricDoubleAggregation.java | 83 ------------------- .../ElasticsearchMetricFieldAggregation.java | 18 +++- ...sticsearchNumericFieldTypeOptionsStep.java | 3 +- .../dsl/AvgAggregationFieldStep.java | 18 +++- .../dsl/AvgAggregationOptionsStep.java | 8 +- .../dsl/impl/AvgAggregationFieldStepImpl.java | 16 ++-- .../impl/AvgAggregationOptionsStepImpl.java | 17 ++-- .../aggregation/spi/AggregationTypeKeys.java | 2 +- .../ElasticsearchMetricAggregationsIT.java | 25 +++--- 9 files changed, 74 insertions(+), 116 deletions(-) delete mode 100644 backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/impl/ElasticsearchMetricDoubleAggregation.java diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/impl/ElasticsearchMetricDoubleAggregation.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/impl/ElasticsearchMetricDoubleAggregation.java deleted file mode 100644 index 9432d08d360..00000000000 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/impl/ElasticsearchMetricDoubleAggregation.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright Red Hat Inc. and Hibernate Authors - */ -package org.hibernate.search.backend.elasticsearch.search.aggregation.impl; - -import org.hibernate.search.backend.elasticsearch.search.common.impl.AbstractElasticsearchCodecAwareSearchQueryElementFactory; -import org.hibernate.search.backend.elasticsearch.search.common.impl.ElasticsearchSearchIndexScope; -import org.hibernate.search.backend.elasticsearch.search.common.impl.ElasticsearchSearchIndexValueFieldContext; -import org.hibernate.search.backend.elasticsearch.types.codec.impl.ElasticsearchDoubleFieldCodec; -import org.hibernate.search.backend.elasticsearch.types.codec.impl.ElasticsearchFieldCodec; -import org.hibernate.search.engine.search.aggregation.spi.SearchFilterableAggregationBuilder; - -import com.google.gson.JsonElement; -import com.google.gson.JsonObject; - -public class ElasticsearchMetricDoubleAggregation extends AbstractElasticsearchNestableAggregation { - - private final String absoluteFieldPath; - private final String operation; - - public ElasticsearchMetricDoubleAggregation(Builder builder) { - super( builder ); - this.absoluteFieldPath = builder.field.absolutePath(); - this.operation = builder.operation; - } - - @Override - protected final JsonObject doRequest(AggregationRequestContext context) { - JsonObject outerObject = new JsonObject(); - JsonObject innerObject = new JsonObject(); - - outerObject.add( operation, innerObject ); - innerObject.addProperty( "field", absoluteFieldPath ); - return outerObject; - } - - @Override - protected Extractor extractor(AggregationRequestContext context) { - return new MetricDoubleExtractor(); - } - - public static class Factory - extends - AbstractElasticsearchCodecAwareSearchQueryElementFactory, F> { - - private final String operation; - - public Factory(ElasticsearchFieldCodec codec, String operation) { - super( codec ); - this.operation = operation; - } - - @Override - public SearchFilterableAggregationBuilder create(ElasticsearchSearchIndexScope scope, - ElasticsearchSearchIndexValueFieldContext field) { - return new Builder( scope, field, operation ); - } - } - - private static class MetricDoubleExtractor implements Extractor { - @Override - public Double extract(JsonObject aggregationResult, AggregationExtractContext context) { - JsonElement value = aggregationResult.get( "value" ); - return ElasticsearchDoubleFieldCodec.INSTANCE.decode( value ); - } - } - - private static class Builder extends AbstractBuilder implements SearchFilterableAggregationBuilder { - private final String operation; - - private Builder(ElasticsearchSearchIndexScope scope, ElasticsearchSearchIndexValueFieldContext field, - String operation) { - super( scope, field ); - this.operation = operation; - } - - @Override - public ElasticsearchMetricDoubleAggregation build() { - return new ElasticsearchMetricDoubleAggregation( this ); - } - } -} diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/impl/ElasticsearchMetricFieldAggregation.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/impl/ElasticsearchMetricFieldAggregation.java index 05e0d5880d3..41ea43ee0c2 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/impl/ElasticsearchMetricFieldAggregation.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/impl/ElasticsearchMetricFieldAggregation.java @@ -7,6 +7,7 @@ import org.hibernate.search.backend.elasticsearch.search.common.impl.AbstractElasticsearchCodecAwareSearchQueryElementFactory; import org.hibernate.search.backend.elasticsearch.search.common.impl.ElasticsearchSearchIndexScope; import org.hibernate.search.backend.elasticsearch.search.common.impl.ElasticsearchSearchIndexValueFieldContext; +import org.hibernate.search.backend.elasticsearch.types.codec.impl.ElasticsearchDoubleFieldCodec; import org.hibernate.search.backend.elasticsearch.types.codec.impl.ElasticsearchFieldCodec; import org.hibernate.search.engine.backend.types.converter.runtime.FromDocumentValueConvertContext; import org.hibernate.search.engine.backend.types.converter.spi.ProjectionConverter; @@ -85,18 +86,31 @@ private TypeSelector(ElasticsearchFieldCodec codec, @Override public Builder type(Class expectedType, ValueConvert convert) { + ProjectionConverter projectionConverter = null; + if ( !Double.class.isAssignableFrom( expectedType ) || + field.type().projectionConverter( convert ).valueType().isAssignableFrom( expectedType ) ) { + projectionConverter = field.type().projectionConverter( convert ) + .withConvertedType( expectedType, field ); + } return new Builder<>( codec, scope, field, - field.type().projectionConverter( convert ).withConvertedType( expectedType, field ), - operation ); + projectionConverter, + operation + ); } } + @SuppressWarnings("unchecked") private class MetricFieldExtractor implements Extractor { @Override public K extract(JsonObject aggregationResult, AggregationExtractContext context) { FromDocumentValueConvertContext convertContext = context.fromDocumentValueConvertContext(); JsonElement value = aggregationResult.get( "value" ); JsonElement valueAsString = aggregationResult.get( "value_as_string" ); + + if ( fromFieldValueConverter == null ) { + Double decode = ElasticsearchDoubleFieldCodec.INSTANCE.decode( value ); + return (K) decode; + } return fromFieldValueConverter.fromDocumentValue( codec.decodeAggregationValue( value, valueAsString ), convertContext diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/types/dsl/impl/AbstractElasticsearchNumericFieldTypeOptionsStep.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/types/dsl/impl/AbstractElasticsearchNumericFieldTypeOptionsStep.java index a926f1de732..2b59cfc1cc0 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/types/dsl/impl/AbstractElasticsearchNumericFieldTypeOptionsStep.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/types/dsl/impl/AbstractElasticsearchNumericFieldTypeOptionsStep.java @@ -4,7 +4,6 @@ */ package org.hibernate.search.backend.elasticsearch.types.dsl.impl; -import org.hibernate.search.backend.elasticsearch.search.aggregation.impl.ElasticsearchMetricDoubleAggregation; import org.hibernate.search.backend.elasticsearch.search.aggregation.impl.ElasticsearchMetricFieldAggregation; import org.hibernate.search.backend.elasticsearch.search.aggregation.impl.ElasticsearchMetricLongAggregation; import org.hibernate.search.backend.elasticsearch.search.aggregation.impl.ElasticsearchRangeAggregation; @@ -77,7 +76,7 @@ protected final void complete() { builder.queryElementFactory( AggregationTypeKeys.COUNT_DISTINCT, new ElasticsearchMetricLongAggregation.Factory<>( codec, "cardinality" ) ); builder.queryElementFactory( AggregationTypeKeys.AVG, - new ElasticsearchMetricDoubleAggregation.Factory<>( codec, "avg" ) + new ElasticsearchMetricFieldAggregation.Factory<>( codec, "avg" ) ); } } diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/AvgAggregationFieldStep.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/AvgAggregationFieldStep.java index 1f93b0b47a8..8c609500303 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/AvgAggregationFieldStep.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/AvgAggregationFieldStep.java @@ -20,9 +20,25 @@ public interface AvgAggregationFieldStep { * Target the given field in the avg aggregation. * * @param fieldPath The path to the index field to aggregate. + * @param type The type of field values. + * @param The type of field values or {@link Double} if a double result is required. + * @return The next step. + */ + default AvgAggregationOptionsStep field(String fieldPath, Class type) { + return field( fieldPath, type, ValueConvert.YES ); + } + + /** + * Target the given field in the avg aggregation. + * + * @param fieldPath The path to the index field to aggregate. + * @param type The type of field values. + * @param The type of field values or {@link Double} if a double result is required. + * @param convert Controls how the ranges passed to the next steps and fetched from the backend should be converted. * See {@link ValueConvert}. * @return The next step. */ - AvgAggregationOptionsStep field(String fieldPath); + AvgAggregationOptionsStep field(String fieldPath, Class type, + ValueConvert convert); } diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/AvgAggregationOptionsStep.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/AvgAggregationOptionsStep.java index 0336f9e734e..a8a2204260c 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/AvgAggregationOptionsStep.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/AvgAggregationOptionsStep.java @@ -13,10 +13,12 @@ * * @param The "self" type (the actual exposed type of this step). * @param The type of factory used to create predicates in {@link #filter(Function)}. + * @param The type of the targeted field. The type of result for this aggregation. */ public interface AvgAggregationOptionsStep< - S extends AvgAggregationOptionsStep, - PDF extends SearchPredicateFactory> - extends AggregationFinalStep, AggregationFilterStep { + S extends AvgAggregationOptionsStep, + PDF extends SearchPredicateFactory, + F> + extends AggregationFinalStep, AggregationFilterStep { } diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/AvgAggregationFieldStepImpl.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/AvgAggregationFieldStepImpl.java index 4e0da0c55f9..483340f181c 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/AvgAggregationFieldStepImpl.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/AvgAggregationFieldStepImpl.java @@ -8,11 +8,12 @@ import org.hibernate.search.engine.search.aggregation.dsl.AvgAggregationOptionsStep; import org.hibernate.search.engine.search.aggregation.dsl.spi.SearchAggregationDslContext; import org.hibernate.search.engine.search.aggregation.spi.AggregationTypeKeys; -import org.hibernate.search.engine.search.aggregation.spi.SearchFilterableAggregationBuilder; +import org.hibernate.search.engine.search.aggregation.spi.FieldMetricAggregationBuilder; +import org.hibernate.search.engine.search.common.ValueConvert; import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; +import org.hibernate.search.util.common.impl.Contracts; -public class AvgAggregationFieldStepImpl - implements AvgAggregationFieldStep { +public class AvgAggregationFieldStepImpl implements AvgAggregationFieldStep { private final SearchAggregationDslContext dslContext; public AvgAggregationFieldStepImpl(SearchAggregationDslContext dslContext) { @@ -20,9 +21,12 @@ public AvgAggregationFieldStepImpl(SearchAggregationDslContext } @Override - public AvgAggregationOptionsStep field(String fieldPath) { - SearchFilterableAggregationBuilder builder = dslContext.scope() - .fieldQueryElement( fieldPath, AggregationTypeKeys.AVG ); + public AvgAggregationOptionsStep field(String fieldPath, Class type, + ValueConvert convert) { + Contracts.assertNotNull( fieldPath, "fieldPath" ); + Contracts.assertNotNull( type, "type" ); + FieldMetricAggregationBuilder builder = dslContext.scope() + .fieldQueryElement( fieldPath, AggregationTypeKeys.AVG ).type( type, convert ); return new AvgAggregationOptionsStepImpl<>( builder, dslContext ); } } diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/AvgAggregationOptionsStepImpl.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/AvgAggregationOptionsStepImpl.java index 10502a6957f..e6deb06eec2 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/AvgAggregationOptionsStepImpl.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/AvgAggregationOptionsStepImpl.java @@ -9,37 +9,38 @@ import org.hibernate.search.engine.search.aggregation.SearchAggregation; import org.hibernate.search.engine.search.aggregation.dsl.AvgAggregationOptionsStep; import org.hibernate.search.engine.search.aggregation.dsl.spi.SearchAggregationDslContext; -import org.hibernate.search.engine.search.aggregation.spi.SearchFilterableAggregationBuilder; +import org.hibernate.search.engine.search.aggregation.spi.FieldMetricAggregationBuilder; import org.hibernate.search.engine.search.predicate.SearchPredicate; import org.hibernate.search.engine.search.predicate.dsl.PredicateFinalStep; import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; -class AvgAggregationOptionsStepImpl - implements AvgAggregationOptionsStep, PDF> { - private final SearchFilterableAggregationBuilder builder; +class AvgAggregationOptionsStepImpl + implements AvgAggregationOptionsStep, PDF, F> { + private final FieldMetricAggregationBuilder builder; private final SearchAggregationDslContext dslContext; - AvgAggregationOptionsStepImpl(SearchFilterableAggregationBuilder builder, + AvgAggregationOptionsStepImpl(FieldMetricAggregationBuilder builder, SearchAggregationDslContext dslContext) { this.builder = builder; this.dslContext = dslContext; } @Override - public AvgAggregationOptionsStepImpl filter( + public AvgAggregationOptionsStepImpl filter( Function clauseContributor) { SearchPredicate predicate = clauseContributor.apply( dslContext.predicateFactory() ).toPredicate(); + return filter( predicate ); } @Override - public AvgAggregationOptionsStepImpl filter(SearchPredicate searchPredicate) { + public AvgAggregationOptionsStepImpl filter(SearchPredicate searchPredicate) { builder.filter( searchPredicate ); return this; } @Override - public SearchAggregation toAggregation() { + public SearchAggregation toAggregation() { return builder.build(); } } diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/spi/AggregationTypeKeys.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/spi/AggregationTypeKeys.java index 66801b4a1ae..0c791a21ece 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/spi/AggregationTypeKeys.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/spi/AggregationTypeKeys.java @@ -28,7 +28,7 @@ private AggregationTypeKeys() { of( IndexFieldTraits.Aggregations.COUNT ); public static final SearchQueryElementTypeKey> COUNT_DISTINCT = of( IndexFieldTraits.Aggregations.COUNT_DISTINCT ); - public static final SearchQueryElementTypeKey> AVG = + public static final SearchQueryElementTypeKey AVG = of( IndexFieldTraits.Aggregations.AVG ); } diff --git a/integrationtest/backend/elasticsearch/src/test/java/org/hibernate/search/integrationtest/backend/elasticsearch/tmp/ElasticsearchMetricAggregationsIT.java b/integrationtest/backend/elasticsearch/src/test/java/org/hibernate/search/integrationtest/backend/elasticsearch/tmp/ElasticsearchMetricAggregationsIT.java index 0d4ec59ab81..9b4e73b0ad2 100644 --- a/integrationtest/backend/elasticsearch/src/test/java/org/hibernate/search/integrationtest/backend/elasticsearch/tmp/ElasticsearchMetricAggregationsIT.java +++ b/integrationtest/backend/elasticsearch/src/test/java/org/hibernate/search/integrationtest/backend/elasticsearch/tmp/ElasticsearchMetricAggregationsIT.java @@ -38,8 +38,9 @@ public class ElasticsearchMetricAggregationsIT { private final AggregationKey countConverted = AggregationKey.of( "countConverted" ); private final AggregationKey countDistinctIntegers = AggregationKey.of( "countDistinctIntegers" ); private final AggregationKey countDistinctConverted = AggregationKey.of( "countDistinctConverted" ); - private final AggregationKey avgIntegers = AggregationKey.of( "avgIntegers" ); - private final AggregationKey avgConverted = AggregationKey.of( "avgConverted" ); + private final AggregationKey avgIntegers = AggregationKey.of( "avgIntegers" ); + private final AggregationKey avgConverted = AggregationKey.of( "avgConverted" ); + private final AggregationKey avgIntegersAsDouble = AggregationKey.of( "avgIntegersAsDouble" ); @BeforeEach void setup() { @@ -62,8 +63,9 @@ public void test_filteringResults() { .aggregation( countConverted, f -> f.count().field( "converted" ) ) .aggregation( countDistinctIntegers, f -> f.countDistinct().field( "integer" ) ) .aggregation( countDistinctConverted, f -> f.countDistinct().field( "converted" ) ) - .aggregation( avgIntegers, f -> f.avg().field( "integer" ) ) - .aggregation( avgConverted, f -> f.avg().field( "converted" ) ) + .aggregation( avgIntegers, f -> f.avg().field( "integer", Integer.class ) ) + .aggregation( avgConverted, f -> f.avg().field( "converted", String.class ) ) + .aggregation( avgIntegersAsDouble, f -> f.avg().field( "integer", Double.class ) ) .toQuery(); SearchResult result = query.fetch( 0 ); @@ -77,8 +79,9 @@ public void test_filteringResults() { assertThat( result.aggregation( countConverted ) ).isEqualTo( 5 ); assertThat( result.aggregation( countDistinctIntegers ) ).isEqualTo( 3 ); assertThat( result.aggregation( countDistinctConverted ) ).isEqualTo( 3 ); - assertThat( result.aggregation( avgIntegers ) ).isEqualTo( 5.8 ); - assertThat( result.aggregation( avgConverted ) ).isEqualTo( 5.8 ); + assertThat( result.aggregation( avgIntegers ) ).isEqualTo( 5 ); + assertThat( result.aggregation( avgConverted ) ).isEqualTo( "5" ); + assertThat( result.aggregation( avgIntegersAsDouble ) ).isEqualTo( 5.8 ); } @Test @@ -96,8 +99,9 @@ public void test_allResults() { .aggregation( countConverted, f -> f.count().field( "converted" ) ) .aggregation( countDistinctIntegers, f -> f.countDistinct().field( "integer" ) ) .aggregation( countDistinctConverted, f -> f.countDistinct().field( "converted" ) ) - .aggregation( avgIntegers, f -> f.avg().field( "integer" ) ) - .aggregation( avgConverted, f -> f.avg().field( "converted" ) ) + .aggregation( avgIntegers, f -> f.avg().field( "integer", Integer.class ) ) + .aggregation( avgConverted, f -> f.avg().field( "converted", String.class ) ) + .aggregation( avgIntegersAsDouble, f -> f.avg().field( "integer", Double.class ) ) .toQuery(); SearchResult result = query.fetch( 0 ); @@ -111,8 +115,9 @@ public void test_allResults() { assertThat( result.aggregation( countConverted ) ).isEqualTo( 10 ); assertThat( result.aggregation( countDistinctIntegers ) ).isEqualTo( 6 ); assertThat( result.aggregation( countDistinctConverted ) ).isEqualTo( 6 ); - assertThat( result.aggregation( avgIntegers ) ).isEqualTo( 5.5 ); - assertThat( result.aggregation( avgConverted ) ).isEqualTo( 5.5 ); + assertThat( result.aggregation( avgIntegers ) ).isEqualTo( 5 ); + assertThat( result.aggregation( avgConverted ) ).isEqualTo( "5" ); + assertThat( result.aggregation( avgIntegersAsDouble ) ).isEqualTo( 5.5 ); } private void initData() {