Skip to content

Commit

Permalink
fix: window function range offset should be long instead of int (#733)
Browse files Browse the repository at this point in the history
* fix: window function range offset should be long instead of int

* fix error

* fall back to Spark if range offset is not int or long

* uncomment tests

* rebase

* fix offset datatype

* fix data type

* address comments

* throw Err for WindowFrameUnits::Groups

* formatting
  • Loading branch information
huaxingao authored Sep 23, 2024
1 parent 7a6f47f commit 459b2b0
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 33 deletions.
84 changes: 71 additions & 13 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1692,16 +1692,46 @@ impl PhysicalPlanner {
.and_then(|inner| inner.lower_frame_bound_struct.as_ref())
{
Some(l) => match l {
LowerFrameBoundStruct::UnboundedPreceding(_) => {
WindowFrameBound::Preceding(ScalarValue::UInt64(None))
}
LowerFrameBoundStruct::UnboundedPreceding(_) => match units {
WindowFrameUnits::Rows => {
WindowFrameBound::Preceding(ScalarValue::UInt64(None))
}
WindowFrameUnits::Range => {
WindowFrameBound::Preceding(ScalarValue::Int64(None))
}
WindowFrameUnits::Groups => {
return Err(ExecutionError::GeneralError(
"WindowFrameUnits::Groups is not supported.".to_string(),
));
}
},
LowerFrameBoundStruct::Preceding(offset) => {
let offset_value = offset.offset.unsigned_abs() as u64;
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(offset_value)))
let offset_value = offset.offset.abs();
match units {
WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(
Some(offset_value as u64),
)),
WindowFrameUnits::Range => {
WindowFrameBound::Preceding(ScalarValue::Int64(Some(offset_value)))
}
WindowFrameUnits::Groups => {
return Err(ExecutionError::GeneralError(
"WindowFrameUnits::Groups is not supported.".to_string(),
));
}
}
}
LowerFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow,
},
None => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
None => match units {
WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::Int64(None)),
WindowFrameUnits::Groups => {
return Err(ExecutionError::GeneralError(
"WindowFrameUnits::Groups is not supported.".to_string(),
));
}
},
};

let upper_bound: WindowFrameBound = match spark_window_frame
Expand All @@ -1710,15 +1740,43 @@ impl PhysicalPlanner {
.and_then(|inner| inner.upper_frame_bound_struct.as_ref())
{
Some(u) => match u {
UpperFrameBoundStruct::UnboundedFollowing(_) => {
WindowFrameBound::Following(ScalarValue::UInt64(None))
}
UpperFrameBoundStruct::Following(offset) => {
WindowFrameBound::Following(ScalarValue::UInt64(Some(offset.offset as u64)))
}
UpperFrameBoundStruct::UnboundedFollowing(_) => match units {
WindowFrameUnits::Rows => {
WindowFrameBound::Following(ScalarValue::UInt64(None))
}
WindowFrameUnits::Range => {
WindowFrameBound::Following(ScalarValue::Int64(None))
}
WindowFrameUnits::Groups => {
return Err(ExecutionError::GeneralError(
"WindowFrameUnits::Groups is not supported.".to_string(),
));
}
},
UpperFrameBoundStruct::Following(offset) => match units {
WindowFrameUnits::Rows => {
WindowFrameBound::Following(ScalarValue::UInt64(Some(offset.offset as u64)))
}
WindowFrameUnits::Range => {
WindowFrameBound::Following(ScalarValue::Int64(Some(offset.offset)))
}
WindowFrameUnits::Groups => {
return Err(ExecutionError::GeneralError(
"WindowFrameUnits::Groups is not supported.".to_string(),
));
}
},
UpperFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow,
},
None => WindowFrameBound::Following(ScalarValue::UInt64(None)),
None => match units {
WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::Int64(None)),
WindowFrameUnits::Groups => {
return Err(ExecutionError::GeneralError(
"WindowFrameUnits::Groups is not supported.".to_string(),
));
}
},
};

let window_frame = WindowFrame::new_bounds(units, lower_bound, upper_bound);
Expand Down
4 changes: 2 additions & 2 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ message UpperWindowFrameBound {
}

message Preceding {
int32 offset = 1;
int64 offset = 1;
}

message Following {
int32 offset = 1;
int64 offset = 1;
}

message UnboundedPreceding {}
Expand Down
57 changes: 52 additions & 5 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
(None, exprToProto(windowExpr.windowFunction, output))
}

if (aggExpr.isEmpty && builtinFunc.isEmpty) {
return None
}

val f = windowExpr.windowSpec.frameSpecification

val (frameType, lowerBound, upperBound) = f match {
case SpecifiedWindowFrame(frameType, lBound, uBound) =>
val frameProto = frameType match {
case RowFrame => OperatorOuterClass.WindowFrameType.Rows
case RangeFrame =>
withInfo(windowExpr, "Range frame is not supported")
return None
case RangeFrame => OperatorOuterClass.WindowFrameType.Range
}

val lBoundProto = lBound match {
Expand All @@ -278,12 +280,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
.setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build())
.build()
case e =>
val offset = e.eval() match {
case i: Integer => i.toLong
case l: Long => l
case _ => return None
}
OperatorOuterClass.LowerWindowFrameBound
.newBuilder()
.setPreceding(
OperatorOuterClass.Preceding
.newBuilder()
.setOffset(e.eval().asInstanceOf[Int])
.setOffset(offset)
.build())
.build()
}
Expand All @@ -300,12 +307,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
.setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build())
.build()
case e =>
val offset = e.eval() match {
case i: Integer => i.toLong
case l: Long => l
case _ => return None
}

OperatorOuterClass.UpperWindowFrameBound
.newBuilder()
.setFollowing(
OperatorOuterClass.Following
.newBuilder()
.setOffset(e.eval().asInstanceOf[Int])
.setOffset(offset)
.build())
.build()
}
Expand Down Expand Up @@ -2774,6 +2787,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
return None
}

if (partitionSpec.nonEmpty && orderSpec.nonEmpty &&
!validatePartitionAndSortSpecsForWindowFunc(partitionSpec, orderSpec, op)) {
return None
}

val windowExprProto = winExprs.map(windowExprToProto(_, output, op.conf))
val partitionExprs = partitionSpec.map(exprToProto(_, child.output))

Expand Down Expand Up @@ -3280,4 +3298,33 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
true
}
}

private def validatePartitionAndSortSpecsForWindowFunc(
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
op: SparkPlan): Boolean = {
if (partitionSpec.length != orderSpec.length) {
withInfo(op, "Partitioning and sorting specifications do not match")
return false
}

val partitionColumnNames = partitionSpec.collect { case a: AttributeReference =>
a.name
}

val orderColumnNames = orderSpec.collect { case s: SortOrder =>
s.child match {
case a: AttributeReference => a.name
}
}

if (partitionColumnNames.zip(orderColumnNames).exists { case (partCol, orderCol) =>
partCol != orderCol
}) {
withInfo(op, "Partitioning and sorting specifications must be the same.")
return false
}

true
}
}
39 changes: 26 additions & 13 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,22 @@ class CometExecSuite extends CometTestBase {
}
}

test(
"fall back to Spark when the partition spec and order spec are not the same for window function") {
withTempView("test") {
sql("""
|CREATE OR REPLACE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES
| (1, true), (1, false),
|(2, true), (3, false), (4, true) AS test(k, v)
|""".stripMargin)

val df = sql("""
SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg
|""".stripMargin)
checkSparkAnswer(df)
}
}

test("Native window operator should be CometUnaryExec") {
withTempView("testData") {
sql("""
Expand All @@ -164,11 +180,11 @@ class CometExecSuite extends CometTestBase {
|(3, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), null)
|AS testData(val, val_long, val_double, val_date, val_timestamp, cate)
|""".stripMargin)
val df = sql("""
val df1 = sql("""
|SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val ROWS CURRENT ROW)
|FROM testData ORDER BY cate, val
|""".stripMargin)
checkSparkAnswer(df)
checkSparkAnswer(df1)
}
}

Expand All @@ -193,23 +209,21 @@ class CometExecSuite extends CometTestBase {
}
}

test("Window range frame should fall back to Spark") {
test("Window range frame with long boundary should not fail") {
val df =
Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2"))
.toDF("key", "value")

checkAnswer(
checkSparkAnswer(
df.select(
$"key",
count("key").over(
Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))),
Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)))
checkAnswer(
Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))))
checkSparkAnswer(
df.select(
$"key",
count("key").over(
Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))),
Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1)))
Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))))
}

test("Unsupported window expression should fall back to Spark") {
Expand Down Expand Up @@ -1777,10 +1791,9 @@ class CometExecSuite extends CometTestBase {
aggregateFunctions.foreach { function =>
val queries = Seq(
s"SELECT $function OVER() FROM t1",
// TODO: Range frame is not supported yet.
// s"SELECT $function OVER(order by _2) FROM t1",
// s"SELECT $function OVER(order by _2 desc) FROM t1",
// s"SELECT $function OVER(partition by _2 order by _2) FROM t1",
s"SELECT $function OVER(order by _2) FROM t1",
s"SELECT $function OVER(order by _2 desc) FROM t1",
s"SELECT $function OVER(partition by _2 order by _2) FROM t1",
s"SELECT $function OVER(rows between 1 preceding and 1 following) FROM t1",
s"SELECT $function OVER(order by _2 rows between 1 preceding and current row) FROM t1",
s"SELECT $function OVER(order by _2 rows between current row and 1 following) FROM t1")
Expand Down

0 comments on commit 459b2b0

Please sign in to comment.