From ae859d4cacb9d86f1ccf05eba57931bb9dc75697 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Wed, 22 Nov 2023 18:39:21 +0800 Subject: [PATCH 1/2] feat(offline): WINDOW without ORDER BY only work if SkewWindowOpt is off --- cases/query/window_query.yaml | 29 ++++ hybridse/include/vm/mem_catalog.h | 125 ++---------------- hybridse/src/vm/core_api.cc | 6 +- hybridse/src/vm/core_api.h | 9 +- hybridse/src/vm/mem_catalog.cc | 118 +++++++++++++++++ hybridse/src/vm/runner.cc | 1 + hybridse/src/vm/runner.h | 2 + .../openmldb/batch/nodes/WindowAggPlan.scala | 4 + .../batch/utils/PhysicalNodeUtil.scala | 13 +- .../openmldb/batch/window/RowDebugger.scala | 6 +- .../batch/window/WindowAggPlanUtil.scala | 14 +- .../batch/window/WindowComputer.scala | 19 ++- .../openmldb/batch/end2end/TestWindow.scala | 38 ++++++ 13 files changed, 252 insertions(+), 132 deletions(-) diff --git a/cases/query/window_query.yaml b/cases/query/window_query.yaml index 3c64259d8c5..bf28ecb863a 100644 --- a/cases/query/window_query.yaml +++ b/cases/query/window_query.yaml @@ -904,6 +904,9 @@ cases: # ====================================================================== # WINDOW without ORDER BY + # + # batch mode tests skipped since ordering in WINDOW is undefined, we only + # verify result for request mode, that's implmentation defined order, not SQL standard # ====================================================================== - id: 24 desc: ROWS WINDOW WITHOUT ORDER BY @@ -1132,3 +1135,29 @@ cases: 3, 1, 0, 3, 3 4, 2, 1, 3, 3 5, 3, 2, 3, 3 + - id: 28 + # simple case verify it compile & run for batch mode + desc: RANGE WINDOW WITHOUT ORDER BY + inputs: + - name: t1 + columns: + - id int + - gp int + - ts timestamp + indexs: + - idx:gp:ts + data: | + 1, 100, 20000 + 2, 100, 10000 + 3, 400, 20000 + 4, 400, 10 + 5, 400, 15000 + sql: | + select id, count(ts) over w as agg + from t1 + window w as ( + partition by gp + rows_range between unbounded preceding and current row + ) + expect: + success: true diff --git a/hybridse/include/vm/mem_catalog.h b/hybridse/include/vm/mem_catalog.h index 6237edd1d43..172cd68db26 100644 --- a/hybridse/include/vm/mem_catalog.h +++ b/hybridse/include/vm/mem_catalog.h @@ -258,9 +258,13 @@ class Window : public MemTimeTableHandler { bool exclude_current_time() const { return exclude_current_time_; } void set_exclude_current_time(bool flag) { exclude_current_time_ = flag; } + bool without_order_by() const { return without_order_by_; } + void set_without_order_by(bool flag) { without_order_by_ = flag; } + protected: bool exclude_current_time_ = false; bool instance_not_in_window_ = false; + bool without_order_by_ = false; }; class WindowRange { public: @@ -356,44 +360,13 @@ class HistoryWindow : public Window { PopFrontRow(); } } + bool BufferData(uint64_t key, const Row& row) override; - // aad newer row into window - bool BufferData(uint64_t key, const Row& row) override { - if (!table_.empty() && GetFrontRow().first > key) { - DLOG(WARNING) << "Fail BufferData: buffer key less than latest key"; - return false; - } - auto cur_size = table_.size(); - if (cur_size < window_range_.start_row_) { - // current in the ROWS window - int64_t sub = key + window_range_.start_offset_; - uint64_t start_ts = sub < 0 ? 0u : static_cast(sub); - if (0 == window_range_.end_offset_) { - return BufferCurrentTimeBuffer(key, row, start_ts); - } else { - return BufferEffectiveWindow(key, row, start_ts); - } - } else if (0 == window_range_.end_offset_) { - // current in the ROWS_RANGE window - int64_t sub = (static_cast(key) + window_range_.start_offset_); - uint64_t start_ts = sub < 0 ? 0u : static_cast(sub); - return BufferCurrentTimeBuffer(key, row, start_ts); - } else { - // current row BeforeWindow - int64_t sub = (key + window_range_.end_offset_); - uint64_t end_ts = sub < 0 ? 0u : static_cast(sub); - return BufferCurrentHistoryBuffer(key, row, end_ts); - } - } + // add newer row into window + bool BufferDataImpl(uint64_t key, const Row& row); protected: - bool BufferCurrentHistoryBuffer(uint64_t key, const Row& row, uint64_t end_ts) { - current_history_buffer_.emplace_front(key, row); - int64_t sub = (static_cast(key) + window_range_.start_offset_); - uint64_t start_ts = sub < 0 ? 0u : static_cast(sub); - SlideWindow(start_ts, end_ts); - return true; - } + bool BufferCurrentHistoryBuffer(uint64_t key, const Row& row, uint64_t end_ts); // sliding rows data from `current_history_buffer_` into effective window // by giving the new start_ts and end_ts. @@ -413,77 +386,18 @@ class HistoryWindow : public Window { // `start_ts_inclusive` and `end_ts_inclusive` can be empty, which effectively means less than 0. // if `start_ts_inclusive` is empty, no rows goes out of effective window // if `end_ts_inclusive` is empty, no rows goes out of history buffer and into effective window - void SlideWindow(std::optional start_ts_inclusive, std::optional end_ts_inclusive) { - // always try to cleanup the stale rows out of effective window - if (start_ts_inclusive.has_value()) { - Slide(start_ts_inclusive); - } - - if (!end_ts_inclusive.has_value()) { - return; - } - - while (!current_history_buffer_.empty() && current_history_buffer_.back().first <= end_ts_inclusive) { - auto& back = current_history_buffer_.back(); - - BufferEffectiveWindow(back.first, back.second, start_ts_inclusive); - current_history_buffer_.pop_back(); - } - } + void SlideWindow(std::optional start_ts_inclusive, std::optional end_ts_inclusive); // push the row to the start of window // - pop last elements in window if exceed max window size // - also pop last elements in window if there ts less than `start_ts` // // if `start_ts` is empty, no rows eliminated from window - bool BufferEffectiveWindow(uint64_t key, const Row& row, std::optional start_ts) { - AddFrontRow(key, row); - return Slide(start_ts); - } + bool BufferEffectiveWindow(uint64_t key, const Row& row, std::optional start_ts); - bool Slide(std::optional start_ts) { - auto cur_size = table_.size(); - while (window_range_.max_size_ > 0 && - cur_size > window_range_.max_size_) { - PopBackRow(); - --cur_size; - } + bool Slide(std::optional start_ts); - // Slide window if window start bound >= rows/range preceding - while (cur_size > 0) { - const auto& pair = GetBackRow(); - if ((kFrameRows == window_range_.frame_type_ || kFrameRowsMergeRowsRange == window_range_.frame_type_) && - cur_size <= window_range_.start_row_ + 1) { - // note it is always current rows window - break; - } - if (kFrameRows == window_range_.frame_type_ || pair.first < start_ts) { - PopBackRow(); - --cur_size; - } else { - break; - } - } - return true; - } - - bool BufferCurrentTimeBuffer(uint64_t key, const Row& row, uint64_t start_ts) { - if (exclude_current_time_) { - // except `exclude current_row`, the current row is always added to the effective window - // but for next buffer action, previous current row already buffered in `current_history_buffer_` - // so the previous current row need eliminated for this next buf action - PopEffectiveDataIfAny(); - if (key == 0) { - SlideWindow(start_ts, {}); - } else { - SlideWindow(start_ts, key - 1); - } - current_history_buffer_.emplace_front(key, row); - } - - // in queue the current row - return BufferEffectiveWindow(key, row, start_ts); - } + bool BufferCurrentTimeBuffer(uint64_t key, const Row& row, uint64_t start_ts); WindowRange window_range_; MemTimeTable current_history_buffer_; @@ -512,20 +426,7 @@ class CurrentHistoryWindow : public HistoryWindow { void PopFrontData() override { PopFrontRow(); } - bool BufferData(uint64_t key, const Row& row) override { - if (!table_.empty() && GetFrontRow().first > key) { - DLOG(WARNING) << "Fail BufferData: buffer key less than latest key"; - return false; - } - int64_t sub = (key + window_range_.start_offset_); - uint64_t start_ts = sub < 0 ? 0u : static_cast(sub); - - if (exclude_current_time_) { - return BufferCurrentTimeBuffer(key, row, start_ts); - } else { - return BufferEffectiveWindow(key, row, start_ts); - } - } + bool BufferData(uint64_t key, const Row& row) override; }; typedef std::map 0 && end_offset == 0) { max_size++; } @@ -38,6 +35,7 @@ WindowInterface::WindowInterface(bool instance_not_in_window, bool exclude_curre WindowRange(ExtractFrameType(frame_type_str), start_offset, end_offset, rows_preceding, max_size)); window_impl_->set_instance_not_in_window(instance_not_in_window); window_impl_->set_exclude_current_time(exclude_current_time); + window_impl_->set_without_order_by(without_order_by); } bool WindowInterface::BufferData(uint64_t key, const Row& row) { diff --git a/hybridse/src/vm/core_api.h b/hybridse/src/vm/core_api.h index 08b3ea3664e..c236abbb1cb 100644 --- a/hybridse/src/vm/core_api.h +++ b/hybridse/src/vm/core_api.h @@ -17,7 +17,6 @@ #ifndef HYBRIDSE_SRC_VM_CORE_API_H_ #define HYBRIDSE_SRC_VM_CORE_API_H_ -#include #include #include #include "codec/fe_row_codec.h" @@ -41,7 +40,13 @@ class WindowInterface { public: WindowInterface(bool instance_not_in_window, bool exclude_current_time, bool execlude_current_row, const std::string& frame_type_str, int64_t start_offset, int64_t end_offset, - uint64_t rows_preceding, uint64_t max_size); + uint64_t rows_preceding, uint64_t max_size) { + WindowInterface(instance_not_in_window, exclude_current_time, execlude_current_row, frame_type_str, + start_offset, end_offset, rows_preceding, max_size, false); + } + WindowInterface(bool instance_not_in_window, bool exclude_current_time, bool execlude_current_row, + const std::string& frame_type_str, int64_t start_offset, int64_t end_offset, + uint64_t rows_preceding, uint64_t max_size, bool without_order_by); bool BufferData(uint64_t key, const Row& row); diff --git a/hybridse/src/vm/mem_catalog.cc b/hybridse/src/vm/mem_catalog.cc index f4f5897f10f..2786df61ea8 100644 --- a/hybridse/src/vm/mem_catalog.cc +++ b/hybridse/src/vm/mem_catalog.cc @@ -414,5 +414,123 @@ size_t RowGetSliceSize(int8_t* row_ptr, size_t idx) { auto row = reinterpret_cast(row_ptr); return row->size(idx); } +bool HistoryWindow::BufferData(uint64_t key, const Row& row) { + if (without_order_by()) { + return BufferDataImpl(0, row); + } + + return BufferDataImpl(key, row); +} +bool HistoryWindow::BufferDataImpl(uint64_t key, const Row& row) { + if (!table_.empty() && GetFrontRow().first > key) { + DLOG(WARNING) << "Fail BufferData: buffer key (" << key << ") less than latest key (" << GetFrontRow().first + << ")"; + return false; + } + auto cur_size = table_.size(); + if (cur_size < window_range_.start_row_) { + // current in the ROWS window + int64_t sub = key + window_range_.start_offset_; + uint64_t start_ts = sub < 0 ? 0u : static_cast(sub); + if (0 == window_range_.end_offset_) { + return BufferCurrentTimeBuffer(key, row, start_ts); + } else { + return BufferEffectiveWindow(key, row, start_ts); + } + } else if (0 == window_range_.end_offset_) { + // current in the ROWS_RANGE window + int64_t sub = (static_cast(key) + window_range_.start_offset_); + uint64_t start_ts = sub < 0 ? 0u : static_cast(sub); + return BufferCurrentTimeBuffer(key, row, start_ts); + } else { + // current row BeforeWindow + int64_t sub = (key + window_range_.end_offset_); + uint64_t end_ts = sub < 0 ? 0u : static_cast(sub); + return BufferCurrentHistoryBuffer(key, row, end_ts); + } +} +bool HistoryWindow::BufferCurrentHistoryBuffer(uint64_t key, const Row& row, uint64_t end_ts) { + current_history_buffer_.emplace_front(key, row); + int64_t sub = (static_cast(key) + window_range_.start_offset_); + uint64_t start_ts = sub < 0 ? 0u : static_cast(sub); + SlideWindow(start_ts, end_ts); + return true; +} +void HistoryWindow::SlideWindow(std::optional start_ts_inclusive, std::optional end_ts_inclusive) { + // always try to cleanup the stale rows out of effective window + if (start_ts_inclusive.has_value()) { + Slide(start_ts_inclusive); + } + + if (!end_ts_inclusive.has_value()) { + return; + } + + while (!current_history_buffer_.empty() && current_history_buffer_.back().first <= end_ts_inclusive) { + auto& back = current_history_buffer_.back(); + + BufferEffectiveWindow(back.first, back.second, start_ts_inclusive); + current_history_buffer_.pop_back(); + } +} +bool HistoryWindow::BufferEffectiveWindow(uint64_t key, const Row& row, std::optional start_ts) { + AddFrontRow(key, row); + return Slide(start_ts); +} +bool HistoryWindow::Slide(std::optional start_ts) { + auto cur_size = table_.size(); + while (window_range_.max_size_ > 0 && cur_size > window_range_.max_size_) { + PopBackRow(); + --cur_size; + } + + // Slide window if window start bound >= rows/range preceding + while (cur_size > 0) { + const auto& pair = GetBackRow(); + if ((kFrameRows == window_range_.frame_type_ || kFrameRowsMergeRowsRange == window_range_.frame_type_) && + cur_size <= window_range_.start_row_ + 1) { + // note it is always current rows window + break; + } + if (kFrameRows == window_range_.frame_type_ || pair.first < start_ts) { + PopBackRow(); + --cur_size; + } else { + break; + } + } + return true; +} +bool HistoryWindow::BufferCurrentTimeBuffer(uint64_t key, const Row& row, uint64_t start_ts) { + if (exclude_current_time_) { + // except `exclude current_row`, the current row is always added to the effective window + // but for next buffer action, previous current row already buffered in `current_history_buffer_` + // so the previous current row need eliminated for this next buf action + PopEffectiveDataIfAny(); + if (key == 0) { + SlideWindow(start_ts, {}); + } else { + SlideWindow(start_ts, key - 1); + } + current_history_buffer_.emplace_front(key, row); + } + + // in queue the current row + return BufferEffectiveWindow(key, row, start_ts); +} +bool CurrentHistoryWindow::BufferData(uint64_t key, const Row& row) { + if (!table_.empty() && GetFrontRow().first > key) { + DLOG(WARNING) << "Fail BufferData: buffer key less than latest key"; + return false; + } + int64_t sub = (key + window_range_.start_offset_); + uint64_t start_ts = sub < 0 ? 0u : static_cast(sub); + + if (exclude_current_time_) { + return BufferCurrentTimeBuffer(key, row, start_ts); + } else { + return BufferEffectiveWindow(key, row, start_ts); + } +} } // namespace vm } // namespace hybridse diff --git a/hybridse/src/vm/runner.cc b/hybridse/src/vm/runner.cc index eb284e6e945..5721f9bb34f 100644 --- a/hybridse/src/vm/runner.cc +++ b/hybridse/src/vm/runner.cc @@ -651,6 +651,7 @@ void WindowAggRunner::RunWindowAggOnKey( HistoryWindow window(instance_window_gen_.range_gen_->window_range_); window.set_instance_not_in_window(instance_not_in_window_); window.set_exclude_current_time(exclude_current_time_); + window.set_without_order_by(without_order_by()); while (instance_segment_iter->Valid()) { if (limit_cnt_.has_value() && cnt >= limit_cnt_) { diff --git a/hybridse/src/vm/runner.h b/hybridse/src/vm/runner.h index b40130db812..5a9f29a0c90 100644 --- a/hybridse/src/vm/runner.h +++ b/hybridse/src/vm/runner.h @@ -553,6 +553,8 @@ class WindowAggRunner : public Runner { const bool instance_not_in_window_; const bool exclude_current_time_; + bool without_order_by() const { return !instance_window_gen_.sort_gen_.Valid(); } + // slice size outputed of the first producer node const size_t append_slices_; WindowGenerator instance_window_gen_; diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/WindowAggPlan.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/WindowAggPlan.scala index 824efbddfc2..8523f446e72 100755 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/WindowAggPlan.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/WindowAggPlan.scala @@ -215,6 +215,10 @@ object WindowAggPlan { val repartitionColIndexes = PhysicalNodeUtil.getRepartitionColumnIndexes(windowAggNode, inputDf) val orderByColIndex = PhysicalNodeUtil.getOrderbyColumnIndex(windowAggNode, inputDf) + if (orderByColIndex < 0) { + throw new Exception("WindowSkewOpt can not run for WINDOW without ORDER BY") + } + // Register the input table val partIdColName = "PART_ID" + uniqueNamePostfix val expandedRowColName = "EXPANDED_ROW" + uniqueNamePostfix diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/PhysicalNodeUtil.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/PhysicalNodeUtil.scala index 186699dfa59..62b45248f17 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/PhysicalNodeUtil.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/PhysicalNodeUtil.scala @@ -86,9 +86,12 @@ object PhysicalNodeUtil { : mutable.ArrayBuffer[Column] = { val windowOp = windowAggNode.window() + val orderByCols = mutable.ArrayBuffer[Column]() val orders = windowOp.sort().orders() + if (orders == null) { + return orderByCols + } val ordersExprListNode = orders.getOrder_expressions_ - val orderByCols = mutable.ArrayBuffer[Column]() for (i <- 0 until ordersExprListNode.GetChildNum()) { val orderExpr = orders.GetOrderExpression(i) @@ -130,6 +133,10 @@ object PhysicalNodeUtil { val windowOp = windowAggNode.window() val orders = windowOp.sort().orders() + if (orders == null) { + // WINDOW without ORDER BY + return ""; + } val orderExprListNode = orders.getOrder_expressions_ if (orderExprListNode.GetChildNum() <= 0) { @@ -150,6 +157,10 @@ object PhysicalNodeUtil { val windowOp = windowAggNode.window() val orders = windowOp.sort().orders() + if (orders == null) { + // WINDOW without ORDER BY + return -1; + } val orderExprListNode = orders.getOrder_expressions_ if (orderExprListNode.GetChildNum() <= 0) { diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/window/RowDebugger.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/window/RowDebugger.scala index 3dd6907734d..579afb81565 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/window/RowDebugger.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/window/RowDebugger.scala @@ -40,8 +40,10 @@ class RowDebugger(sqlConfig: OpenmldbBatchConfig, config: WindowAggConfig, isSke def printRow(computer: WindowComputer, row: Row): Unit = { if (cnt % sampleInterval == 0) { val str = new StringBuffer() - str.append(row.get(config.orderIdx)) - str.append(",") + if (config.orderIdx >= 0) { + str.append(row.get(config.orderIdx)) + str.append(",") + } for (e <- config.groupIdxs) { str.append(row.get(e)) str.append(",") diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/window/WindowAggPlanUtil.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/window/WindowAggPlanUtil.scala index 9886a3212a0..0d2f206b733 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/window/WindowAggPlanUtil.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/window/WindowAggPlanUtil.scala @@ -172,11 +172,17 @@ object WindowAggPlanUtil { // process order key val orders = windowOp.sort().orders() - val ordersExprListNode = orders.getOrder_expressions_() - if (ordersExprListNode.GetChildNum() > 1) { - throw new HybridSeException("Multiple window order not supported") + if (orders != null) { + val ordersExprListNode = orders.getOrder_expressions_() + if (ordersExprListNode.GetChildNum() > 1) { + throw new HybridSeException("Multiple window order not supported") + } + } + val orderIdx = if (orders == null) { + -1 + } else { + SparkColumnUtil.resolveOrderColumnIndex(orders.GetOrderExpression(0), node.GetProducer(0)) } - val orderIdx = SparkColumnUtil.resolveOrderColumnIndex(orders.GetOrderExpression(0), node.GetProducer(0)) // process group-by keys val groups = windowOp.partition().keys() diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/window/WindowComputer.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/window/WindowComputer.scala index 9336b25340d..c124892f1a5 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/window/WindowComputer.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/window/WindowComputer.scala @@ -47,9 +47,6 @@ class WindowComputer(config: WindowAggConfig, jit: HybridSeJitWrapper, keepIndex protected var encoder = new SparkRowCodec(config.inputSchemaSlices) private var decoder = new SparkRowCodec(config.outputSchemaSlices) - // order key field - private val orderField = config.inputSchema(config.orderIdx) - // append slices cnt = needAppendInput ? inputSchemaSlices.size : 0 private val appendSlices = if (config.needAppendInput) config.inputSchemaSlices.length else 0 @@ -73,7 +70,7 @@ class WindowComputer(config: WindowAggConfig, jit: HybridSeJitWrapper, keepIndex config.excludeCurrentTime, config.excludeCurrentRow, config.windowFrameTypeName, - config.startOffset, config.endOffset, config.rowPreceding, config.maxSize) + config.startOffset, config.endOffset, config.rowPreceding, config.maxSize, config.orderIdx < 0) def compute(row: Row, key: Long, keepIndexColumn: Boolean, unionFlagIdx: Int, inputSchemaSize: Int, outputSchema: StructType, enableUnsafeRowFormat: Boolean): Row = { @@ -250,16 +247,24 @@ class WindowComputer(config: WindowAggConfig, jit: HybridSeJitWrapper, keepIndex window = new WindowInterface( config.instanceNotInWindow, config.excludeCurrentTime, config.excludeCurrentRow, config.windowFrameTypeName, - config.startOffset, config.endOffset, config.rowPreceding, config.maxSize) + config.startOffset, config.endOffset, config.rowPreceding, config.maxSize, config.orderIdx < 0) } def extractKey(curRow: Row): Long = { - SparkRowUtil.getLongFromIndex(config.orderIdx, orderField.dataType, curRow) + if (config.orderIdx < 0) { + // no ORDER BY: all to 0 + return 0 + } + SparkRowUtil.getLongFromIndex(config.orderIdx, config.inputSchema(config.orderIdx).dataType, curRow) } def extractUnsafeKey(curRow: UnsafeRow): Long = { + if (config.orderIdx < 0) { + // no ORDER BY: all to 0 + return 0 + } // TODO(tobe): support different data types - SparkRowUtil.unsafeGetLongFromIndex(config.orderIdx, orderField.dataType, curRow) + SparkRowUtil.unsafeGetLongFromIndex(config.orderIdx, config.inputSchema(config.orderIdx).dataType, curRow) } def delete(): Unit = { diff --git a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestWindow.scala b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestWindow.scala index 9d7b9b09cfa..af73e4cea02 100644 --- a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestWindow.scala +++ b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestWindow.scala @@ -122,4 +122,42 @@ class TestWindow extends SparkTestSuite { assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), compareDf, false)) } + test("Test end2end WINDOW without ORDER BY") { + + val spark = getSparkSession + val sess = new OpenmldbSession(spark) + + // test with small set, ordering is undetermined for WINDOW without ORDER BY + val data = Seq( + Row(1, "tom", 100, 1), + Row(2, "amy", 200, 2), + Row(3, "tom", 300, 3), + Row(4, "tom", 400, 4)) + + val schema = StructType(List( + StructField("id", IntegerType), + StructField("user", StringType), + StructField("trans_amount", IntegerType), + StructField("trans_time", IntegerType))) + val df = spark.createDataFrame(spark.sparkContext.makeRDD(data), schema) + + sess.registerTable("t1", df) + df.createOrReplaceTempView("t1") + + val sqlText = + """ + | SELECT id,sum(trans_amount) OVER w AS w_sum_amount FROM t1 + | WINDOW w AS ( + | PARTITION BY user + | ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW); + """.stripMargin + + val outputDf = sess.sql(sqlText) + + val sparksqlOutputDf = sess.sparksql(sqlText) + outputDf.show() + sparksqlOutputDf.show() + // Notice that the sum column type is different for SparkSQL and SparkFE + assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), sparksqlOutputDf, false)) + } } From 6f7dac35850f4c6fc5b8e3fe307e95ad03f1ea47 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Fri, 24 Nov 2023 06:27:33 +0000 Subject: [PATCH 2/2] feat(offline): support last join (lastjoin/window) --- .../batch/nulldata/TestJoinWithNullData.scala | 45 +++++++++++++++++++ .../nulldata/TestWindowWithNullData.scala | 44 ++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/nulldata/TestJoinWithNullData.scala b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/nulldata/TestJoinWithNullData.scala index 7f43f57975c..fbc831437de 100644 --- a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/nulldata/TestJoinWithNullData.scala +++ b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/nulldata/TestJoinWithNullData.scala @@ -93,4 +93,49 @@ class TestJoinWithNullData extends SparkTestSuite { assert(outputDf.collect().size == data.size) } + test("Test last join to last join") { + + val spark = getSparkSession + val sess = new OpenmldbSession(spark) + + val data1 = Seq(Row(1, 1), Row(2, 2)) + val schema1 = StructType(List( + StructField("id", IntegerType), + StructField("val", IntegerType))) + val df1 = spark.createDataFrame(spark.sparkContext.makeRDD(data1), schema1) + + val data2 = Seq(Row(1, 3), Row(2, 4)) + val schema2 = StructType(List( + StructField("id", IntegerType), + StructField("val", IntegerType))) + val df2 = spark.createDataFrame(spark.sparkContext.makeRDD(data2), schema2) + + val data3 = Seq(Row(1, 5)) + val schema3 = StructType(List( + StructField("id3", IntegerType), + StructField("val3", IntegerType))) + val df3 = spark.createDataFrame(spark.sparkContext.makeRDD(data3), schema3) + + sess.registerTable("t1", df1) + sess.registerTable("t2", df2) + sess.registerTable("t3", df3) + + val sqlText = + """select t1.id, tx.id as id2, tx.val3 from t1 last join + | (select * from t2 last join t3 on t2.id = t3.id3) tx on t1.id = tx.id + """.stripMargin + + val outputDf = sess.sql(sqlText) + outputDf.show() + + val output = Seq(Row(1, 1, 5), Row(2, 2, null)) + val output_sc = StructType(List( + StructField("id", IntegerType), + StructField("id2", IntegerType), + StructField("val3", IntegerType))) + val expect_df = spark.createDataFrame(spark.sparkContext.makeRDD(output), output_sc) + expect_df.show() + + assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), expect_df, true)) + } } diff --git a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/nulldata/TestWindowWithNullData.scala b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/nulldata/TestWindowWithNullData.scala index 1528f893e17..19511219327 100644 --- a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/nulldata/TestWindowWithNullData.scala +++ b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/nulldata/TestWindowWithNullData.scala @@ -65,4 +65,48 @@ class TestWindowWithNullData extends SparkTestSuite { assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), sparksqlOutputDf, false)) } + test("Test last join to window") { + + val spark = getSparkSession + val sess = new OpenmldbSession(spark) + + val data1 = Seq(Row(1, 1), Row(2, 2)) + val schema1 = StructType(List( + StructField("id", IntegerType), + StructField("val", IntegerType))) + val df1 = spark.createDataFrame(spark.sparkContext.makeRDD(data1), schema1) + + val data2 = Seq(Row(1, 3, 1), Row(1, 3, 2), Row(2, 4, 3)) + val schema2 = StructType(List( + StructField("id", IntegerType), + StructField("key", IntegerType), + StructField("val", IntegerType))) + val df2 = spark.createDataFrame(spark.sparkContext.makeRDD(data2), schema2) + + sess.registerTable("t1", df1) + sess.registerTable("t2", df2) + + val sqlText = + """select t1.id, tx.id as id2, tx.agg from t1 last join + | ( + | select id, val, sum(val) over w as agg from t2 + | window w as (partition by key order by val rows between 3 preceding and current row) + | ) tx order by tx.val + | on t1.id = tx.id + """.stripMargin + + val outputDf = sess.sql(sqlText) + outputDf.show() + + val output = Seq(Row(1, 1, 3), Row(2, 2, 3)) + val output_sc = StructType(List( + StructField("id", IntegerType), + StructField("id2", IntegerType), + StructField("agg", IntegerType))) + val expect_df = spark.createDataFrame(spark.sparkContext.makeRDD(output), output_sc) + expect_df.show() + + assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), expect_df, true)) + } + }