diff --git a/src/aggregation/metric/mod.rs b/src/aggregation/metric/mod.rs index 8fd568e2b3..a018ed2f9c 100644 --- a/src/aggregation/metric/mod.rs +++ b/src/aggregation/metric/mod.rs @@ -91,7 +91,7 @@ pub struct TopHitsVecEntry { /// The document id, composed of segment local `DocId` and segment ordinal. pub id: DocAddress, /// The sort values of the document, depending on the sort criteria in the request. - pub sort: Vec, + pub sort: Vec>, } /// The top_hits metric aggregation results a list of top hits by sort criteria. diff --git a/src/aggregation/metric/top_hits.rs b/src/aggregation/metric/top_hits.rs index 2c84f6c4a4..a1f6156641 100644 --- a/src/aggregation/metric/top_hits.rs +++ b/src/aggregation/metric/top_hits.rs @@ -110,16 +110,23 @@ impl TopHitsAggregation { #[derive(Clone, Serialize, Deserialize, Debug)] struct ComparableDocFeature { /// Stores any u64-mappable feature. - value: u64, + value: Option, /// Sort order for the doc feature order: Order, } impl Ord for ComparableDocFeature { fn cmp(&self, other: &Self) -> std::cmp::Ordering { - match self.order { - Order::Asc => self.value.cmp(&other.value), - Order::Desc => other.value.cmp(&self.value), + let invert = |cmp: std::cmp::Ordering| match self.order { + Order::Asc => cmp, + Order::Desc => cmp.reverse(), + }; + + match (self.value, other.value) { + (Some(self_value), Some(other_value)) => invert(self_value.cmp(&other_value)), + (Some(_), None) => std::cmp::Ordering::Greater, + (None, Some(_)) => std::cmp::Ordering::Less, + (None, None) => std::cmp::Ordering::Equal, } } } @@ -290,10 +297,9 @@ impl SegmentAggregationCollector for SegmentTopHitsCollector { .zip(self.inner_collector.req.sort.iter()) .map(|((c, _), KeyOrder(_, order))| { let order = *order; - match c.values_for_doc(doc_id).next() { - Some(value) => ComparableDocFeature { value, order }, - // TODO: confirm if this default 0-value is correct - None => ComparableDocFeature { value: 0, order }, + ComparableDocFeature { + value: c.values_for_doc(doc_id).next(), + order, } }) .collect(); @@ -323,3 +329,54 @@ impl SegmentAggregationCollector for SegmentTopHitsCollector { Ok(()) } } + +#[cfg(test)] +mod tests { + use serde_json::Value; + + use super::{ComparableDoc, ComparableDocFeature, ComparableDocFeatures, Order}; + use crate::aggregation::agg_req::Aggregations; + use crate::aggregation::agg_result::AggregationResults; + use crate::aggregation::bucket::tests::get_test_index_from_docs; + use crate::aggregation::tests::get_test_index_from_values; + use crate::aggregation::AggregationCollector; + use crate::query::AllQuery; + + fn invert_order(cmp_feature: ComparableDocFeature) -> ComparableDocFeature { + let ComparableDocFeature { value, order } = cmp_feature; + let order = match order { + Order::Asc => Order::Desc, + Order::Desc => Order::Asc, + }; + ComparableDocFeature { value, order } + } + + #[test] + fn test_comparable_doc_feature() -> crate::Result<()> { + let small = ComparableDocFeature { + value: Some(1), + order: Order::Asc, + }; + let big = ComparableDocFeature { + value: Some(2), + order: Order::Asc, + }; + let none = ComparableDocFeature { + value: None, + order: Order::Asc, + }; + + assert!(small < big); + assert!(none < small); + assert!(none < big); + + let small = invert_order(small); + let big = invert_order(big); + let none = invert_order(none); + + assert!(small > big); + assert!(none < small); + assert!(none < big); + + Ok(()) + }