diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/ConstProjectPlan.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/ConstProjectPlan.scala index e71183980d7..48927ef3fb4 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/ConstProjectPlan.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/ConstProjectPlan.scala @@ -17,14 +17,15 @@ package com._4paradigm.openmldb.batch.nodes import com._4paradigm.hybridse.sdk.UnsupportedHybridSeException -import com._4paradigm.hybridse.node.{ConstNode, ExprType, DataType => HybridseDataType} +import com._4paradigm.hybridse.node.{CastExprNode, ConstNode, ExprNode, ExprType, DataType => HybridseDataType} import com._4paradigm.hybridse.vm.PhysicalConstProjectNode import com._4paradigm.openmldb.batch.{PlanContext, SparkInstance} import com._4paradigm.openmldb.batch.utils.{DataTypeUtil, ExpressionUtil} -import org.apache.spark.sql.{Column, SparkSession} +import org.apache.spark.sql.{Column, Row, SparkSession} import org.apache.spark.sql.functions.{to_date, when} import org.apache.spark.sql.types.{BooleanType, DateType, DoubleType, FloatType, IntegerType, LongType, ShortType, - StringType, TimestampType} + StringType, StructField, StructType, TimestampType} + import java.sql.Timestamp import scala.collection.JavaConverters.asScalaBufferConverter @@ -44,29 +45,43 @@ object ConstProjectPlan { // Get the select columns val selectColList = (0 until node.project().size.toInt).map(i => { val expr = node.project().GetExpr(i) - expr.GetExprType() match { - case ExprType.kExprPrimary => - val constNode = ConstNode.CastFrom(expr) - val outputColName = outputColNameList(i) - - // Create simple literal Spark column - val column = ExpressionUtil.constExprToSparkColumn(constNode) - - // Match column type for output type - castSparkOutputCol(ctx.getSparkSession, column, constNode.GetDataType(), outputColTypeList(i)) - .alias(outputColName) + val (column, innerType) = createSparkColumn(ctx.getSparkSession, expr) - case _ => throw new UnsupportedHybridSeException( - s"Should not handle non-const column for const project node") - } + // Match column type for output type + castSparkOutputCol(ctx.getSparkSession, column, innerType, outputColTypeList(i)) + .alias(outputColNameList(i)) }) // Use Spark DataFrame to select columns - val result = ctx.getSparkSession.emptyDataFrame.select(selectColList: _*) + val sparkSession = ctx.getSparkSession + val originDf = sparkSession.createDataFrame(sparkSession.sparkContext.parallelize(Seq(Row("1"))), + StructType(List(StructField("1", StringType)))) + val result = originDf.select(selectColList: _*) SparkInstance.createConsideringIndex(ctx, node.GetNodeId(), result) } + def createSparkColumn(spark: SparkSession, + expr: ExprNode): (Column, HybridseDataType) = { + expr.GetExprType() match { + case ExprType.kExprPrimary => + val constNode = ConstNode.CastFrom(expr) + + // Create simple literal Spark column + ExpressionUtil.constExprToSparkColumn(constNode) -> constNode.GetDataType + + case ExprType.kExprCast => + val cast = CastExprNode.CastFrom(expr) + val castType = cast.getCast_type_ + val (childCol, childType) = createSparkColumn(spark, cast.GetChild(0)) + val castColumn = castSparkOutputCol(spark, childCol, childType, castType) + castColumn -> castType + + case _ => throw new UnsupportedHybridSeException( + s"Should not handle non-const column for const project node") + } + } + def stringToTimestamp: String => Timestamp = (input: String) => { if (input == null) { null.asInstanceOf[Timestamp] diff --git a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/TestProjectPlan.scala b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/TestProjectPlan.scala index 67fb9a7070a..a50bb6ca56c 100644 --- a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/TestProjectPlan.scala +++ b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/TestProjectPlan.scala @@ -16,10 +16,13 @@ package com._4paradigm.openmldb.batch -import org.apache.spark.sql.Row -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import com._4paradigm.openmldb.batch.utils.SparkUtil +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.types.{DateType, DoubleType, IntegerType, StringType, StructField, StructType} +import java.sql.Date import scala.collection.JavaConverters.seqAsJavaListConverter +import scala.collection.mutable @@ -47,4 +50,26 @@ class TestProjectPlan extends SparkTestSuite { val output = res.getDf() output.show() } + + test("Test const project") { + val sess = getSparkSession + + val schema = StructType(Seq( + StructField("1", IntegerType, nullable = false), + StructField("3.500000", DoubleType, nullable = false), + StructField("a", StringType, nullable = false), + StructField("date(2024-03-25)", DateType, nullable = true), + StructField("string(int32(int64(1)))", StringType, nullable = false) + )) + val expectDf = sess.createDataFrame(Seq( + (1, 3.5d, "a", Date.valueOf("2024-03-25"), "1") + ).map(Row.fromTuple(_)).asJava, schema) + + val planner = new SparkPlanner(sess) + val res = planner.plan("select 1, 3.5, \"a\", date('2024-03-25'), string(int(bigint(1)));", + mutable.HashMap[String, mutable.Map[String, DataFrame]]()) + val output = res.getDf() + + assert(SparkUtil.approximateDfEqual(expectDf, output)) + } } diff --git a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestSelectInto.scala b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestSelectInto.scala index f6e2785f1a5..55a78a81975 100644 --- a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestSelectInto.scala +++ b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestSelectInto.scala @@ -34,9 +34,9 @@ class TestSelectInto extends SparkTestSuite { sess.sql(sqlText) val csvDf = spark.read.option("header", true).csv(csvFilePath) - assert(csvDf.isEmpty) assert(csvDf.schema.size == 1) assert(csvDf.schema.fields(0).name.equals("1")) + assert(csvDf.count() == 1) } }