diff --git a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/search/aggregation/MetricFieldAggregationsIT.java b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/search/aggregation/MetricFieldAggregationsIT.java new file mode 100644 index 00000000000..942a357f472 --- /dev/null +++ b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/search/aggregation/MetricFieldAggregationsIT.java @@ -0,0 +1,93 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.search.integrationtest.backend.tck.search.aggregation; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.ArrayList; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; +import java.util.function.Function; + +import org.hibernate.search.engine.backend.document.model.dsl.IndexSchemaElement; +import org.hibernate.search.engine.backend.types.Aggregable; +import org.hibernate.search.integrationtest.backend.tck.testsupport.model.singlefield.SingleFieldIndexBinding; +import org.hibernate.search.integrationtest.backend.tck.testsupport.operations.MetricAggregationsTestCase; +import org.hibernate.search.integrationtest.backend.tck.testsupport.types.FieldTypeDescriptor; +import org.hibernate.search.integrationtest.backend.tck.testsupport.types.StandardFieldTypeDescriptor; +import org.hibernate.search.integrationtest.backend.tck.testsupport.util.extension.SearchSetupHelper; +import org.hibernate.search.util.impl.integrationtest.mapper.stub.BulkIndexer; +import org.hibernate.search.util.impl.integrationtest.mapper.stub.SimpleMappedIndex; +import org.hibernate.search.util.impl.integrationtest.mapper.stub.StubMappingScope; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +public class MetricFieldAggregationsIT { + + private static final Set> supportedFieldTypes = new LinkedHashSet<>(); + private static final List> testCases = new ArrayList<>(); + private static final List parameters = new ArrayList<>(); + + static { + for ( StandardFieldTypeDescriptor typeDescriptor : FieldTypeDescriptor.getAllStandard() ) { + MetricAggregationsTestCase scenario = new MetricAggregationsTestCase<>( typeDescriptor ); + if ( !scenario.supported() ) { + continue; + } + testCases.add( scenario ); + supportedFieldTypes.add( typeDescriptor ); + parameters.add( Arguments.of( scenario ) ); + } + } + + public static List params() { + return parameters; + } + + @RegisterExtension + public static final SearchSetupHelper setupHelper = SearchSetupHelper.create(); + + private static final Function bindingFactory = + root -> SingleFieldIndexBinding.create( root, supportedFieldTypes, c -> c.aggregable( Aggregable.YES ) ); + private static final SimpleMappedIndex mainIndex = + SimpleMappedIndex.of( bindingFactory ).name( "main" ); + + @BeforeAll + static void setup() { + int expectedDocuments = 0; + + setupHelper.start().withIndexes( mainIndex ).setup(); + BulkIndexer indexer = mainIndex.bulkIndexer(); + for ( MetricAggregationsTestCase scenario : testCases ) { + expectedDocuments += scenario.contribute( indexer, mainIndex.binding() ); + } + indexer.join(); + + long createdDocuments = mainIndex.createScope().query().where( f -> f.matchAll() ) + .totalHitCountThreshold( expectedDocuments ) + .toQuery().fetch( 0 ).total().hitCountLowerBound(); + assertThat( createdDocuments ).isEqualTo( expectedDocuments ); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("params") + public void test(MetricAggregationsTestCase testCase) { + StubMappingScope scope = mainIndex.createScope(); + MetricAggregationsTestCase.Result result = testCase.testMetricsAggregation( scope, mainIndex.binding() ); + if ( result.expectedSum() != null ) { + assertThat( result.computedSum() ).isEqualTo( result.expectedSum() ); + } + assertThat( result.computedMin() ).isEqualTo( result.expectedMin() ); + assertThat( result.computedMax() ).isEqualTo( result.expectedMax() ); + assertThat( result.computedCount() ).isEqualTo( result.expectedCount() ); + assertThat( result.computedCountDistinct() ).isEqualTo( result.expectedCountDistinct() ); + assertThat( result.computedAvg() ).isEqualTo( result.expectedAvg() ); + } +} diff --git a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/operations/MetricAggregationsTestCase.java b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/operations/MetricAggregationsTestCase.java new file mode 100644 index 00000000000..ffa093b27ae --- /dev/null +++ b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/operations/MetricAggregationsTestCase.java @@ -0,0 +1,183 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.search.integrationtest.backend.tck.testsupport.operations; + +import static org.hibernate.search.util.impl.integrationtest.mapper.stub.StubMapperUtils.documentProvider; + +import java.util.List; +import java.util.Locale; +import java.util.StringJoiner; + +import org.hibernate.search.engine.backend.common.DocumentReference; +import org.hibernate.search.engine.search.aggregation.AggregationKey; +import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; +import org.hibernate.search.engine.search.query.SearchQuery; +import org.hibernate.search.engine.search.query.SearchResult; +import org.hibernate.search.integrationtest.backend.tck.testsupport.model.singlefield.SingleFieldIndexBinding; +import org.hibernate.search.integrationtest.backend.tck.testsupport.types.FieldTypeDescriptor; +import org.hibernate.search.integrationtest.backend.tck.testsupport.types.values.MetricAggregationsValues; +import org.hibernate.search.integrationtest.backend.tck.testsupport.util.IndexFieldLocation; +import org.hibernate.search.integrationtest.backend.tck.testsupport.util.IndexFieldValueCardinality; +import org.hibernate.search.integrationtest.backend.tck.testsupport.util.TestedFieldStructure; +import org.hibernate.search.util.impl.integrationtest.mapper.stub.BulkIndexer; +import org.hibernate.search.util.impl.integrationtest.mapper.stub.StubMappingScope; + +/** + * Denotes a metric aggregations test case for a particular {@link FieldTypeDescriptor}. + */ +public class MetricAggregationsTestCase { + + private final FieldTypeDescriptor typeDescriptor; + private final boolean supported; + private final MetricAggregationsValues metricAggregationsValues; + + public MetricAggregationsTestCase(FieldTypeDescriptor typeDescriptor) { + this.typeDescriptor = typeDescriptor; + metricAggregationsValues = typeDescriptor.metricAggregationsValues(); + this.supported = metricAggregationsValues != null; + } + + public boolean supported() { + return supported; + } + + public int contribute(BulkIndexer indexer, SingleFieldIndexBinding binding) { + int i = 0; + for ( F value : metricAggregationsValues.values() ) { + String uniqueName = typeDescriptor.getUniqueName(); + String keyA = String.format( Locale.ROOT, "%03d_ROOT_%s", ++i, uniqueName ); + String keyB = String.format( Locale.ROOT, "%03d_NEST_%s", i, uniqueName ); + String keyC = String.format( Locale.ROOT, "%03d_FLAT_%s", i, uniqueName ); + indexer.add( documentProvider( keyA, document -> binding.initSingleValued( typeDescriptor, + IndexFieldLocation.ROOT, document, value ) ) ); + indexer.add( documentProvider( keyB, document -> binding.initSingleValued( + typeDescriptor, IndexFieldLocation.IN_NESTED, document, value ) ) ); + indexer.add( documentProvider( keyC, document -> binding.initSingleValued( + typeDescriptor, IndexFieldLocation.IN_FLATTENED, document, value ) ) ); + } + return metricAggregationsValues.values().size() * 3; + } + + public Result testMetricsAggregation(StubMappingScope scope, SingleFieldIndexBinding binding) { + InternalResult result = new InternalResult<>(); + String fieldPath = binding.getFieldPath( TestedFieldStructure.of( + IndexFieldLocation.ROOT, IndexFieldValueCardinality.SINGLE_VALUED ), typeDescriptor ); + + SearchQuery query = scope.query().where( SearchPredicateFactory::matchAll ) + .aggregation( result.sumKey, f -> f.sum().field( fieldPath, typeDescriptor.getJavaType() ) ) + .aggregation( result.minKey, f -> f.min().field( fieldPath, typeDescriptor.getJavaType() ) ) + .aggregation( result.maxKey, f -> f.max().field( fieldPath, typeDescriptor.getJavaType() ) ) + .aggregation( result.countKey, f -> f.count().field( fieldPath ) ) + .aggregation( result.countDistinctKey, f -> f.countDistinct().field( fieldPath ) ) + .aggregation( result.avgKey, f -> f.avg().field( fieldPath, typeDescriptor.getJavaType() ) ) + .toQuery(); + result.apply( query ); + return new Result<>( typeDescriptor.getJavaType(), metricAggregationsValues, result ); + } + + @Override + public String toString() { + return "Case{" + typeDescriptor + '}'; + } + + public static class Result { + private final Class javaType; + private final MetricAggregationsValues metricAggregationsValues; + private final InternalResult result; + + public Result(Class javaType, MetricAggregationsValues metricAggregationsValues, + InternalResult result) { + this.javaType = javaType; + this.metricAggregationsValues = metricAggregationsValues; + this.result = result; + } + + public List values() { + return metricAggregationsValues.values(); + } + + public F expectedSum() { + return metricAggregationsValues.sum(); + } + + public F expectedMin() { + return metricAggregationsValues.min(); + } + + public F expectedMax() { + return metricAggregationsValues.max(); + } + + public Long expectedCount() { + return metricAggregationsValues.count(); + } + + public Long expectedCountDistinct() { + return metricAggregationsValues.countDistinct(); + } + + public F expectedAvg() { + return metricAggregationsValues.avg(); + } + + public F computedSum() { + return result.sum; + } + + public F computedMax() { + return result.max; + } + + public F computedMin() { + return result.min; + } + + public Long computedCount() { + return result.count; + } + + public Long computedCountDistinct() { + return result.countDistinct; + } + + public F computedAvg() { + return result.avg; + } + + @Override + public String toString() { + return new StringJoiner( ", ", Result.class.getSimpleName() + "[", "]" ) + .add( "javaType=" + javaType ) + .add( "values=" + values() ) + .toString(); + } + } + + private static class InternalResult { + AggregationKey sumKey = AggregationKey.of( "sum" ); + AggregationKey minKey = AggregationKey.of( "min" ); + AggregationKey maxKey = AggregationKey.of( "max" ); + AggregationKey countKey = AggregationKey.of( "count" ); + AggregationKey countDistinctKey = AggregationKey.of( "countDistinct" ); + AggregationKey avgKey = AggregationKey.of( "avg" ); + + F sum; + F min; + F max; + Long count; + Long countDistinct; + F avg; + + void apply(SearchQuery query) { + SearchResult result = query.fetch( 0 ); + sum = result.aggregation( sumKey ); + min = result.aggregation( minKey ); + max = result.aggregation( maxKey ); + count = result.aggregation( countKey ); + countDistinct = result.aggregation( countDistinctKey ); + avg = result.aggregation( avgKey ); + } + } +} diff --git a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/util/TestedFieldStructure.java b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/util/TestedFieldStructure.java index e7fbed1a510..cffcdfc5c8e 100644 --- a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/util/TestedFieldStructure.java +++ b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/testsupport/util/TestedFieldStructure.java @@ -29,7 +29,12 @@ public static List all() { return ALL; } + public static TestedFieldStructure of(IndexFieldLocation location, IndexFieldValueCardinality cardinality) { + return new TestedFieldStructure( location, cardinality ); + } + private static final List ALL; + static { List values = new ArrayList<>(); for ( IndexFieldLocation location : IndexFieldLocation.values() ) {