Skip to content

Commit

Permalink
fix: select constant return empty value (#3825)
Browse files Browse the repository at this point in the history
* fix: select const return empty value

* fix: select const return empty value, add test

* fix: support cast expression when select constant in offline mode

* fix: test case for support cast expression when select constant in offline mode

* fix: modify test case, select constant will return dataframe with 1 row

---------

Co-authored-by: yebowen <[email protected]>
  • Loading branch information
Matagits and Matagits authored Mar 27, 2024
1 parent 5df6d4e commit 4c6406e
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
package com._4paradigm.openmldb.batch.nodes

import com._4paradigm.hybridse.sdk.UnsupportedHybridSeException
import com._4paradigm.hybridse.node.{ConstNode, ExprType, DataType => HybridseDataType}
import com._4paradigm.hybridse.node.{CastExprNode, ConstNode, ExprNode, ExprType, DataType => HybridseDataType}
import com._4paradigm.hybridse.vm.PhysicalConstProjectNode
import com._4paradigm.openmldb.batch.{PlanContext, SparkInstance}
import com._4paradigm.openmldb.batch.utils.{DataTypeUtil, ExpressionUtil}
import org.apache.spark.sql.{Column, SparkSession}
import org.apache.spark.sql.{Column, Row, SparkSession}
import org.apache.spark.sql.functions.{to_date, when}
import org.apache.spark.sql.types.{BooleanType, DateType, DoubleType, FloatType, IntegerType, LongType, ShortType,
StringType, TimestampType}
StringType, StructField, StructType, TimestampType}

import java.sql.Timestamp
import scala.collection.JavaConverters.asScalaBufferConverter

Expand All @@ -44,29 +45,43 @@ object ConstProjectPlan {
// Get the select columns
val selectColList = (0 until node.project().size.toInt).map(i => {
val expr = node.project().GetExpr(i)
expr.GetExprType() match {
case ExprType.kExprPrimary =>
val constNode = ConstNode.CastFrom(expr)
val outputColName = outputColNameList(i)

// Create simple literal Spark column
val column = ExpressionUtil.constExprToSparkColumn(constNode)

// Match column type for output type
castSparkOutputCol(ctx.getSparkSession, column, constNode.GetDataType(), outputColTypeList(i))
.alias(outputColName)
val (column, innerType) = createSparkColumn(ctx.getSparkSession, expr)

case _ => throw new UnsupportedHybridSeException(
s"Should not handle non-const column for const project node")
}
// Match column type for output type
castSparkOutputCol(ctx.getSparkSession, column, innerType, outputColTypeList(i))
.alias(outputColNameList(i))
})

// Use Spark DataFrame to select columns
val result = ctx.getSparkSession.emptyDataFrame.select(selectColList: _*)
val sparkSession = ctx.getSparkSession
val originDf = sparkSession.createDataFrame(sparkSession.sparkContext.parallelize(Seq(Row("1"))),
StructType(List(StructField("1", StringType))))
val result = originDf.select(selectColList: _*)

SparkInstance.createConsideringIndex(ctx, node.GetNodeId(), result)
}

def createSparkColumn(spark: SparkSession,
expr: ExprNode): (Column, HybridseDataType) = {
expr.GetExprType() match {
case ExprType.kExprPrimary =>
val constNode = ConstNode.CastFrom(expr)

// Create simple literal Spark column
ExpressionUtil.constExprToSparkColumn(constNode) -> constNode.GetDataType

case ExprType.kExprCast =>
val cast = CastExprNode.CastFrom(expr)
val castType = cast.getCast_type_
val (childCol, childType) = createSparkColumn(spark, cast.GetChild(0))
val castColumn = castSparkOutputCol(spark, childCol, childType, castType)
castColumn -> castType

case _ => throw new UnsupportedHybridSeException(
s"Should not handle non-const column for const project node")
}
}

def stringToTimestamp: String => Timestamp = (input: String) => {
if (input == null) {
null.asInstanceOf[Timestamp]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

package com._4paradigm.openmldb.batch

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import com._4paradigm.openmldb.batch.utils.SparkUtil
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.{DateType, DoubleType, IntegerType, StringType, StructField, StructType}

import java.sql.Date
import scala.collection.JavaConverters.seqAsJavaListConverter
import scala.collection.mutable



Expand Down Expand Up @@ -47,4 +50,26 @@ class TestProjectPlan extends SparkTestSuite {
val output = res.getDf()
output.show()
}

test("Test const project") {
val sess = getSparkSession

val schema = StructType(Seq(
StructField("1", IntegerType, nullable = false),
StructField("3.500000", DoubleType, nullable = false),
StructField("a", StringType, nullable = false),
StructField("date(2024-03-25)", DateType, nullable = true),
StructField("string(int32(int64(1)))", StringType, nullable = false)
))
val expectDf = sess.createDataFrame(Seq(
(1, 3.5d, "a", Date.valueOf("2024-03-25"), "1")
).map(Row.fromTuple(_)).asJava, schema)

val planner = new SparkPlanner(sess)
val res = planner.plan("select 1, 3.5, \"a\", date('2024-03-25'), string(int(bigint(1)));",
mutable.HashMap[String, mutable.Map[String, DataFrame]]())
val output = res.getDf()

assert(SparkUtil.approximateDfEqual(expectDf, output))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ class TestSelectInto extends SparkTestSuite {
sess.sql(sqlText)

val csvDf = spark.read.option("header", true).csv(csvFilePath)
assert(csvDf.isEmpty)
assert(csvDf.schema.size == 1)
assert(csvDf.schema.fields(0).name.equals("1"))
assert(csvDf.count() == 1)
}

}

0 comments on commit 4c6406e

Please sign in to comment.