Skip to content

Commit

Permalink
feat(offline): support last join (lastjoin/window)
Browse files Browse the repository at this point in the history
  • Loading branch information
aceforeverd committed Nov 24, 2023
1 parent ae859d4 commit 6f7dac3
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,49 @@ class TestJoinWithNullData extends SparkTestSuite {
assert(outputDf.collect().size == data.size)
}

test("Test last join to last join") {

val spark = getSparkSession
val sess = new OpenmldbSession(spark)

val data1 = Seq(Row(1, 1), Row(2, 2))
val schema1 = StructType(List(
StructField("id", IntegerType),
StructField("val", IntegerType)))
val df1 = spark.createDataFrame(spark.sparkContext.makeRDD(data1), schema1)

val data2 = Seq(Row(1, 3), Row(2, 4))
val schema2 = StructType(List(
StructField("id", IntegerType),
StructField("val", IntegerType)))
val df2 = spark.createDataFrame(spark.sparkContext.makeRDD(data2), schema2)

val data3 = Seq(Row(1, 5))
val schema3 = StructType(List(
StructField("id3", IntegerType),
StructField("val3", IntegerType)))
val df3 = spark.createDataFrame(spark.sparkContext.makeRDD(data3), schema3)

sess.registerTable("t1", df1)
sess.registerTable("t2", df2)
sess.registerTable("t3", df3)

val sqlText =
"""select t1.id, tx.id as id2, tx.val3 from t1 last join
| (select * from t2 last join t3 on t2.id = t3.id3) tx on t1.id = tx.id
""".stripMargin

val outputDf = sess.sql(sqlText)
outputDf.show()

val output = Seq(Row(1, 1, 5), Row(2, 2, null))
val output_sc = StructType(List(
StructField("id", IntegerType),
StructField("id2", IntegerType),
StructField("val3", IntegerType)))
val expect_df = spark.createDataFrame(spark.sparkContext.makeRDD(output), output_sc)
expect_df.show()

assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), expect_df, true))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,48 @@ class TestWindowWithNullData extends SparkTestSuite {
assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), sparksqlOutputDf, false))
}

test("Test last join to window") {

val spark = getSparkSession
val sess = new OpenmldbSession(spark)

val data1 = Seq(Row(1, 1), Row(2, 2))
val schema1 = StructType(List(
StructField("id", IntegerType),
StructField("val", IntegerType)))
val df1 = spark.createDataFrame(spark.sparkContext.makeRDD(data1), schema1)

val data2 = Seq(Row(1, 3, 1), Row(1, 3, 2), Row(2, 4, 3))
val schema2 = StructType(List(
StructField("id", IntegerType),
StructField("key", IntegerType),
StructField("val", IntegerType)))
val df2 = spark.createDataFrame(spark.sparkContext.makeRDD(data2), schema2)

sess.registerTable("t1", df1)
sess.registerTable("t2", df2)

val sqlText =
"""select t1.id, tx.id as id2, tx.agg from t1 last join
| (
| select id, val, sum(val) over w as agg from t2
| window w as (partition by key order by val rows between 3 preceding and current row)
| ) tx order by tx.val
| on t1.id = tx.id
""".stripMargin

val outputDf = sess.sql(sqlText)
outputDf.show()

val output = Seq(Row(1, 1, 3), Row(2, 2, 3))
val output_sc = StructType(List(
StructField("id", IntegerType),
StructField("id2", IntegerType),
StructField("agg", IntegerType)))
val expect_df = spark.createDataFrame(spark.sparkContext.makeRDD(output), output_sc)
expect_df.show()

assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), expect_df, true))
}

}

0 comments on commit 6f7dac3

Please sign in to comment.