Skip to content

Commit

Permalink
HSEARCH-5133 Support avg field-typed
Browse files Browse the repository at this point in the history
  • Loading branch information
fax4ever committed Jun 27, 2024
1 parent 1d8ac44 commit 28aebae
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 116 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.hibernate.search.backend.elasticsearch.search.common.impl.AbstractElasticsearchCodecAwareSearchQueryElementFactory;
import org.hibernate.search.backend.elasticsearch.search.common.impl.ElasticsearchSearchIndexScope;
import org.hibernate.search.backend.elasticsearch.search.common.impl.ElasticsearchSearchIndexValueFieldContext;
import org.hibernate.search.backend.elasticsearch.types.codec.impl.ElasticsearchDoubleFieldCodec;
import org.hibernate.search.backend.elasticsearch.types.codec.impl.ElasticsearchFieldCodec;
import org.hibernate.search.engine.backend.types.converter.runtime.FromDocumentValueConvertContext;
import org.hibernate.search.engine.backend.types.converter.spi.ProjectionConverter;
Expand Down Expand Up @@ -85,18 +86,31 @@ private TypeSelector(ElasticsearchFieldCodec<F> codec,

@Override
public <T> Builder<F, T> type(Class<T> expectedType, ValueConvert convert) {
ProjectionConverter<F, ? extends T> projectionConverter = null;
if ( !Double.class.isAssignableFrom( expectedType ) ||
field.type().projectionConverter( convert ).valueType().isAssignableFrom( expectedType ) ) {
projectionConverter = field.type().projectionConverter( convert )
.withConvertedType( expectedType, field );
}
return new Builder<>( codec, scope, field,
field.type().projectionConverter( convert ).withConvertedType( expectedType, field ),
operation );
projectionConverter,
operation
);
}
}

@SuppressWarnings("unchecked")
private class MetricFieldExtractor implements Extractor<K> {
@Override
public K extract(JsonObject aggregationResult, AggregationExtractContext context) {
FromDocumentValueConvertContext convertContext = context.fromDocumentValueConvertContext();
JsonElement value = aggregationResult.get( "value" );
JsonElement valueAsString = aggregationResult.get( "value_as_string" );

if ( fromFieldValueConverter == null ) {
Double decode = ElasticsearchDoubleFieldCodec.INSTANCE.decode( value );
return (K) decode;
}
return fromFieldValueConverter.fromDocumentValue(
codec.decodeAggregationValue( value, valueAsString ),
convertContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
*/
package org.hibernate.search.backend.elasticsearch.types.dsl.impl;

import org.hibernate.search.backend.elasticsearch.search.aggregation.impl.ElasticsearchMetricDoubleAggregation;
import org.hibernate.search.backend.elasticsearch.search.aggregation.impl.ElasticsearchMetricFieldAggregation;
import org.hibernate.search.backend.elasticsearch.search.aggregation.impl.ElasticsearchMetricLongAggregation;
import org.hibernate.search.backend.elasticsearch.search.aggregation.impl.ElasticsearchRangeAggregation;
Expand Down Expand Up @@ -77,7 +76,7 @@ protected final void complete() {
builder.queryElementFactory( AggregationTypeKeys.COUNT_DISTINCT,
new ElasticsearchMetricLongAggregation.Factory<>( codec, "cardinality" ) );
builder.queryElementFactory( AggregationTypeKeys.AVG,
new ElasticsearchMetricDoubleAggregation.Factory<>( codec, "avg" )
new ElasticsearchMetricFieldAggregation.Factory<>( codec, "avg" )
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,25 @@ public interface AvgAggregationFieldStep<PDF extends SearchPredicateFactory> {
* Target the given field in the avg aggregation.
*
* @param fieldPath The <a href="SearchAggregationFactory.html#field-paths">path</a> to the index field to aggregate.
* @param type The type of field values.
* @param <F> The type of field values or {@link Double} if a double result is required.
* @return The next step.
*/
default <F> AvgAggregationOptionsStep<?, PDF, F> field(String fieldPath, Class<F> type) {
return field( fieldPath, type, ValueConvert.YES );
}

/**
* Target the given field in the avg aggregation.
*
* @param fieldPath The <a href="SearchAggregationFactory.html#field-paths">path</a> to the index field to aggregate.
* @param type The type of field values.
* @param <F> The type of field values or {@link Double} if a double result is required.
* @param convert Controls how the ranges passed to the next steps and fetched from the backend should be converted.
* See {@link ValueConvert}.
* @return The next step.
*/
AvgAggregationOptionsStep<?, PDF> field(String fieldPath);
<F> AvgAggregationOptionsStep<?, PDF, F> field(String fieldPath, Class<F> type,
ValueConvert convert);

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
*
* @param <S> The "self" type (the actual exposed type of this step).
* @param <PDF> The type of factory used to create predicates in {@link #filter(Function)}.
* @param <F> The type of the targeted field. The type of result for this aggregation.
*/
public interface AvgAggregationOptionsStep<
S extends AvgAggregationOptionsStep<?, PDF>,
PDF extends SearchPredicateFactory>
extends AggregationFinalStep<Double>, AggregationFilterStep<S, PDF> {
S extends AvgAggregationOptionsStep<?, PDF, F>,
PDF extends SearchPredicateFactory,
F>
extends AggregationFinalStep<F>, AggregationFilterStep<S, PDF> {

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,25 @@
import org.hibernate.search.engine.search.aggregation.dsl.AvgAggregationOptionsStep;
import org.hibernate.search.engine.search.aggregation.dsl.spi.SearchAggregationDslContext;
import org.hibernate.search.engine.search.aggregation.spi.AggregationTypeKeys;
import org.hibernate.search.engine.search.aggregation.spi.SearchFilterableAggregationBuilder;
import org.hibernate.search.engine.search.aggregation.spi.FieldMetricAggregationBuilder;
import org.hibernate.search.engine.search.common.ValueConvert;
import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory;
import org.hibernate.search.util.common.impl.Contracts;

public class AvgAggregationFieldStepImpl<PDF extends SearchPredicateFactory>
implements AvgAggregationFieldStep<PDF> {
public class AvgAggregationFieldStepImpl<PDF extends SearchPredicateFactory> implements AvgAggregationFieldStep<PDF> {
private final SearchAggregationDslContext<?, ? extends PDF> dslContext;

public AvgAggregationFieldStepImpl(SearchAggregationDslContext<?, ? extends PDF> dslContext) {
this.dslContext = dslContext;
}

@Override
public AvgAggregationOptionsStep<?, PDF> field(String fieldPath) {
SearchFilterableAggregationBuilder<Double> builder = dslContext.scope()
.fieldQueryElement( fieldPath, AggregationTypeKeys.AVG );
public <F> AvgAggregationOptionsStep<?, PDF, F> field(String fieldPath, Class<F> type,
ValueConvert convert) {
Contracts.assertNotNull( fieldPath, "fieldPath" );
Contracts.assertNotNull( type, "type" );
FieldMetricAggregationBuilder<F> builder = dslContext.scope()
.fieldQueryElement( fieldPath, AggregationTypeKeys.AVG ).type( type, convert );
return new AvgAggregationOptionsStepImpl<>( builder, dslContext );
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,38 @@
import org.hibernate.search.engine.search.aggregation.SearchAggregation;
import org.hibernate.search.engine.search.aggregation.dsl.AvgAggregationOptionsStep;
import org.hibernate.search.engine.search.aggregation.dsl.spi.SearchAggregationDslContext;
import org.hibernate.search.engine.search.aggregation.spi.SearchFilterableAggregationBuilder;
import org.hibernate.search.engine.search.aggregation.spi.FieldMetricAggregationBuilder;
import org.hibernate.search.engine.search.predicate.SearchPredicate;
import org.hibernate.search.engine.search.predicate.dsl.PredicateFinalStep;
import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory;

class AvgAggregationOptionsStepImpl<PDF extends SearchPredicateFactory>
implements AvgAggregationOptionsStep<AvgAggregationOptionsStepImpl<PDF>, PDF> {
private final SearchFilterableAggregationBuilder<Double> builder;
class AvgAggregationOptionsStepImpl<PDF extends SearchPredicateFactory, F>
implements AvgAggregationOptionsStep<AvgAggregationOptionsStepImpl<PDF, F>, PDF, F> {
private final FieldMetricAggregationBuilder<F> builder;
private final SearchAggregationDslContext<?, ? extends PDF> dslContext;

AvgAggregationOptionsStepImpl(SearchFilterableAggregationBuilder<Double> builder,
AvgAggregationOptionsStepImpl(FieldMetricAggregationBuilder<F> builder,
SearchAggregationDslContext<?, ? extends PDF> dslContext) {
this.builder = builder;
this.dslContext = dslContext;
}

@Override
public AvgAggregationOptionsStepImpl<PDF> filter(
public AvgAggregationOptionsStepImpl<PDF, F> filter(
Function<? super PDF, ? extends PredicateFinalStep> clauseContributor) {
SearchPredicate predicate = clauseContributor.apply( dslContext.predicateFactory() ).toPredicate();

return filter( predicate );
}

@Override
public AvgAggregationOptionsStepImpl<PDF> filter(SearchPredicate searchPredicate) {
public AvgAggregationOptionsStepImpl<PDF, F> filter(SearchPredicate searchPredicate) {
builder.filter( searchPredicate );
return this;
}

@Override
public SearchAggregation<Double> toAggregation() {
public SearchAggregation<F> toAggregation() {
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ private AggregationTypeKeys() {
of( IndexFieldTraits.Aggregations.COUNT );
public static final SearchQueryElementTypeKey<SearchFilterableAggregationBuilder<Long>> COUNT_DISTINCT =
of( IndexFieldTraits.Aggregations.COUNT_DISTINCT );
public static final SearchQueryElementTypeKey<SearchFilterableAggregationBuilder<Double>> AVG =
public static final SearchQueryElementTypeKey<FieldMetricAggregationBuilder.TypeSelector> AVG =
of( IndexFieldTraits.Aggregations.AVG );

}
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ public class ElasticsearchMetricAggregationsIT {
private final AggregationKey<Long> countConverted = AggregationKey.of( "countConverted" );
private final AggregationKey<Long> countDistinctIntegers = AggregationKey.of( "countDistinctIntegers" );
private final AggregationKey<Long> countDistinctConverted = AggregationKey.of( "countDistinctConverted" );
private final AggregationKey<Double> avgIntegers = AggregationKey.of( "avgIntegers" );
private final AggregationKey<Double> avgConverted = AggregationKey.of( "avgConverted" );
private final AggregationKey<Integer> avgIntegers = AggregationKey.of( "avgIntegers" );
private final AggregationKey<String> avgConverted = AggregationKey.of( "avgConverted" );
private final AggregationKey<Double> avgIntegersAsDouble = AggregationKey.of( "avgIntegersAsDouble" );

@BeforeEach
void setup() {
Expand All @@ -62,8 +63,9 @@ public void test_filteringResults() {
.aggregation( countConverted, f -> f.count().field( "converted" ) )
.aggregation( countDistinctIntegers, f -> f.countDistinct().field( "integer" ) )
.aggregation( countDistinctConverted, f -> f.countDistinct().field( "converted" ) )
.aggregation( avgIntegers, f -> f.avg().field( "integer" ) )
.aggregation( avgConverted, f -> f.avg().field( "converted" ) )
.aggregation( avgIntegers, f -> f.avg().field( "integer", Integer.class ) )
.aggregation( avgConverted, f -> f.avg().field( "converted", String.class ) )
.aggregation( avgIntegersAsDouble, f -> f.avg().field( "integer", Double.class ) )
.toQuery();

SearchResult<DocumentReference> result = query.fetch( 0 );
Expand All @@ -77,8 +79,9 @@ public void test_filteringResults() {
assertThat( result.aggregation( countConverted ) ).isEqualTo( 5 );
assertThat( result.aggregation( countDistinctIntegers ) ).isEqualTo( 3 );
assertThat( result.aggregation( countDistinctConverted ) ).isEqualTo( 3 );
assertThat( result.aggregation( avgIntegers ) ).isEqualTo( 5.8 );
assertThat( result.aggregation( avgConverted ) ).isEqualTo( 5.8 );
assertThat( result.aggregation( avgIntegers ) ).isEqualTo( 5 );
assertThat( result.aggregation( avgConverted ) ).isEqualTo( "5" );
assertThat( result.aggregation( avgIntegersAsDouble ) ).isEqualTo( 5.8 );
}

@Test
Expand All @@ -96,8 +99,9 @@ public void test_allResults() {
.aggregation( countConverted, f -> f.count().field( "converted" ) )
.aggregation( countDistinctIntegers, f -> f.countDistinct().field( "integer" ) )
.aggregation( countDistinctConverted, f -> f.countDistinct().field( "converted" ) )
.aggregation( avgIntegers, f -> f.avg().field( "integer" ) )
.aggregation( avgConverted, f -> f.avg().field( "converted" ) )
.aggregation( avgIntegers, f -> f.avg().field( "integer", Integer.class ) )
.aggregation( avgConverted, f -> f.avg().field( "converted", String.class ) )
.aggregation( avgIntegersAsDouble, f -> f.avg().field( "integer", Double.class ) )
.toQuery();

SearchResult<DocumentReference> result = query.fetch( 0 );
Expand All @@ -111,8 +115,9 @@ public void test_allResults() {
assertThat( result.aggregation( countConverted ) ).isEqualTo( 10 );
assertThat( result.aggregation( countDistinctIntegers ) ).isEqualTo( 6 );
assertThat( result.aggregation( countDistinctConverted ) ).isEqualTo( 6 );
assertThat( result.aggregation( avgIntegers ) ).isEqualTo( 5.5 );
assertThat( result.aggregation( avgConverted ) ).isEqualTo( 5.5 );
assertThat( result.aggregation( avgIntegers ) ).isEqualTo( 5 );
assertThat( result.aggregation( avgConverted ) ).isEqualTo( "5" );
assertThat( result.aggregation( avgIntegersAsDouble ) ).isEqualTo( 5.5 );
}

private void initData() {
Expand Down

0 comments on commit 28aebae

Please sign in to comment.