From 1744885acb39e26ee68e14ff46a5c093cec82b2a Mon Sep 17 00:00:00 2001 From: hanbj Date: Fri, 21 Feb 2025 14:25:14 +0800 Subject: [PATCH] PointInSetQuery clips segments by lower and upper --- .../apache/lucene/search/PointInSetQuery.java | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java b/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java index f0e0cfd6bdb8..d37dbc02091b 100644 --- a/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java @@ -62,6 +62,8 @@ public abstract class PointInSetQuery extends Query implements Accountable { final int numDims; final int bytesPerDim; final long ramBytesUsed; // cache + byte[] lowerPoint = null; + byte[] upperPoint = null; /** Iterator of encoded point values. */ // TODO: if we want to stream, maybe we should use jdk stream class? @@ -108,6 +110,8 @@ protected PointInSetQuery(String field, int numDims, int bytesPerDim, Stream pac } if (previous == null) { previous = new BytesRefBuilder(); + lowerPoint = new byte[bytesPerDim * numDims]; + System.arraycopy(current.bytes, current.offset, lowerPoint, 0, lowerPoint.length); } else { int cmp = previous.get().compareTo(current); if (cmp == 0) { @@ -122,6 +126,11 @@ protected PointInSetQuery(String field, int numDims, int bytesPerDim, Stream pac } sortedPackedPoints = builder.finish(); sortedPackedPointsHashCode = sortedPackedPoints.hashCode(); + if (previous != null) { + BytesRef max = previous.get(); + upperPoint = new byte[bytesPerDim * numDims]; + System.arraycopy(max.bytes, max.offset, upperPoint, 0, upperPoint.length); + } ramBytesUsed = BASE_RAM_BYTES + RamUsageEstimator.sizeOfObject(field) @@ -153,6 +162,21 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti return null; } + if (values.getDocCount() == 0) { + return null; + } else if (lowerPoint != null && upperPoint != null) { + ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator(bytesPerDim); + final byte[] fieldPackedLower = values.getMinPackedValue(); + final byte[] fieldPackedUpper = values.getMaxPackedValue(); + for (int i = 0; i < numDims; ++i) { + int offset = i * bytesPerDim; + if (comparator.compare(lowerPoint, offset, fieldPackedUpper, offset) > 0 + || comparator.compare(upperPoint, offset, fieldPackedLower, offset) < 0) { + return null; + } + } + } + if (values.getNumIndexDimensions() != numDims) { throw new IllegalArgumentException( "field=\""