From f4a3a814c9b8d04aae5bccb427673f7d8eac9ce3 Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Fri, 25 Oct 2024 13:50:31 -0700 Subject: [PATCH] Fix FilterOperator to cache next element and avoid repeated consumption on hasNext() calls Signed-off-by: Peng Huo --- .../sql/planner/physical/FilterOperator.java | 32 +++++-- .../planner/physical/FilterOperatorTest.java | 84 +++++++++++++++++++ .../resources/correctness/bugfixes/3121.txt | 1 + 3 files changed, 109 insertions(+), 8 deletions(-) create mode 100644 integ-test/src/test/resources/correctness/bugfixes/3121.txt diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/FilterOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/FilterOperator.java index ec61d53163..45fb2fcc98 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/FilterOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/FilterOperator.java @@ -28,6 +28,7 @@ public class FilterOperator extends PhysicalPlan { @Getter private final PhysicalPlan input; @Getter private final Expression conditions; @ToString.Exclude private ExprValue next = null; + @ToString.Exclude private boolean nextPrepared = false; @Override public R accept(PhysicalPlanNodeVisitor visitor, C context) { @@ -41,19 +42,34 @@ public List getChild() { @Override public boolean hasNext() { + if (!nextPrepared) { + prepareNext(); + } + return next != null; + } + + @Override + public ExprValue next() { + if (!nextPrepared) { + prepareNext(); + } + ExprValue result = next; + next = null; + nextPrepared = false; + return result; + } + + private void prepareNext() { while (input.hasNext()) { ExprValue inputValue = input.next(); ExprValue exprValue = conditions.valueOf(inputValue.bindingTuples()); - if (!(exprValue.isNull() || exprValue.isMissing()) && (exprValue.booleanValue())) { + if (!(exprValue.isNull() || exprValue.isMissing()) && exprValue.booleanValue()) { next = inputValue; - return true; + nextPrepared = true; + return; } } - return false; - } - - @Override - public ExprValue next() { - return next; + next = null; + nextPrepared = true; } } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/FilterOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/FilterOperatorTest.java index bfe3b323c4..ba2354b168 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/FilterOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/FilterOperatorTest.java @@ -8,14 +8,24 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_FALSE; import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_MISSING; import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_NULL; +import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_TRUE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.filter; import com.google.common.collect.ImmutableMap; import java.util.LinkedHashMap; import java.util.List; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; @@ -26,12 +36,22 @@ import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; @ExtendWith(MockitoExtension.class) @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class FilterOperatorTest extends PhysicalPlanTestBase { @Mock private PhysicalPlan inputPlan; + @Mock private Expression condition; + + private FilterOperator filterOperator; + + @BeforeEach + public void setup() { + filterOperator = filter(inputPlan, condition); + } + @Test public void filter_test() { FilterOperator plan = @@ -82,4 +102,68 @@ public void missing_value_should_been_ignored() { List result = execute(plan); assertEquals(0, result.size()); } + + @Test + public void testHasNextWhenInputHasNoElements() { + when(inputPlan.hasNext()).thenReturn(false); + + assertFalse( + filterOperator.hasNext(), "hasNext() should return false when input has no elements"); + } + + @Test + public void testHasNextWithMatchingCondition() { + ExprValue inputValue = mock(ExprValue.class); + when(inputPlan.hasNext()).thenReturn(true).thenReturn(false); + when(inputPlan.next()).thenReturn(inputValue); + when(condition.valueOf(any())).thenReturn(LITERAL_TRUE); + + assertTrue(filterOperator.hasNext(), "hasNext() should return true when condition matches"); + assertEquals( + inputValue, filterOperator.next(), "next() should return the matching input value"); + } + + @Test + public void testHasNextWithNonMatchingCondition() { + ExprValue inputValue = mock(ExprValue.class); + when(inputPlan.hasNext()).thenReturn(true, false); + when(inputPlan.next()).thenReturn(inputValue); + when(condition.valueOf(any())).thenReturn(LITERAL_FALSE); + + assertFalse( + filterOperator.hasNext(), "hasNext() should return false if no values match the condition"); + } + + @Test + public void testMultipleCallsToHasNextDoNotConsumeInput() { + ExprValue inputValue = mock(ExprValue.class); + when(inputPlan.hasNext()).thenReturn(true); + when(inputPlan.next()).thenReturn(inputValue); + when(condition.valueOf(any())).thenReturn(LITERAL_TRUE); + + assertTrue( + filterOperator.hasNext(), + "First hasNext() call should return true if there is a matching value"); + verify(inputPlan, times(1)).next(); + assertTrue( + filterOperator.hasNext(), + "Subsequent hasNext() calls should still return true without advancing the input"); + verify(inputPlan, times(1)).next(); + assertEquals( + inputValue, filterOperator.next(), "next() should return the matching input value"); + verify(inputPlan, times(1)).next(); + } + + @Test + public void testNextWithoutCallingHasNext() { + ExprValue inputValue = mock(ExprValue.class); + when(inputPlan.hasNext()).thenReturn(true, false); + when(inputPlan.next()).thenReturn(inputValue); + when(condition.valueOf(any())).thenReturn(LITERAL_TRUE); + + assertEquals( + inputValue, + filterOperator.next(), + "next() should return the matching input value even if hasNext() was not called"); + } } diff --git a/integ-test/src/test/resources/correctness/bugfixes/3121.txt b/integ-test/src/test/resources/correctness/bugfixes/3121.txt new file mode 100644 index 0000000000..f60f724897 --- /dev/null +++ b/integ-test/src/test/resources/correctness/bugfixes/3121.txt @@ -0,0 +1 @@ +SELECT Origin, Dest FROM (SELECT * FROM opensearch_dashboards_sample_data_flights WHERE AvgTicketPrice > 100 GROUP BY Origin, Dest, AvgTicketPrice) AS flights WHERE AvgTicketPrice < 1000 ORDER BY AvgTicketPrice LIMIT 30