From 38f3d6b0c79dbbd726a71dafae3c21478abc847a Mon Sep 17 00:00:00 2001 From: marko-bekhta Date: Mon, 16 Sep 2024 09:34:06 +0200 Subject: [PATCH] HSEARCH-5230 Support RAW model type with metric aggregations --- .../ElasticsearchMetricFieldAggregation.java | 160 +++++++++++---- ...LuceneMetricCompensatedSumAggregation.java | 109 +++++++---- ...ctLuceneMetricNumericFieldAggregation.java | 184 +++++++++++++----- .../LuceneAvgCompensatedSumAggregation.java | 20 +- .../LuceneAvgNumericFieldAggregation.java | 154 ++++++++++++++- .../LuceneMaxNumericFieldAggregation.java | 14 +- .../LuceneMinNumericFieldAggregation.java | 14 +- .../LuceneSumCompensatedSumAggregation.java | 20 +- .../LuceneSumNumericFieldAggregation.java | 14 +- .../impl/AbstractLuceneNumericFieldCodec.java | 3 + .../impl/LuceneBigDecimalFieldCodec.java | 5 + .../impl/LuceneBigIntegerFieldCodec.java | 5 + .../util/ElasticsearchTckBackendFeatures.java | 20 ++ .../util/LuceneTckBackendFeatures.java | 12 ++ .../MetricFieldAggregationsIT.java | 56 ++++-- .../MetricNumericFieldsAggregationsIT.java | 16 +- .../MetricAggregationsTestCase.java | 150 +++++++------- .../types/BigIntegerFieldTypeDescriptor.java | 5 + .../types/InstantFieldTypeDescriptor.java | 5 + .../LocalDateTimeFieldTypeDescriptor.java | 5 + .../OffsetDateTimeFieldTypeDescriptor.java | 11 ++ .../types/YearFieldTypeDescriptor.java | 5 + .../ZonedDateTimeFieldTypeDescriptor.java | 11 ++ .../values/MetricAggregationsValues.java | 20 ++ .../testsupport/util/TckBackendFeatures.java | 23 ++- .../testsupport/util/TypeAssertionHelper.java | 33 ++++ 26 files changed, 798 insertions(+), 276 deletions(-) diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/impl/ElasticsearchMetricFieldAggregation.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/impl/ElasticsearchMetricFieldAggregation.java index e8f7f1b933c..108ca7d15e8 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/impl/ElasticsearchMetricFieldAggregation.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/impl/ElasticsearchMetricFieldAggregation.java @@ -17,7 +17,6 @@ import org.hibernate.search.engine.backend.types.converter.spi.ProjectionConverter; import org.hibernate.search.engine.search.aggregation.spi.FieldMetricAggregationBuilder; import org.hibernate.search.engine.search.common.ValueModel; -import org.hibernate.search.util.common.AssertionFailure; import com.google.gson.JsonElement; import com.google.gson.JsonObject; @@ -64,15 +63,13 @@ public static ElasticsearchMetricFieldAggregation.Factory avg(Elasticsear } private final String absoluteFieldPath; - private final ProjectionConverter fromFieldValueConverter; - private final ElasticsearchFieldCodec codec; + private final AggregationExtractorBuilder metricFieldExtractorCreator; private final JsonAccessor operation; private ElasticsearchMetricFieldAggregation(Builder builder) { super( builder ); this.absoluteFieldPath = builder.field.absolutePath(); - this.fromFieldValueConverter = builder.fromFieldValueConverter; - this.codec = builder.codec; + this.metricFieldExtractorCreator = builder.metricFieldExtractorCreator; this.operation = builder.operation; } @@ -88,7 +85,7 @@ protected final JsonObject doRequest(AggregationRequestContext context) { @Override protected Extractor extractor(AggregationRequestContext context) { - return new MetricFieldExtractor( nestedPathHierarchy, filter ); + return metricFieldExtractorCreator.extractor( filter ); } private static class Factory @@ -124,72 +121,155 @@ private TypeSelector(ElasticsearchFieldCodec codec, this.operation = operation; } + @SuppressWarnings("unchecked") @Override public Builder type(Class expectedType, ValueModel valueModel) { - ProjectionConverter projectionConverter = null; - if ( useProjectionConverter( expectedType, valueModel ) ) { - projectionConverter = field.type().projectionConverter( valueModel ) - .withConvertedType( expectedType, field ); - } - return new Builder<>( codec, scope, field, - projectionConverter, - operation - ); - } + AggregationExtractorBuilder metricFieldExtractorCreator; - private boolean useProjectionConverter(Class expectedType, ValueModel valueModel) { - if ( !Double.class.isAssignableFrom( expectedType ) ) { - if ( ValueModel.RAW.equals( valueModel ) ) { - throw new AssertionFailure( - "Raw projection converter is not supported with metric aggregations at the moment" ); + if ( ValueModel.RAW.equals( valueModel ) ) { + if ( Double.class.isAssignableFrom( expectedType ) ) { + metricFieldExtractorCreator = (AggregationExtractorBuilder< + T>) new DoubleMetricFieldExtractor.Builder( field.nestedPathHierarchy() ); + } + else { + var projectionConverter = (ProjectionConverter) field.type() + .rawProjectionConverter().withConvertedType( expectedType, field ); + metricFieldExtractorCreator = (AggregationExtractorBuilder) new RawMetricFieldExtractor.Builder<>( + field.nestedPathHierarchy(), + projectionConverter ); } - return true; } - - // expectedType == Double.class - if ( ValueModel.RAW.equals( valueModel ) ) { - return false; + else { + var projectionConverter = field.type() + .projectionConverter( valueModel ).withConvertedType( expectedType, field ); + metricFieldExtractorCreator = + new MetricFieldExtractor.Builder<>( field.nestedPathHierarchy(), projectionConverter, codec ); } - return field.type().projectionConverter( valueModel ).valueType().isAssignableFrom( Double.class ); + + return new Builder<>( scope, field, metricFieldExtractorCreator, operation ); } } - private class MetricFieldExtractor extends AbstractExtractor { - protected MetricFieldExtractor(List nestedPathHierarchy, ElasticsearchSearchPredicate filter) { + private static class MetricFieldExtractor extends AbstractExtractor { + + private final ProjectionConverter fromFieldValueConverter; + private final ElasticsearchFieldCodec codec; + + protected MetricFieldExtractor(List nestedPathHierarchy, ElasticsearchSearchPredicate filter, + ProjectionConverter fromFieldValueConverter, ElasticsearchFieldCodec codec) { super( nestedPathHierarchy, filter ); + this.fromFieldValueConverter = fromFieldValueConverter; + this.codec = codec; } @Override - @SuppressWarnings("unchecked") protected K doExtract(JsonObject aggregationResult, AggregationExtractContext context) { FromDocumentValueConvertContext convertContext = context.fromDocumentValueConvertContext(); Optional value = VALUE_ACCESSOR.get( aggregationResult ); JsonElement valueAsString = aggregationResult.get( "value_as_string" ); - if ( fromFieldValueConverter == null ) { - Double decode = value.orElse( null ); - return (K) decode; - } + return fromFieldValueConverter.fromDocumentValue( codec.decodeAggregationValue( value, valueAsString ), convertContext ); } + + private static class Builder extends AggregationExtractorBuilder { + private final ProjectionConverter fromFieldValueConverter; + private final ElasticsearchFieldCodec codec; + + private Builder(List nestedPathHierarchy, ProjectionConverter fromFieldValueConverter, + ElasticsearchFieldCodec codec) { + super( nestedPathHierarchy ); + this.fromFieldValueConverter = fromFieldValueConverter; + this.codec = codec; + } + + @Override + AbstractExtractor extractor(ElasticsearchSearchPredicate filter) { + return new MetricFieldExtractor<>( nestedPathHierarchy, filter, fromFieldValueConverter, codec ); + } + } + } + + private static class DoubleMetricFieldExtractor extends AbstractExtractor { + protected DoubleMetricFieldExtractor(List nestedPathHierarchy, ElasticsearchSearchPredicate filter) { + super( nestedPathHierarchy, filter ); + } + + @Override + protected Double doExtract(JsonObject aggregationResult, AggregationExtractContext context) { + Optional value = VALUE_ACCESSOR.get( aggregationResult ); + return value.orElse( null ); + } + + private static class Builder extends AggregationExtractorBuilder { + + private Builder(List nestedPathHierarchy) { + super( nestedPathHierarchy ); + } + + @Override + AbstractExtractor extractor(ElasticsearchSearchPredicate filter) { + return new DoubleMetricFieldExtractor( nestedPathHierarchy, filter ); + } + } + } + + private static class RawMetricFieldExtractor extends AbstractExtractor { + + private final ProjectionConverter projectionConverter; + + protected RawMetricFieldExtractor(List nestedPathHierarchy, ElasticsearchSearchPredicate filter, + ProjectionConverter projectionConverter) { + super( nestedPathHierarchy, filter ); + this.projectionConverter = projectionConverter; + } + + @Override + protected K doExtract(JsonObject aggregationResult, AggregationExtractContext context) { + FromDocumentValueConvertContext convertContext = context.fromDocumentValueConvertContext(); + return projectionConverter.fromDocumentValue( aggregationResult, convertContext ); + } + + private static class Builder extends AggregationExtractorBuilder { + private final ProjectionConverter projectionConverter; + + private Builder(List nestedPathHierarchy, ProjectionConverter projectionConverter) { + super( nestedPathHierarchy ); + this.projectionConverter = projectionConverter; + } + + @Override + AbstractExtractor extractor(ElasticsearchSearchPredicate filter) { + return new RawMetricFieldExtractor<>( nestedPathHierarchy, filter, projectionConverter ); + } + } + } + + private abstract static class AggregationExtractorBuilder { + protected final List nestedPathHierarchy; + + protected AggregationExtractorBuilder(List nestedPathHierarchy) { + this.nestedPathHierarchy = nestedPathHierarchy; + } + + abstract AbstractExtractor extractor(ElasticsearchSearchPredicate filter); } private static class Builder extends AbstractBuilder implements FieldMetricAggregationBuilder { - private final ElasticsearchFieldCodec codec; - private final ProjectionConverter fromFieldValueConverter; + private final AggregationExtractorBuilder metricFieldExtractorCreator; private final JsonAccessor operation; - private Builder(ElasticsearchFieldCodec codec, ElasticsearchSearchIndexScope scope, + private Builder(ElasticsearchSearchIndexScope scope, ElasticsearchSearchIndexValueFieldContext field, - ProjectionConverter fromFieldValueConverter, JsonAccessor operation) { + AggregationExtractorBuilder metricFieldExtractorCreator, + JsonAccessor operation) { super( scope, field ); - this.codec = codec; - this.fromFieldValueConverter = fromFieldValueConverter; + this.metricFieldExtractorCreator = metricFieldExtractorCreator; this.operation = operation; } diff --git a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/AbstractLuceneMetricCompensatedSumAggregation.java b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/AbstractLuceneMetricCompensatedSumAggregation.java index 1807b1416fb..1e02a1c45a8 100644 --- a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/AbstractLuceneMetricCompensatedSumAggregation.java +++ b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/AbstractLuceneMetricCompensatedSumAggregation.java @@ -16,10 +16,10 @@ import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexValueFieldContext; import org.hibernate.search.backend.lucene.types.codec.impl.AbstractLuceneNumericFieldCodec; import org.hibernate.search.backend.lucene.types.lowlevel.impl.LuceneNumericDomain; +import org.hibernate.search.engine.backend.types.converter.runtime.FromDocumentValueConvertContext; import org.hibernate.search.engine.backend.types.converter.spi.ProjectionConverter; import org.hibernate.search.engine.search.aggregation.spi.FieldMetricAggregationBuilder; import org.hibernate.search.engine.search.common.ValueModel; -import org.hibernate.search.util.common.AssertionFailure; /** * @param The type of field values. @@ -31,9 +31,9 @@ public abstract class AbstractLuceneMetricCompensatedSumAggregation indexNames; private final String absoluteFieldPath; - private final AbstractLuceneNumericFieldCodec codec; + protected final AbstractLuceneNumericFieldCodec codec; private final LuceneNumericDomain numericDomain; - private final ProjectionConverter fromFieldValueConverter; + private final ExtractedValueConverter extractedConverter; protected CollectorKey collectorKey; protected CollectorKey, Double> compensatedSumCollectorKey; @@ -44,7 +44,7 @@ public abstract class AbstractLuceneMetricCompensatedSumAggregation request(AggregationRequestContext context) { JoiningLongMultiValuesSource source = JoiningLongMultiValuesSource.fromField( absoluteFieldPath, createNestedDocsProvider( context ) ); - fillCollectors( source, context, numericDomain ); + fillCollectors( source, context ); return new LuceneNumericMetricFieldAggregationExtraction(); } - abstract void fillCollectors(JoiningLongMultiValuesSource source, AggregationRequestContext context, - LuceneNumericDomain numericDomain); + abstract void fillCollectors(JoiningLongMultiValuesSource source, AggregationRequestContext context); @Override public Set indexNames() { @@ -67,75 +66,109 @@ public Set indexNames() { private class LuceneNumericMetricFieldAggregationExtraction implements Extractor { @Override - @SuppressWarnings("unchecked") public K extract(AggregationExtractContext context) { E extracted = extractEncoded( context, numericDomain ); - F decode = codec.decode( extracted ); - if ( fromFieldValueConverter == null ) { - return (K) decode; - } - return fromFieldValueConverter.fromDocumentValue( decode, context.fromDocumentValueConvertContext() ); + return extractedConverter.convert( extracted, context.fromDocumentValueConvertContext() ); } } abstract E extractEncoded(AggregationExtractContext context, LuceneNumericDomain numericDomain); - protected abstract static class TypeSelector implements FieldMetricAggregationBuilder.TypeSelector { - protected final AbstractLuceneNumericFieldCodec codec; + protected abstract static class ExtractedValueConverter { + + abstract K convert(E extracted, FromDocumentValueConvertContext context); + } + + protected abstract static class TypeSelector implements FieldMetricAggregationBuilder.TypeSelector { + protected final AbstractLuceneNumericFieldCodec codec; protected final LuceneSearchIndexScope scope; protected final LuceneSearchIndexValueFieldContext field; - protected TypeSelector(AbstractLuceneNumericFieldCodec codec, + protected TypeSelector(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, LuceneSearchIndexValueFieldContext field) { this.codec = codec; this.scope = scope; this.field = field; } + @SuppressWarnings("unchecked") @Override public Builder type(Class expectedType, ValueModel valueModel) { - ProjectionConverter projectionConverter = null; - if ( useProjectionConverter( expectedType, valueModel ) ) { - projectionConverter = field.type().projectionConverter( valueModel ) + ExtractedValueConverter extractedConverter; + + if ( ValueModel.RAW.equals( valueModel ) ) { + if ( Double.class.isAssignableFrom( expectedType ) ) { + extractedConverter = (ExtractedValueConverter) new DoubleExtractedValueConverter<>(); + } + else { + var projectionConverter = (ProjectionConverter) field.type().rawProjectionConverter() + .withConvertedType( expectedType, field ); + extractedConverter = new RawExtractedValueConverter<>( projectionConverter ); + } + } + else { + var projectionConverter = field.type().projectionConverter( valueModel ) .withConvertedType( expectedType, field ); + extractedConverter = new DecodingExtractedValueConverter<>( projectionConverter, codec ); } - return getFtBuilder( projectionConverter ); + return getFtBuilder( extractedConverter ); } - private boolean useProjectionConverter(Class expectedType, ValueModel valueModel) { - if ( !Double.class.isAssignableFrom( expectedType ) ) { - if ( ValueModel.RAW.equals( valueModel ) ) { - throw new AssertionFailure( - "Raw projection converter is not supported with metric aggregations at the moment" ); - } - return true; - } + protected abstract Builder getFtBuilder( + ExtractedValueConverter extractedConverter); + } - // expectedType == Double.class - if ( ValueModel.RAW.equals( valueModel ) ) { - return false; - } - return field.type().projectionConverter( valueModel ).valueType().isAssignableFrom( Double.class ); + private static class DoubleExtractedValueConverter extends ExtractedValueConverter { + + @Override + Double convert(E extracted, FromDocumentValueConvertContext context) { + return extracted.doubleValue(); } + } - protected abstract Builder getFtBuilder( - ProjectionConverter projectionConverter); + private static class RawExtractedValueConverter extends ExtractedValueConverter { + private final ProjectionConverter projectionConverter; + + private RawExtractedValueConverter(ProjectionConverter projectionConverter) { + this.projectionConverter = projectionConverter; + } + + @Override + T convert(E extracted, FromDocumentValueConvertContext context) { + return projectionConverter.fromDocumentValue( extracted, context ); + } + } + + private static class DecodingExtractedValueConverter extends ExtractedValueConverter { + private final ProjectionConverter projectionConverter; + private final AbstractLuceneNumericFieldCodec codec; + + private DecodingExtractedValueConverter(ProjectionConverter projectionConverter, + AbstractLuceneNumericFieldCodec codec) { + this.projectionConverter = projectionConverter; + this.codec = codec; + } + + @Override + T convert(E extracted, FromDocumentValueConvertContext context) { + return projectionConverter.fromDocumentValue( codec.decode( extracted ), context ); + } } protected abstract static class Builder extends AbstractBuilder implements FieldMetricAggregationBuilder { private final AbstractLuceneNumericFieldCodec codec; - private final ProjectionConverter fromFieldValueConverter; + private final ExtractedValueConverter extractedConverter; public Builder(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, LuceneSearchIndexValueFieldContext field, - ProjectionConverter fromFieldValueConverter) { + ExtractedValueConverter extractedConverter) { super( scope, field ); this.codec = codec; - this.fromFieldValueConverter = fromFieldValueConverter; + this.extractedConverter = extractedConverter; } } } diff --git a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/AbstractLuceneMetricNumericFieldAggregation.java b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/AbstractLuceneMetricNumericFieldAggregation.java index aa1acc83f64..bcdfae57cf0 100644 --- a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/AbstractLuceneMetricNumericFieldAggregation.java +++ b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/AbstractLuceneMetricNumericFieldAggregation.java @@ -6,8 +6,6 @@ import java.util.Set; -import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.AggregationFunctionCollector; -import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.Count; import org.hibernate.search.backend.lucene.lowlevel.collector.impl.CollectorKey; import org.hibernate.search.backend.lucene.lowlevel.docvalues.impl.JoiningLongMultiValuesSource; import org.hibernate.search.backend.lucene.search.aggregation.impl.AggregationExtractContext; @@ -17,10 +15,8 @@ import org.hibernate.search.backend.lucene.types.codec.impl.AbstractLuceneNumericFieldCodec; import org.hibernate.search.backend.lucene.types.lowlevel.impl.LuceneNumericDomain; import org.hibernate.search.engine.backend.types.converter.spi.ProjectionConverter; -import org.hibernate.search.engine.cfg.spi.NumberUtils; import org.hibernate.search.engine.search.aggregation.spi.FieldMetricAggregationBuilder; import org.hibernate.search.engine.search.common.ValueModel; -import org.hibernate.search.util.common.AssertionFailure; /** * @param The type of field values. @@ -32,22 +28,19 @@ public abstract class AbstractLuceneMetricNumericFieldAggregation indexNames; private final String absoluteFieldPath; - private final AbstractLuceneNumericFieldCodec codec; - private final LuceneNumericDomain numericDomain; - private final ProjectionConverter fromFieldValueConverter; + protected final AbstractLuceneNumericFieldCodec codec; + protected final LuceneNumericDomain numericDomain; + private final AbstractExtractorBuilder extractorCreator; protected CollectorKey collectorKey; - // Supplementary collector used by the avg function - protected CollectorKey, Long> countCollectorKey; - AbstractLuceneMetricNumericFieldAggregation(Builder builder) { super( builder ); this.indexNames = builder.scope.hibernateSearchIndexNames(); this.absoluteFieldPath = builder.field.absolutePath(); this.codec = builder.codec; this.numericDomain = codec.getDomain(); - this.fromFieldValueConverter = builder.fromFieldValueConverter; + this.extractorCreator = builder.extractorCreator; } @Override @@ -56,7 +49,7 @@ public Extractor request(AggregationRequestContext context) { absoluteFieldPath, createNestedDocsProvider( context ) ); fillCollectors( source, context ); - return new LuceneNumericMetricFieldAggregationExtraction(); + return extractorCreator.extractor( this ); } abstract void fillCollectors(JoiningLongMultiValuesSource source, AggregationRequestContext context); @@ -66,39 +59,116 @@ public Set indexNames() { return indexNames; } - private class LuceneNumericMetricFieldAggregationExtraction implements Extractor { + private static class LuceneNumericMetricFieldAggregationExtraction implements Extractor { + private final CollectorKey collectorKey; + private final AbstractLuceneNumericFieldCodec codec; + private final ProjectionConverter fromFieldValueConverter; + + private LuceneNumericMetricFieldAggregationExtraction(CollectorKey collectorKey, + AbstractLuceneNumericFieldCodec codec, ProjectionConverter fromFieldValueConverter) { + this.collectorKey = collectorKey; + this.codec = codec; + this.fromFieldValueConverter = fromFieldValueConverter; + } @Override - @SuppressWarnings("unchecked") public K extract(AggregationExtractContext context) { Long collector = context.getFacets( collectorKey ); - if ( countCollectorKey != null ) { - Long counts = context.getFacets( countCollectorKey ); - Double avg = ( (double) collector / counts ); - if ( fromFieldValueConverter == null ) { - return (K) avg; - } - collector = NumberUtils.toLong( avg ); + E e = codec.getDomain().sortedDocValueToTerm( collector ); + F decode = codec.decode( e ); + return fromFieldValueConverter.fromDocumentValue( decode, context.fromDocumentValueConvertContext() ); + } + + private static class Builder extends AbstractExtractorBuilder { + private final ProjectionConverter fromFieldValueConverter; + + private Builder(ProjectionConverter fromFieldValueConverter) { + this.fromFieldValueConverter = fromFieldValueConverter; } - if ( fromFieldValueConverter == null ) { - Double decode = collector.doubleValue(); - return (K) decode; + @Override + Extractor extractor(AbstractLuceneMetricNumericFieldAggregation aggregation) { + return new LuceneNumericMetricFieldAggregationExtraction<>( + aggregation.collectorKey, + aggregation.codec, + fromFieldValueConverter + ); } + } + } - E e = numericDomain.sortedDocValueToTerm( collector ); - F decode = codec.decode( e ); - return fromFieldValueConverter.fromDocumentValue( decode, context.fromDocumentValueConvertContext() ); + private static class LuceneNumericMetricFieldAggregationDoubleExtraction implements Extractor { + + private final CollectorKey collectorKey; + private final AbstractLuceneNumericFieldCodec codec; + + private LuceneNumericMetricFieldAggregationDoubleExtraction(CollectorKey collectorKey, + AbstractLuceneNumericFieldCodec codec) { + this.collectorKey = collectorKey; + this.codec = codec; + } + + @Override + public Double extract(AggregationExtractContext context) { + Long collector = context.getFacets( collectorKey ); + + return codec.sortedDocValueToDouble( collector ); + } + + private static class Builder extends AbstractExtractorBuilder { + + @Override + Extractor extractor(AbstractLuceneMetricNumericFieldAggregation aggregation) { + return new LuceneNumericMetricFieldAggregationDoubleExtraction( + aggregation.collectorKey, + aggregation.codec + ); + } } } - protected abstract static class TypeSelector implements FieldMetricAggregationBuilder.TypeSelector { - protected final AbstractLuceneNumericFieldCodec codec; + private static class LuceneNumericMetricFieldAggregationRawExtraction implements Extractor { + + private final CollectorKey collectorKey; + private final LuceneNumericDomain numericDomain; + + private LuceneNumericMetricFieldAggregationRawExtraction(CollectorKey collectorKey, + LuceneNumericDomain numericDomain) { + this.collectorKey = collectorKey; + this.numericDomain = numericDomain; + } + + @SuppressWarnings("unchecked") + @Override + public K extract(AggregationExtractContext context) { + Long collector = context.getFacets( collectorKey ); + return (K) numericDomain.sortedDocValueToTerm( collector ); + } + + private static class Builder extends AbstractExtractorBuilder { + + @Override + Extractor extractor(AbstractLuceneMetricNumericFieldAggregation aggregation) { + return new LuceneNumericMetricFieldAggregationRawExtraction<>( + aggregation.collectorKey, + aggregation.numericDomain + ); + } + } + } + + protected abstract static class AbstractExtractorBuilder { + + abstract Extractor extractor(AbstractLuceneMetricNumericFieldAggregation aggregation); + } + + protected abstract static class TypeSelector implements FieldMetricAggregationBuilder.TypeSelector { + protected final AbstractLuceneNumericFieldCodec codec; protected final LuceneSearchIndexScope scope; protected final LuceneSearchIndexValueFieldContext field; - protected TypeSelector(AbstractLuceneNumericFieldCodec codec, + protected TypeSelector(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, LuceneSearchIndexValueFieldContext field) { this.codec = codec; this.scope = scope; @@ -107,47 +177,57 @@ protected TypeSelector(AbstractLuceneNumericFieldCodec codec, @Override public Builder type(Class expectedType, ValueModel valueModel) { - ProjectionConverter projectionConverter = null; - if ( useProjectionConverter( expectedType, valueModel ) ) { - projectionConverter = field.type().projectionConverter( valueModel ) + AbstractExtractorBuilder extractorCreator; + if ( ValueModel.RAW.equals( valueModel ) ) { + if ( Double.class.isAssignableFrom( expectedType ) ) { + extractorCreator = doubleExtractor(); + } + else { + var projectionConverter = field.type().rawProjectionConverter() + .withConvertedType( expectedType, field ); + extractorCreator = rawExtractor( projectionConverter ); + } + } + else { + var projectionConverter = field.type().projectionConverter( valueModel ) .withConvertedType( expectedType, field ); + extractorCreator = extractor( projectionConverter ); } - return getFtBuilder( projectionConverter ); + return getFtBuilder( extractorCreator ); } - private boolean useProjectionConverter(Class expectedType, ValueModel valueModel) { - if ( !Double.class.isAssignableFrom( expectedType ) ) { - if ( ValueModel.RAW.equals( valueModel ) ) { - throw new AssertionFailure( - "Raw projection converter is not supported with metric aggregations at the moment" ); - } - return true; - } + protected AbstractExtractorBuilder extractor(ProjectionConverter projectionConverter) { + return new LuceneNumericMetricFieldAggregationExtraction.Builder<>( projectionConverter ); + } - // expectedType == Double.class - if ( ValueModel.RAW.equals( valueModel ) ) { - return false; - } - return field.type().projectionConverter( valueModel ).valueType().isAssignableFrom( Double.class ); + // we've checked the types in the place where we are calling this method: + protected AbstractExtractorBuilder rawExtractor(ProjectionConverter projectionConverter) { + return new LuceneNumericMetricFieldAggregationRawExtraction.Builder<>(); + } + + // we've checked the types in the place where we are calling this method: + @SuppressWarnings("unchecked") + protected AbstractExtractorBuilder doubleExtractor() { + return (AbstractExtractorBuilder) new LuceneNumericMetricFieldAggregationDoubleExtraction.Builder<>(); } - protected abstract Builder getFtBuilder( - ProjectionConverter projectionConverter); + protected abstract Builder getFtBuilder(AbstractExtractorBuilder extractorCreator); + } protected abstract static class Builder extends AbstractBuilder implements FieldMetricAggregationBuilder { private final AbstractLuceneNumericFieldCodec codec; - private final ProjectionConverter fromFieldValueConverter; + private final AbstractExtractorBuilder extractorCreator; public Builder(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, LuceneSearchIndexValueFieldContext field, - ProjectionConverter fromFieldValueConverter) { + AbstractExtractorBuilder extractorCreator) { super( scope, field ); this.codec = codec; - this.fromFieldValueConverter = fromFieldValueConverter; + this.extractorCreator = extractorCreator; } } } diff --git a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneAvgCompensatedSumAggregation.java b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneAvgCompensatedSumAggregation.java index c8b677752c3..c392eee44b8 100644 --- a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneAvgCompensatedSumAggregation.java +++ b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneAvgCompensatedSumAggregation.java @@ -14,7 +14,6 @@ import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexValueFieldContext; import org.hibernate.search.backend.lucene.types.codec.impl.AbstractLuceneNumericFieldCodec; import org.hibernate.search.backend.lucene.types.lowlevel.impl.LuceneNumericDomain; -import org.hibernate.search.engine.backend.types.converter.spi.ProjectionConverter; import org.hibernate.search.engine.search.aggregation.spi.FieldMetricAggregationBuilder; public class LuceneAvgCompensatedSumAggregation @@ -29,10 +28,9 @@ public static Factory factory(AbstractLuceneNumericFieldCodec codec } @Override - void fillCollectors(JoiningLongMultiValuesSource source, AggregationRequestContext context, - LuceneNumericDomain numericDomain) { - CompensatedSumCollectorFactory sumCollectorFactory = new CompensatedSumCollectorFactory( source, - numericDomain::sortedDocValueToDouble ); + void fillCollectors(JoiningLongMultiValuesSource source, AggregationRequestContext context) { + CompensatedSumCollectorFactory sumCollectorFactory = + new CompensatedSumCollectorFactory( source, codec::sortedDocValueToDouble ); compensatedSumCollectorKey = sumCollectorFactory.getCollectorKey(); context.requireCollector( sumCollectorFactory ); @@ -65,18 +63,18 @@ public FieldMetricAggregationBuilder.TypeSelector create(LuceneSearchIndexScope< } } - protected static class FunctionTypeSelector extends TypeSelector + protected static class FunctionTypeSelector extends TypeSelector implements FieldMetricAggregationBuilder.TypeSelector { - protected FunctionTypeSelector(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, + protected FunctionTypeSelector(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, LuceneSearchIndexValueFieldContext field) { super( codec, scope, field ); } @Override protected Builder getFtBuilder( - ProjectionConverter projectionConverter) { - return new Builder<>( codec, scope, field, projectionConverter ); + ExtractedValueConverter extractedConverter) { + return new Builder<>( codec, scope, field, extractedConverter ); } } @@ -86,8 +84,8 @@ protected static class Builder public Builder(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, LuceneSearchIndexValueFieldContext field, - ProjectionConverter fromFieldValueConverter) { - super( codec, scope, field, fromFieldValueConverter ); + ExtractedValueConverter extractedConverter) { + super( codec, scope, field, extractedConverter ); } @Override diff --git a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneAvgNumericFieldAggregation.java b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneAvgNumericFieldAggregation.java index ed0f5c98b52..ad44a7f9402 100644 --- a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneAvgNumericFieldAggregation.java +++ b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneAvgNumericFieldAggregation.java @@ -4,15 +4,20 @@ */ package org.hibernate.search.backend.lucene.types.aggregation.impl; +import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.AggregationFunctionCollector; +import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.Count; import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.CountCollectorFactory; import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.SumCollectorFactory; +import org.hibernate.search.backend.lucene.lowlevel.collector.impl.CollectorKey; import org.hibernate.search.backend.lucene.lowlevel.docvalues.impl.JoiningLongMultiValuesSource; +import org.hibernate.search.backend.lucene.search.aggregation.impl.AggregationExtractContext; import org.hibernate.search.backend.lucene.search.aggregation.impl.AggregationRequestContext; import org.hibernate.search.backend.lucene.search.common.impl.AbstractLuceneCodecAwareSearchQueryElementFactory; import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexScope; import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexValueFieldContext; import org.hibernate.search.backend.lucene.types.codec.impl.AbstractLuceneNumericFieldCodec; import org.hibernate.search.engine.backend.types.converter.spi.ProjectionConverter; +import org.hibernate.search.engine.cfg.spi.NumberUtils; import org.hibernate.search.engine.search.aggregation.spi.FieldMetricAggregationBuilder; public class LuceneAvgNumericFieldAggregation @@ -22,6 +27,9 @@ public static Factory factory(AbstractLuceneNumericFieldCodec codec return new Factory<>( codec ); } + // Supplementary collector used by the avg function + protected CollectorKey, Long> countCollectorKey; + LuceneAvgNumericFieldAggregation(Builder builder) { super( builder ); } @@ -36,6 +44,124 @@ void fillCollectors(JoiningLongMultiValuesSource source, AggregationRequestConte context.requireCollector( countCollectorFactory ); } + private static class LuceneNumericMetricFieldAggregationExtraction implements Extractor { + private final CollectorKey collectorKey; + private final CollectorKey countCollectorKey; + private final AbstractLuceneNumericFieldCodec codec; + private final ProjectionConverter fromFieldValueConverter; + + private LuceneNumericMetricFieldAggregationExtraction(CollectorKey collectorKey, + CollectorKey countCollectorKey, + AbstractLuceneNumericFieldCodec codec, ProjectionConverter fromFieldValueConverter) { + this.collectorKey = collectorKey; + this.countCollectorKey = countCollectorKey; + this.codec = codec; + this.fromFieldValueConverter = fromFieldValueConverter; + } + + @Override + public K extract(AggregationExtractContext context) { + Long collector = context.getFacets( collectorKey ); + Long counts = context.getFacets( countCollectorKey ); + Double avg = ( (double) collector / counts ); + collector = NumberUtils.toLong( avg ); + + E e = codec.getDomain().sortedDocValueToTerm( collector ); + F decode = codec.decode( e ); + return fromFieldValueConverter.fromDocumentValue( decode, context.fromDocumentValueConvertContext() ); + } + + private static class Builder extends AbstractExtractorBuilder { + private final ProjectionConverter fromFieldValueConverter; + + private Builder(ProjectionConverter fromFieldValueConverter) { + this.fromFieldValueConverter = fromFieldValueConverter; + } + + @Override + Extractor extractor(AbstractLuceneMetricNumericFieldAggregation aggregation) { + return new LuceneNumericMetricFieldAggregationExtraction<>( + aggregation.collectorKey, + ( (LuceneAvgNumericFieldAggregation) aggregation ).countCollectorKey, + aggregation.codec, + fromFieldValueConverter + ); + } + } + } + + private static class LuceneNumericMetricFieldAggregationDoubleExtraction implements Extractor { + + private final CollectorKey collectorKey; + private final CollectorKey countCollectorKey; + private final AbstractLuceneNumericFieldCodec codec; + + private LuceneNumericMetricFieldAggregationDoubleExtraction(CollectorKey collectorKey, + CollectorKey countCollectorKey, + AbstractLuceneNumericFieldCodec codec) { + this.collectorKey = collectorKey; + this.countCollectorKey = countCollectorKey; + this.codec = codec; + } + + @Override + public Double extract(AggregationExtractContext context) { + double collector = codec.sortedDocValueToDouble( context.getFacets( collectorKey ) ); + Long counts = context.getFacets( countCollectorKey ); + return ( collector / counts ); + } + + private static class Builder extends AbstractExtractorBuilder { + + @SuppressWarnings("unchecked") + @Override + Extractor extractor(AbstractLuceneMetricNumericFieldAggregation aggregation) { + return (Extractor) new LuceneNumericMetricFieldAggregationDoubleExtraction<>( + aggregation.collectorKey, + ( (LuceneAvgNumericFieldAggregation) aggregation ).countCollectorKey, + aggregation.codec + ); + } + } + } + + private static class LuceneNumericMetricFieldAggregationRawExtraction implements Extractor { + + private final CollectorKey collectorKey; + private final CollectorKey countCollectorKey; + private final AbstractLuceneNumericFieldCodec codec; + + private LuceneNumericMetricFieldAggregationRawExtraction(CollectorKey collectorKey, + CollectorKey countCollectorKey, + AbstractLuceneNumericFieldCodec codec) { + this.collectorKey = collectorKey; + this.countCollectorKey = countCollectorKey; + this.codec = codec; + } + + @Override + public E extract(AggregationExtractContext context) { + Long collector = context.getFacets( collectorKey ); + Long counts = context.getFacets( countCollectorKey ); + Double avg = ( (double) collector / counts ); + collector = NumberUtils.toLong( avg ); + return codec.getDomain().sortedDocValueToTerm( collector ); + } + + private static class Builder extends AbstractExtractorBuilder { + + @SuppressWarnings("unchecked") + @Override + Extractor extractor(AbstractLuceneMetricNumericFieldAggregation aggregation) { + return (Extractor) new LuceneNumericMetricFieldAggregationRawExtraction<>( + aggregation.collectorKey, + ( (LuceneAvgNumericFieldAggregation) aggregation ).countCollectorKey, + aggregation.codec + ); + } + } + } + public static class Factory extends AbstractLuceneCodecAwareSearchQueryElementFactory extends TypeSelector + protected static class FunctionTypeSelector extends TypeSelector implements FieldMetricAggregationBuilder.TypeSelector { - protected FunctionTypeSelector(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, + protected FunctionTypeSelector(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, LuceneSearchIndexValueFieldContext field) { super( codec, scope, field ); } @Override - protected Builder getFtBuilder( - ProjectionConverter projectionConverter) { - return new Builder<>( codec, scope, field, projectionConverter ); + protected Builder getFtBuilder(AbstractExtractorBuilder extractorCreator) { + return new Builder<>( codec, scope, field, extractorCreator ); + } + + @Override + protected AbstractExtractorBuilder extractor(ProjectionConverter projectionConverter) { + return new LuceneNumericMetricFieldAggregationExtraction.Builder<>( projectionConverter ); + } + + @Override + protected AbstractExtractorBuilder rawExtractor(ProjectionConverter projectionConverter) { + return new LuceneNumericMetricFieldAggregationRawExtraction.Builder<>(); + } + + @Override + protected AbstractExtractorBuilder doubleExtractor() { + return new LuceneNumericMetricFieldAggregationDoubleExtraction.Builder<>(); } } @@ -73,8 +213,8 @@ protected static class Builder public Builder(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, LuceneSearchIndexValueFieldContext field, - ProjectionConverter fromFieldValueConverter) { - super( codec, scope, field, fromFieldValueConverter ); + AbstractExtractorBuilder extractorCreator) { + super( codec, scope, field, extractorCreator ); } @Override diff --git a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneMaxNumericFieldAggregation.java b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneMaxNumericFieldAggregation.java index 33829737b32..0c747cb2100 100644 --- a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneMaxNumericFieldAggregation.java +++ b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneMaxNumericFieldAggregation.java @@ -11,7 +11,6 @@ import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexScope; import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexValueFieldContext; import org.hibernate.search.backend.lucene.types.codec.impl.AbstractLuceneNumericFieldCodec; -import org.hibernate.search.engine.backend.types.converter.spi.ProjectionConverter; import org.hibernate.search.engine.search.aggregation.spi.FieldMetricAggregationBuilder; public class LuceneMaxNumericFieldAggregation @@ -48,18 +47,17 @@ public FieldMetricAggregationBuilder.TypeSelector create(LuceneSearchIndexScope< } } - protected static class FunctionTypeSelector extends TypeSelector + protected static class FunctionTypeSelector extends TypeSelector implements FieldMetricAggregationBuilder.TypeSelector { - protected FunctionTypeSelector(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, + protected FunctionTypeSelector(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, LuceneSearchIndexValueFieldContext field) { super( codec, scope, field ); } @Override - protected Builder getFtBuilder( - ProjectionConverter projectionConverter) { - return new Builder<>( codec, scope, field, projectionConverter ); + protected Builder getFtBuilder(AbstractExtractorBuilder extractorCreator) { + return new Builder<>( codec, scope, field, extractorCreator ); } } @@ -69,8 +67,8 @@ protected static class Builder public Builder(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, LuceneSearchIndexValueFieldContext field, - ProjectionConverter fromFieldValueConverter) { - super( codec, scope, field, fromFieldValueConverter ); + AbstractExtractorBuilder extractorCreator) { + super( codec, scope, field, extractorCreator ); } @Override diff --git a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneMinNumericFieldAggregation.java b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneMinNumericFieldAggregation.java index 70bbd260518..6c95008880f 100644 --- a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneMinNumericFieldAggregation.java +++ b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneMinNumericFieldAggregation.java @@ -11,7 +11,6 @@ import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexScope; import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexValueFieldContext; import org.hibernate.search.backend.lucene.types.codec.impl.AbstractLuceneNumericFieldCodec; -import org.hibernate.search.engine.backend.types.converter.spi.ProjectionConverter; import org.hibernate.search.engine.search.aggregation.spi.FieldMetricAggregationBuilder; public class LuceneMinNumericFieldAggregation @@ -48,18 +47,17 @@ public FieldMetricAggregationBuilder.TypeSelector create(LuceneSearchIndexScope< } } - protected static class FunctionTypeSelector extends TypeSelector + protected static class FunctionTypeSelector extends TypeSelector implements FieldMetricAggregationBuilder.TypeSelector { - protected FunctionTypeSelector(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, + protected FunctionTypeSelector(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, LuceneSearchIndexValueFieldContext field) { super( codec, scope, field ); } @Override - protected Builder getFtBuilder( - ProjectionConverter projectionConverter) { - return new Builder<>( codec, scope, field, projectionConverter ); + protected Builder getFtBuilder(AbstractExtractorBuilder extractorCreator) { + return new Builder<>( codec, scope, field, extractorCreator ); } } @@ -69,8 +67,8 @@ protected static class Builder public Builder(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, LuceneSearchIndexValueFieldContext field, - ProjectionConverter fromFieldValueConverter) { - super( codec, scope, field, fromFieldValueConverter ); + AbstractExtractorBuilder extractorCreator) { + super( codec, scope, field, extractorCreator ); } @Override diff --git a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneSumCompensatedSumAggregation.java b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneSumCompensatedSumAggregation.java index 0397eac198a..b91a81ca70f 100644 --- a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneSumCompensatedSumAggregation.java +++ b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneSumCompensatedSumAggregation.java @@ -13,7 +13,6 @@ import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexValueFieldContext; import org.hibernate.search.backend.lucene.types.codec.impl.AbstractLuceneNumericFieldCodec; import org.hibernate.search.backend.lucene.types.lowlevel.impl.LuceneNumericDomain; -import org.hibernate.search.engine.backend.types.converter.spi.ProjectionConverter; import org.hibernate.search.engine.search.aggregation.spi.FieldMetricAggregationBuilder; public class LuceneSumCompensatedSumAggregation @@ -28,10 +27,9 @@ public static Factory factory(AbstractLuceneNumericFieldCodec codec } @Override - void fillCollectors(JoiningLongMultiValuesSource source, AggregationRequestContext context, - LuceneNumericDomain numericDomain) { - CompensatedSumCollectorFactory collectorFactory = new CompensatedSumCollectorFactory( source, - numericDomain::sortedDocValueToDouble ); + void fillCollectors(JoiningLongMultiValuesSource source, AggregationRequestContext context) { + CompensatedSumCollectorFactory collectorFactory = + new CompensatedSumCollectorFactory( source, codec::sortedDocValueToDouble ); compensatedSumCollectorKey = collectorFactory.getCollectorKey(); context.requireCollector( collectorFactory ); } @@ -58,18 +56,18 @@ public FieldMetricAggregationBuilder.TypeSelector create(LuceneSearchIndexScope< } } - protected static class FunctionTypeSelector extends TypeSelector + protected static class FunctionTypeSelector extends TypeSelector implements FieldMetricAggregationBuilder.TypeSelector { - protected FunctionTypeSelector(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, + protected FunctionTypeSelector(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, LuceneSearchIndexValueFieldContext field) { super( codec, scope, field ); } @Override protected Builder getFtBuilder( - ProjectionConverter projectionConverter) { - return new Builder<>( codec, scope, field, projectionConverter ); + ExtractedValueConverter extractedConverter) { + return new Builder<>( codec, scope, field, extractedConverter ); } } @@ -79,8 +77,8 @@ protected static class Builder public Builder(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, LuceneSearchIndexValueFieldContext field, - ProjectionConverter fromFieldValueConverter) { - super( codec, scope, field, fromFieldValueConverter ); + ExtractedValueConverter extractedConverter) { + super( codec, scope, field, extractedConverter ); } @Override diff --git a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneSumNumericFieldAggregation.java b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneSumNumericFieldAggregation.java index 3630e971016..b9c9766923b 100644 --- a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneSumNumericFieldAggregation.java +++ b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/aggregation/impl/LuceneSumNumericFieldAggregation.java @@ -11,7 +11,6 @@ import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexScope; import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexValueFieldContext; import org.hibernate.search.backend.lucene.types.codec.impl.AbstractLuceneNumericFieldCodec; -import org.hibernate.search.engine.backend.types.converter.spi.ProjectionConverter; import org.hibernate.search.engine.search.aggregation.spi.FieldMetricAggregationBuilder; public class LuceneSumNumericFieldAggregation @@ -48,18 +47,17 @@ public FieldMetricAggregationBuilder.TypeSelector create(LuceneSearchIndexScope< } } - protected static class FunctionTypeSelector extends TypeSelector + protected static class FunctionTypeSelector extends TypeSelector implements FieldMetricAggregationBuilder.TypeSelector { - protected FunctionTypeSelector(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, + protected FunctionTypeSelector(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, LuceneSearchIndexValueFieldContext field) { super( codec, scope, field ); } @Override - protected Builder getFtBuilder( - ProjectionConverter projectionConverter) { - return new Builder<>( codec, scope, field, projectionConverter ); + protected Builder getFtBuilder(AbstractExtractorBuilder extractorCreator) { + return new Builder<>( codec, scope, field, extractorCreator ); } } @@ -69,8 +67,8 @@ protected static class Builder public Builder(AbstractLuceneNumericFieldCodec codec, LuceneSearchIndexScope scope, LuceneSearchIndexValueFieldContext field, - ProjectionConverter fromFieldValueConverter) { - super( codec, scope, field, fromFieldValueConverter ); + AbstractExtractorBuilder extractorCreator) { + super( codec, scope, field, extractorCreator ); } @Override diff --git a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/codec/impl/AbstractLuceneNumericFieldCodec.java b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/codec/impl/AbstractLuceneNumericFieldCodec.java index 29bfb6575bb..0fdf9b167de 100644 --- a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/codec/impl/AbstractLuceneNumericFieldCodec.java +++ b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/codec/impl/AbstractLuceneNumericFieldCodec.java @@ -64,4 +64,7 @@ public boolean isCompatibleWith(LuceneFieldCodec obj) { abstract void addStoredToDocument(LuceneDocumentContent documentBuilder, String absoluteFieldPath, F value, E encodedValue); + public Double sortedDocValueToDouble(Long value) { + return getDomain().sortedDocValueToDouble( value ); + } } diff --git a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/codec/impl/LuceneBigDecimalFieldCodec.java b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/codec/impl/LuceneBigDecimalFieldCodec.java index 003279c790b..f7ffeb50677 100644 --- a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/codec/impl/LuceneBigDecimalFieldCodec.java +++ b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/codec/impl/LuceneBigDecimalFieldCodec.java @@ -102,4 +102,9 @@ private BigDecimal scale(Long value) { public Class encodedType() { return Long.class; } + + @Override + public Double sortedDocValueToDouble(Long value) { + return scale( value ).doubleValue(); + } } diff --git a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/codec/impl/LuceneBigIntegerFieldCodec.java b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/codec/impl/LuceneBigIntegerFieldCodec.java index 13d49dc0bb6..31bd3d2e719 100644 --- a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/codec/impl/LuceneBigIntegerFieldCodec.java +++ b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/types/codec/impl/LuceneBigIntegerFieldCodec.java @@ -103,4 +103,9 @@ private BigDecimal scale(Long value) { public Class encodedType() { return Long.class; } + + @Override + public Double sortedDocValueToDouble(Long value) { + return scale( value ).doubleValue(); + } } diff --git a/integrationtest/backend/elasticsearch/src/test/java/org/hibernate/search/integrationtest/backend/elasticsearch/testsupport/util/ElasticsearchTckBackendFeatures.java b/integrationtest/backend/elasticsearch/src/test/java/org/hibernate/search/integrationtest/backend/elasticsearch/testsupport/util/ElasticsearchTckBackendFeatures.java index a1ef463f2d3..37ec2e4293c 100644 --- a/integrationtest/backend/elasticsearch/src/test/java/org/hibernate/search/integrationtest/backend/elasticsearch/testsupport/util/ElasticsearchTckBackendFeatures.java +++ b/integrationtest/backend/elasticsearch/src/test/java/org/hibernate/search/integrationtest/backend/elasticsearch/testsupport/util/ElasticsearchTckBackendFeatures.java @@ -52,6 +52,7 @@ import com.google.gson.Gson; import com.google.gson.GsonBuilder; import com.google.gson.JsonElement; +import com.google.gson.JsonObject; public class ElasticsearchTckBackendFeatures extends TckBackendFeatures { @@ -494,4 +495,23 @@ public boolean rangeAggregationsDoNotIgnoreQuery() { public boolean negativeDecimalScaleIsAppliedToAvgAggregationFunction() { return false; } + + @Override + public T fromRawAggregation(FieldTypeDescriptor typeDescriptor, T value) { + + JsonObject jsonObject = gson.fromJson( value.toString(), JsonObject.class ); + return (T) ( gson.toJson( jsonObject.has( "value_as_string" ) + ? jsonObject.get( "value_as_string" ) + : jsonObject.get( "value" ) ) ); + } + + @Override + public boolean rawAggregationProduceSensibleDoubleValue(FieldTypeDescriptor fFieldTypeDescriptor) { + if ( YearFieldTypeDescriptor.INSTANCE.equals( fFieldTypeDescriptor ) + || YearMonthFieldTypeDescriptor.INSTANCE.equals( fFieldTypeDescriptor ) + || LocalDateFieldTypeDescriptor.INSTANCE.equals( fFieldTypeDescriptor ) ) { + return false; + } + return super.rawAggregationProduceSensibleDoubleValue( fFieldTypeDescriptor ); + } } diff --git a/integrationtest/backend/lucene/src/test/java/org/hibernate/search/integrationtest/backend/lucene/testsupport/util/LuceneTckBackendFeatures.java b/integrationtest/backend/lucene/src/test/java/org/hibernate/search/integrationtest/backend/lucene/testsupport/util/LuceneTckBackendFeatures.java index 1e0b89efe9e..fb41d9eb9da 100644 --- a/integrationtest/backend/lucene/src/test/java/org/hibernate/search/integrationtest/backend/lucene/testsupport/util/LuceneTckBackendFeatures.java +++ b/integrationtest/backend/lucene/src/test/java/org/hibernate/search/integrationtest/backend/lucene/testsupport/util/LuceneTckBackendFeatures.java @@ -181,4 +181,16 @@ public Class rawType(FieldTypeDescriptor descriptor) { } return descriptor.getJavaType(); } + + @SuppressWarnings("unchecked") + @Override + public T fromRawAggregation(FieldTypeDescriptor descriptor, T value) { + if ( BigIntegerFieldTypeDescriptor.INSTANCE.equals( descriptor ) ) { + return (T) ( (Number) ( ( (Number) value ).doubleValue() * 100 ) ); + } + if ( BigDecimalFieldTypeDescriptor.INSTANCE.equals( descriptor ) ) { + return (T) ( (Number) ( ( (Number) value ).doubleValue() / 100 ) ); + } + return super.fromRawAggregation( descriptor, value ); + } } diff --git a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/search/aggregation/MetricFieldAggregationsIT.java b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/search/aggregation/MetricFieldAggregationsIT.java index cf9cde3c152..efe1257e64c 100644 --- a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/search/aggregation/MetricFieldAggregationsIT.java +++ b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/search/aggregation/MetricFieldAggregationsIT.java @@ -5,6 +5,7 @@ package org.hibernate.search.integrationtest.backend.tck.search.aggregation; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assumptions.assumeTrue; import java.util.ArrayList; import java.util.LinkedHashSet; @@ -14,10 +15,13 @@ import org.hibernate.search.engine.backend.document.model.dsl.IndexSchemaElement; import org.hibernate.search.engine.backend.types.Aggregable; +import org.hibernate.search.engine.search.common.ValueModel; import org.hibernate.search.integrationtest.backend.tck.testsupport.model.singlefield.SingleFieldIndexBinding; import org.hibernate.search.integrationtest.backend.tck.testsupport.operations.MetricAggregationsTestCase; import org.hibernate.search.integrationtest.backend.tck.testsupport.types.FieldTypeDescriptor; import org.hibernate.search.integrationtest.backend.tck.testsupport.types.StandardFieldTypeDescriptor; +import org.hibernate.search.integrationtest.backend.tck.testsupport.util.TckConfiguration; +import org.hibernate.search.integrationtest.backend.tck.testsupport.util.TypeAssertionHelper; import org.hibernate.search.integrationtest.backend.tck.testsupport.util.extension.SearchSetupHelper; import org.hibernate.search.util.impl.integrationtest.mapper.stub.BulkIndexer; import org.hibernate.search.util.impl.integrationtest.mapper.stub.SimpleMappedIndex; @@ -29,7 +33,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -public class MetricFieldAggregationsIT { +public class MetricFieldAggregationsIT { private static final Set> supportedFieldTypes = new LinkedHashSet<>(); private static final List> testCases = new ArrayList<>(); @@ -78,16 +82,46 @@ static void setup() { @ParameterizedTest(name = "{0}") @MethodSource("params") - public void test(MetricAggregationsTestCase testCase) { + public void test(MetricAggregationsTestCase testCase) { StubMappingScope scope = mainIndex.createScope(); - MetricAggregationsTestCase.Result result = testCase.testMetricsAggregation( scope, mainIndex.binding() ); - if ( result.expectedSum() != null ) { - assertThat( result.computedSum() ).isEqualTo( result.expectedSum() ); - } - assertThat( result.computedMin() ).isEqualTo( result.expectedMin() ); - assertThat( result.computedMax() ).isEqualTo( result.expectedMax() ); - assertThat( result.computedCount() ).isEqualTo( result.expectedCount() ); - assertThat( result.computedCountDistinct() ).isEqualTo( result.expectedCountDistinct() ); - assertThat( result.computedAvg() ).isEqualTo( result.expectedAvg() ); + MetricAggregationsTestCase.Result result = testCase.testMetricsAggregation( scope, mainIndex.binding(), + ValueModel.MAPPING, TypeAssertionHelper.identity( testCase.typeDescriptor() ) + ); + + result.validate(); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("params") + public void testString(MetricAggregationsTestCase testCase) { + StubMappingScope scope = mainIndex.createScope(); + MetricAggregationsTestCase.Result result = testCase.testMetricsAggregation( scope, mainIndex.binding(), + ValueModel.STRING, TypeAssertionHelper.string( testCase.typeDescriptor() ) + ); + result.validate(); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("params") + public void testRaw(MetricAggregationsTestCase testCase) { + StubMappingScope scope = mainIndex.createScope(); + MetricAggregationsTestCase.Result result = testCase.testMetricsAggregation( scope, mainIndex.binding(), + ValueModel.RAW, TypeAssertionHelper.raw( testCase.typeDescriptor() ) + ); + result.validate(); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("params") + public void testRawDouble(MetricAggregationsTestCase testCase) { + assumeTrue( + TckConfiguration.get().getBackendFeatures() + .rawAggregationProduceSensibleDoubleValue( testCase.typeDescriptor() ), + "Some date-time types with Elasticsearch backends can produce some 'garbage' double values, but for those we usually rely on the value_as_string anyway." ); + StubMappingScope scope = mainIndex.createScope(); + MetricAggregationsTestCase.Result result = testCase.testMetricsAggregation( scope, mainIndex.binding(), + ValueModel.RAW, TypeAssertionHelper.rawDouble( testCase.typeDescriptor() ) + ); + result.validateDouble(); } } diff --git a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/search/aggregation/MetricNumericFieldsAggregationsIT.java b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/search/aggregation/MetricNumericFieldsAggregationsIT.java index 36e21e6ce26..735dde12b03 100644 --- a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/search/aggregation/MetricNumericFieldsAggregationsIT.java +++ b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/search/aggregation/MetricNumericFieldsAggregationsIT.java @@ -5,7 +5,6 @@ package org.hibernate.search.integrationtest.backend.tck.search.aggregation; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; import java.math.BigDecimal; import java.math.BigInteger; @@ -25,7 +24,6 @@ import org.hibernate.search.engine.search.query.SearchResult; import org.hibernate.search.engine.search.query.dsl.SearchQueryOptionsStep; import org.hibernate.search.integrationtest.backend.tck.testsupport.util.extension.SearchSetupHelper; -import org.hibernate.search.util.common.AssertionFailure; import org.hibernate.search.util.impl.integrationtest.mapper.stub.BulkIndexer; import org.hibernate.search.util.impl.integrationtest.mapper.stub.SimpleMappedIndex; import org.hibernate.search.util.impl.integrationtest.mapper.stub.StubLoadingOptionsStep; @@ -165,11 +163,8 @@ void test_allResults() { private SearchQuery defineAggregations( SearchQueryOptionsStep options) { - assertThatThrownBy( () -> { - options.aggregation( sumIntegersRaw, f -> f.sum().field( "integer", Object.class, ValueModel.RAW ) ); - } ) - .isInstanceOf( AssertionFailure.class ) - .hasMessageContaining( "Raw projection converter is not supported with metric aggregations at the moment" ); + + options.aggregation( sumIntegersRaw, f -> f.sum().field( "integer", Object.class, ValueModel.RAW ) ); return options .aggregation( sumIntegers, f -> f.sum().field( "integer", Integer.class ) ) @@ -192,10 +187,11 @@ private SearchQuery defineAggregations( .aggregation( avgIntegers, f -> f.avg().field( "integer", Integer.class ) ) .aggregation( avgIntegersAsString, f -> f.avg().field( "integer", String.class, ValueModel.STRING ) ) .aggregation( avgConverted, f -> f.avg().field( "converted", String.class ) ) - .aggregation( avgIntegersAsDouble, f -> f.avg().field( "integer", Double.class ) ) + .aggregation( avgIntegersAsDouble, f -> f.avg().field( "integer", Double.class, ValueModel.RAW ) ) .aggregation( avgIntegersAsDoubleRaw, f -> f.avg().field( "integer", Double.class, ValueModel.RAW ) ) - .aggregation( avgIntegersAsDoubleFiltered, f -> f.avg().field( "object.nestedInteger", Double.class ) - .filter( ff -> ff.range().field( "object.nestedInteger" ).atLeast( 5 ) ) ) + .aggregation( avgIntegersAsDoubleFiltered, + f -> f.avg().field( "object.nestedInteger", Double.class, ValueModel.RAW ) + .filter( ff -> ff.range().field( "object.nestedInteger" ).atLeast( 5 ) ) ) .aggregation( sumDoubles, f -> f.sum().field( "doubleF", Double.class ) ) .aggregation( sumDoublesRaw, f -> f.sum().field( "doubleF", Double.class, ValueModel.RAW ) ) .aggregation( sumFloats, f -> f.sum().field( "floatF", Float.class ) ) diff --git a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/operations/MetricAggregationsTestCase.java b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/operations/MetricAggregationsTestCase.java index 26a1ad9259e..de86070ee01 100644 --- a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/operations/MetricAggregationsTestCase.java +++ b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/operations/MetricAggregationsTestCase.java @@ -4,6 +4,7 @@ */ package org.hibernate.search.integrationtest.backend.tck.testsupport.operations; +import static org.assertj.core.api.Assertions.assertThat; import static org.hibernate.search.util.impl.integrationtest.mapper.stub.StubMapperUtils.documentProvider; import java.util.List; @@ -12,16 +13,23 @@ import org.hibernate.search.engine.backend.common.DocumentReference; import org.hibernate.search.engine.search.aggregation.AggregationKey; +import org.hibernate.search.engine.search.common.ValueModel; import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; import org.hibernate.search.engine.search.query.SearchQuery; import org.hibernate.search.engine.search.query.SearchResult; import org.hibernate.search.engine.search.query.dsl.SearchQueryOptionsStep; import org.hibernate.search.integrationtest.backend.tck.testsupport.model.singlefield.SingleFieldIndexBinding; +import org.hibernate.search.integrationtest.backend.tck.testsupport.types.ByteFieldTypeDescriptor; import org.hibernate.search.integrationtest.backend.tck.testsupport.types.FieldTypeDescriptor; +import org.hibernate.search.integrationtest.backend.tck.testsupport.types.IntegerFieldTypeDescriptor; +import org.hibernate.search.integrationtest.backend.tck.testsupport.types.LongFieldTypeDescriptor; +import org.hibernate.search.integrationtest.backend.tck.testsupport.types.ShortFieldTypeDescriptor; import org.hibernate.search.integrationtest.backend.tck.testsupport.types.values.MetricAggregationsValues; import org.hibernate.search.integrationtest.backend.tck.testsupport.util.IndexFieldLocation; import org.hibernate.search.integrationtest.backend.tck.testsupport.util.IndexFieldValueCardinality; import org.hibernate.search.integrationtest.backend.tck.testsupport.util.TestedFieldStructure; +import org.hibernate.search.integrationtest.backend.tck.testsupport.util.TypeAssertionHelper; +import org.hibernate.search.util.common.SearchException; import org.hibernate.search.util.impl.integrationtest.mapper.stub.BulkIndexer; import org.hibernate.search.util.impl.integrationtest.mapper.stub.StubLoadingOptionsStep; import org.hibernate.search.util.impl.integrationtest.mapper.stub.StubMappingScope; @@ -36,7 +44,11 @@ public class MetricAggregationsTestCase { public MetricAggregationsTestCase(FieldTypeDescriptor typeDescriptor) { this.typeDescriptor = typeDescriptor; - metricAggregationsValues = typeDescriptor.metricAggregationsValues(); + this.metricAggregationsValues = typeDescriptor.metricAggregationsValues(); + } + + public FieldTypeDescriptor typeDescriptor() { + return typeDescriptor; } public int contribute(BulkIndexer indexer, SingleFieldIndexBinding binding) { @@ -47,7 +59,8 @@ public int contribute(BulkIndexer indexer, SingleFieldIndexBinding binding) { String keyB = String.format( Locale.ROOT, "%03d_NEST_%s", i, uniqueName ); String keyC = String.format( Locale.ROOT, "%03d_FLAT_%s", i, uniqueName ); indexer.add( documentProvider( keyA, document -> binding.initSingleValued( typeDescriptor, - IndexFieldLocation.ROOT, document, value ) ) ); + IndexFieldLocation.ROOT, document, value + ) ) ); indexer.add( documentProvider( keyB, document -> binding.initSingleValued( typeDescriptor, IndexFieldLocation.IN_NESTED, document, value ) ) ); indexer.add( documentProvider( keyC, document -> binding.initSingleValued( @@ -56,27 +69,29 @@ public int contribute(BulkIndexer indexer, SingleFieldIndexBinding binding) { return metricAggregationsValues.values().size() * 3; } - public Result testMetricsAggregation(StubMappingScope scope, SingleFieldIndexBinding binding) { - InternalResult result = new InternalResult<>(); - String fieldPath = binding.getFieldPath( TestedFieldStructure.of( - IndexFieldLocation.ROOT, IndexFieldValueCardinality.SINGLE_VALUED ), typeDescriptor ); - - SearchQueryOptionsStep step = scope.query().where( - SearchPredicateFactory::matchAll ) - .aggregation( result.minKey, f -> f.min().field( fieldPath, typeDescriptor.getJavaType() ) ) - .aggregation( result.maxKey, f -> f.max().field( fieldPath, typeDescriptor.getJavaType() ) ) + public Result testMetricsAggregation(StubMappingScope scope, SingleFieldIndexBinding binding, ValueModel valueModel, + TypeAssertionHelper typeAssertionHelper) { + InternalResult result = new InternalResult<>(); + String fieldPath = binding.getFieldPath( + TestedFieldStructure.of( IndexFieldLocation.ROOT, IndexFieldValueCardinality.SINGLE_VALUED ), typeDescriptor ); + Class javaClass = typeAssertionHelper.getJavaClass(); + + SearchQueryOptionsStep step = scope.query() + .where( SearchPredicateFactory::matchAll ) + .aggregation( result.minKey, f -> f.min().field( fieldPath, javaClass, valueModel ) ) + .aggregation( result.maxKey, f -> f.max().field( fieldPath, javaClass, valueModel ) ) .aggregation( result.countKey, f -> f.count().field( fieldPath ) ) .aggregation( result.countDistinctKey, f -> f.countDistinct().field( fieldPath ) ) - .aggregation( result.avgKey, f -> f.avg().field( fieldPath, typeDescriptor.getJavaType() ) ); + .aggregation( result.avgKey, f -> f.avg().field( fieldPath, javaClass, valueModel ) ); if ( metricAggregationsValues.sum() != null ) { - step.aggregation( result.sumKey, f -> f.sum().field( fieldPath, typeDescriptor.getJavaType() ) ); + step.aggregation( result.sumKey, f -> f.sum().field( fieldPath, javaClass, valueModel ) ); } - SearchQuery query = step - .toQuery(); - result.apply( query, metricAggregationsValues ); - return new Result<>( typeDescriptor.getJavaType(), metricAggregationsValues, result ); + SearchQuery query = step.toQuery(); + result.apply( query ); + + return new Result<>( result, typeAssertionHelper, valueModel ); } @Override @@ -84,75 +99,66 @@ public String toString() { return "Case{" + typeDescriptor + '}'; } - public static class Result { - private final Class javaType; - private final MetricAggregationsValues metricAggregationsValues; - private final InternalResult result; + public class Result { + private final InternalResult result; + private final TypeAssertionHelper typeAssertionHelper; + private final ValueModel valueModel; - private Result(Class javaType, MetricAggregationsValues metricAggregationsValues, - InternalResult result) { - this.javaType = javaType; - this.metricAggregationsValues = metricAggregationsValues; + private Result(InternalResult result, TypeAssertionHelper typeAssertionHelper, ValueModel valueModel) { this.result = result; + this.typeAssertionHelper = typeAssertionHelper; + this.valueModel = valueModel; } public List values() { return metricAggregationsValues.values(); } - // expected* can return null, which would mean that this type does not support a particular aggregation - public F expectedSum() { - return metricAggregationsValues.sum(); - } - - public F expectedMin() { - return metricAggregationsValues.min(); - } - - public F expectedMax() { - return metricAggregationsValues.max(); - } - - public Long expectedCount() { - return metricAggregationsValues.count(); - } - - public Long expectedCountDistinct() { - return metricAggregationsValues.countDistinct(); - } - - public F expectedAvg() { - return metricAggregationsValues.avg(); - } - - public F computedSum() { - return result.sum; - } - - public F computedMax() { - return result.max; - } - - public F computedMin() { - return result.min; - } - - public Long computedCount() { - return result.count; + @SuppressWarnings("unchecked") + public void validate() { + validateCommon(); + if ( metricAggregationsValues.sum() != null ) { + typeAssertionHelper.assertSameAggregation( result.sum, metricAggregationsValues.sum() ); + } + typeAssertionHelper.assertSameAggregation( result.min, metricAggregationsValues.min() ); + typeAssertionHelper.assertSameAggregation( result.max, metricAggregationsValues.max() ); + + // Elasticsearch would return a double average even for int types, and if we access a raw value, + // it can contain decimals, so we handle this case differently to the others: + if ( typeAssertionHelper.getJavaClass().equals( String.class ) + && ValueModel.RAW.equals( valueModel ) + && ( IntegerFieldTypeDescriptor.INSTANCE.equals( typeDescriptor ) + || LongFieldTypeDescriptor.INSTANCE.equals( typeDescriptor ) + || ShortFieldTypeDescriptor.INSTANCE.equals( typeDescriptor ) + || ByteFieldTypeDescriptor.INSTANCE.equals( typeDescriptor ) ) ) { + // the cast is "safe" as we've tested the `getJavaClass` just above. + typeAssertionHelper.assertSameAggregation( result.avg, + (F) Double.toString( metricAggregationsValues.avgRaw() ) ); + } + else { + typeAssertionHelper.assertSameAggregation( result.avg, metricAggregationsValues.avg() ); + } } - public Long computedCountDistinct() { - return result.countDistinct; + public void validateDouble() { + validateCommon(); + if ( metricAggregationsValues.sum() != null ) { + assertThat( ( (Number) result.sum ).doubleValue() ).isEqualTo( metricAggregationsValues.sumRaw() ); + } + assertThat( ( (Number) result.min ).doubleValue() ).isEqualTo( metricAggregationsValues.minRaw() ); + assertThat( ( (Number) result.max ).doubleValue() ).isEqualTo( metricAggregationsValues.maxRaw() ); + assertThat( ( (Number) result.avg ).doubleValue() ).isEqualTo( metricAggregationsValues.avgRaw() ); } - public F computedAvg() { - return result.avg; + private void validateCommon() { + assertThat( result.count ).isEqualTo( metricAggregationsValues.count() ); + assertThat( result.countDistinct ).isEqualTo( metricAggregationsValues.countDistinct() ); } @Override public String toString() { return new StringJoiner( ", ", Result.class.getSimpleName() + "[", "]" ) - .add( "javaType=" + javaType ) + .add( "javaType=" + "javaType" ) .add( "values=" + values() ) .toString(); } @@ -173,11 +179,13 @@ private static class InternalResult { Long countDistinct; F avg; - void apply(SearchQuery query, MetricAggregationsValues metricAggregationsValues) { + void apply(SearchQuery query) { SearchResult result = query.fetch( 0 ); - if ( metricAggregationsValues.sum() != null ) { + try { sum = result.aggregation( sumKey ); } + catch (SearchException e) { + } min = result.aggregation( minKey ); max = result.aggregation( maxKey ); count = result.aggregation( countKey ); diff --git a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/BigIntegerFieldTypeDescriptor.java b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/BigIntegerFieldTypeDescriptor.java index 23c9207c0aa..9236242105f 100644 --- a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/BigIntegerFieldTypeDescriptor.java +++ b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/BigIntegerFieldTypeDescriptor.java @@ -82,6 +82,11 @@ public BigInteger avg() { } return BigInteger.valueOf( 550L ); } + + @Override + protected double doubleValueOf(double value) { + return value * 100L; + } }; } diff --git a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/InstantFieldTypeDescriptor.java b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/InstantFieldTypeDescriptor.java index 1adf003e3ce..8e1a92bfe5d 100644 --- a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/InstantFieldTypeDescriptor.java +++ b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/InstantFieldTypeDescriptor.java @@ -75,6 +75,11 @@ protected Instant valueOf(int value) { public Instant avg() { return Instant.parse( "1970-01-07T08:46:40Z" ); } + + @Override + protected double doubleValueOf(double value) { + return value * 100_000_000; + } }; } diff --git a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/LocalDateTimeFieldTypeDescriptor.java b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/LocalDateTimeFieldTypeDescriptor.java index 785ce48bc93..31bb8707287 100644 --- a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/LocalDateTimeFieldTypeDescriptor.java +++ b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/LocalDateTimeFieldTypeDescriptor.java @@ -72,6 +72,11 @@ protected LocalDateTime valueOf(int value) { public LocalDateTime avg() { return LocalDateTime.parse( "1970-01-01T00:00:05.500" ); } + + @Override + protected double doubleValueOf(double value) { + return value * 1_000; + } }; } diff --git a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/OffsetDateTimeFieldTypeDescriptor.java b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/OffsetDateTimeFieldTypeDescriptor.java index aed5ab851ee..4f7f5c73ecb 100644 --- a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/OffsetDateTimeFieldTypeDescriptor.java +++ b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/OffsetDateTimeFieldTypeDescriptor.java @@ -8,6 +8,7 @@ import java.time.LocalDateTime; import java.time.OffsetDateTime; import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Arrays; @@ -82,6 +83,11 @@ protected OffsetDateTime valueOf(int value) { public OffsetDateTime avg() { return OffsetDateTime.parse( "1970-01-07T08:46:40Z" ); } + + @Override + protected double doubleValueOf(double value) { + return value * 100_000_000; + } }; } @@ -160,4 +166,9 @@ public Optional> getIndex LocalDateTime.of( 2018, 3, 1, 12, 14, 52 ).atOffset( ZoneOffset.ofHours( 1 ) ) ) ); } + + @Override + public String format(OffsetDateTime value) { + return DateTimeFormatter.ISO_OFFSET_DATE_TIME.format( value ); + } } diff --git a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/YearFieldTypeDescriptor.java b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/YearFieldTypeDescriptor.java index 7a439255c8a..a5f17b823cb 100644 --- a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/YearFieldTypeDescriptor.java +++ b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/YearFieldTypeDescriptor.java @@ -71,6 +71,11 @@ public Year sum() { protected Year valueOf(int value) { return Year.of( value + 2000 ); } + + @Override + protected double doubleValueOf(double value) { + return value + 2000; + } }; } diff --git a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/ZonedDateTimeFieldTypeDescriptor.java b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/ZonedDateTimeFieldTypeDescriptor.java index d644ec114b0..2c5f74349c8 100644 --- a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/ZonedDateTimeFieldTypeDescriptor.java +++ b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/ZonedDateTimeFieldTypeDescriptor.java @@ -9,6 +9,7 @@ import java.time.ZoneId; import java.time.ZoneOffset; import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Arrays; @@ -91,6 +92,11 @@ protected ZonedDateTime valueOf(int value) { public ZonedDateTime avg() { return ZonedDateTime.parse( "1970-01-01T00:00:05.500Z" ); } + + @Override + protected double doubleValueOf(double value) { + return value * 1_000; + } }; } @@ -182,4 +188,9 @@ public Optional> getIndexN LocalDateTime.of( 2018, 3, 1, 12, 14, 52 ).atZone( ZoneId.of( "Europe/Paris" ) ) ) ); } + + @Override + public String format(ZonedDateTime value) { + return DateTimeFormatter.ISO_ZONED_DATE_TIME.format( value ); + } } diff --git a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/values/MetricAggregationsValues.java b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/values/MetricAggregationsValues.java index e5b804347f1..0c5f4094339 100644 --- a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/values/MetricAggregationsValues.java +++ b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/types/values/MetricAggregationsValues.java @@ -42,6 +42,26 @@ public F avg() { return valueOf( 5 ); } + public double avgRaw() { + return doubleValueOf( 5.5 ); + } + + public double minRaw() { + return doubleValueOf( -10 ); + } + + public double maxRaw() { + return doubleValueOf( 18 ); + } + + public double sumRaw() { + return doubleValueOf( 55 ); + } + + protected double doubleValueOf(double value) { + return value; + } + protected abstract F valueOf(int value); private List createValues() { diff --git a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/util/TckBackendFeatures.java b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/util/TckBackendFeatures.java index e27597c10d8..05bbeb8537e 100644 --- a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/util/TckBackendFeatures.java +++ b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/util/TckBackendFeatures.java @@ -4,6 +4,7 @@ */ package org.hibernate.search.integrationtest.backend.tck.testsupport.util; +import java.time.temporal.TemporalAccessor; import java.util.Comparator; import java.util.Optional; @@ -173,7 +174,7 @@ public boolean vectorSearchRequiredMinimumSimilarityAsLucene() { public abstract Object toRawValue(FieldTypeDescriptor descriptor, F value); - // with some backends sorts need require a different raw value to what other places like predicates will allow. + // with some backends sorts require a different raw value to what other places like predicates will allow. // E.g. Elasticsearch won't accept the formatted string date-time types and expects it to be in form of a number instead. public Object toSortRawValue(FieldTypeDescriptor descriptor, F value) { return toRawValue( descriptor, value ); @@ -200,4 +201,24 @@ public boolean rangeAggregationsDoNotIgnoreQuery() { public boolean negativeDecimalScaleIsAppliedToAvgAggregationFunction() { return true; } + + public T fromRawAggregation(FieldTypeDescriptor typeDescriptor, T value) { + return value; + } + + public Double toDoubleValue(FieldTypeDescriptor descriptor, F fieldValue) { + if ( Number.class.isAssignableFrom( descriptor.getJavaType() ) ) { + return ( (Number) fieldValue ).doubleValue(); + } + + if ( TemporalAccessor.class.isAssignableFrom( descriptor.getJavaType() ) ) { + return ( (Number) toSortRawValue( descriptor, fieldValue ) ).doubleValue(); + } + + throw new UnsupportedOperationException( "Type " + descriptor.getJavaType() + " is not supported" ); + } + + public boolean rawAggregationProduceSensibleDoubleValue(FieldTypeDescriptor fFieldTypeDescriptor) { + return true; + } } diff --git a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/util/TypeAssertionHelper.java b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/util/TypeAssertionHelper.java index ed8044e3ab7..819f03fa494 100644 --- a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/util/TypeAssertionHelper.java +++ b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/util/TypeAssertionHelper.java @@ -4,6 +4,8 @@ */ package org.hibernate.search.integrationtest.backend.tck.testsupport.util; +import static org.assertj.core.api.Assertions.assertThat; + import java.math.BigDecimal; import java.math.RoundingMode; import java.time.Instant; @@ -29,6 +31,10 @@ private TypeAssertionHelper() { public abstract T create(F fieldValue); + public void assertSameAggregation(T value1, F value2) { + assertThat( value1 ).isEqualTo( create( value2 ) ); + } + public boolean isSame(F a, F b) { return Objects.equals( a, b ); } @@ -143,6 +149,19 @@ public T create(F fieldValue) { public boolean isSame(F a, F b) { return isSame.test( a, b ); } + + @Override + public void assertSameAggregation(T value1, F value2) { + if ( Number.class.isAssignableFrom( typeDescriptor.getJavaType() ) ) { + assertThat( Double.parseDouble( TckConfiguration.get().getBackendFeatures() + .fromRawAggregation( typeDescriptor, value1 ).toString() ) ) + .isEqualTo( Double.parseDouble( value2.toString() ) ); + } + else { + assertThat( TckConfiguration.get().getBackendFeatures().fromRawAggregation( typeDescriptor, value1 ) ) + .isEqualTo( create( value2 ) ); + } + } }; } @@ -160,6 +179,20 @@ public String create(F fieldValue) { }; } + public static TypeAssertionHelper rawDouble(FieldTypeDescriptor typeDescriptor) { + return new TypeAssertionHelper() { + @Override + public Class getJavaClass() { + return Double.class; + } + + @Override + public Double create(F fieldValue) { + return TckConfiguration.get().getBackendFeatures().toDoubleValue( typeDescriptor, fieldValue ); + } + }; + } + private static R neverCalled(P1 param) { throw new IllegalStateException( "This should not be called; called with parameter " + param ); }