Skip to content

Commit

Permalink
Support optional filtering of null values from RAW (#986)
Browse files Browse the repository at this point in the history
* Support optional filtering of null values from RAW

This change exposes the Raw-API support for filtering nulls server
side through the cdp data source. It does not impact schema
interference, as that is still done without interference.

* sbt +{,macroSub/,structType/}dependencyLockWrite

* Rename exposed option for enabling RAW API null filtering

* Add new option to README.md for RAW

* Add tests for current null behavior and both explicitly disabled and unset

* testcompile

* testcompile

---------

Co-authored-by: Dmitry Ivankov <[email protected]>
  • Loading branch information
thorkildcognite and dmivankov authored Nov 14, 2024
1 parent b67d7ca commit 4e12657
Show file tree
Hide file tree
Showing 11 changed files with 128 additions and 24 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,8 @@ Optionally, you can have Spark infer the DataFrame schema with the following opt

- `rawEnsureParent`: When set to true, the parent database and table will be creates if it does not exists already.

- `filterNullFieldsOnNonSchemaRawQueries`: Set this to `"true"`to enable experimental support for filtering empty columns server side in the Raw API, without impacting the inferred schema. Aimed to become enabled by default in the future once it has been fully tested.

```scala
val df = spark.read.format("cognite.spark.v1")
.option("type", "raw")
Expand Down
4 changes: 2 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ val circeVersion = "0.14.9"
val sttpVersion = "3.5.2"
val natchezVersion = "0.3.1"
val Specs2Version = "4.20.3"
val cogniteSdkVersion = "2.30.860"
val cogniteSdkVersion = "2.31.861"

val prometheusVersion = "0.16.0"
val log4sVersion = "1.10.0"
Expand Down Expand Up @@ -41,7 +41,7 @@ lazy val commonSettings = Seq(
organization := "com.cognite.spark.datasource",
organizationName := "Cognite",
organizationHomepage := Some(url("https://cognite.com")),
version := "3.20." + patchVersion,
version := "3.21." + patchVersion,
isSnapshot := patchVersion.endsWith("-SNAPSHOT"),
crossScalaVersions := supportedScalaVersions,
semanticdbEnabled := true,
Expand Down
6 changes: 3 additions & 3 deletions build.scala-2.12.sbt.lock
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"lockVersion" : 1,
"timestamp" : "2024-10-31T16:04:08.459482Z",
"timestamp" : "2024-11-13T11:30:40.138510863Z",
"configurations" : [
"compile",
"optional",
Expand Down Expand Up @@ -94,11 +94,11 @@
{
"org" : "com.cognite",
"name" : "cognite-sdk-scala_2.12",
"version" : "2.30.860",
"version" : "2.31.861",
"artifacts" : [
{
"name" : "cognite-sdk-scala_2.12.jar",
"hash" : "sha1:559a261df9d7bd698ec97dc45ae416870ecd2f62"
"hash" : "sha1:ae52606e613a31e1979830a6036f265328033908"
}
],
"configurations" : [
Expand Down
6 changes: 3 additions & 3 deletions build.scala-2.13.sbt.lock
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"lockVersion" : 1,
"timestamp" : "2024-10-31T16:04:15.967563Z",
"timestamp" : "2024-11-13T11:30:56.507065520Z",
"configurations" : [
"compile",
"optional",
Expand Down Expand Up @@ -94,11 +94,11 @@
{
"org" : "com.cognite",
"name" : "cognite-sdk-scala_2.13",
"version" : "2.30.860",
"version" : "2.31.861",
"artifacts" : [
{
"name" : "cognite-sdk-scala_2.13.jar",
"hash" : "sha1:2712efe507294923d9cc5ca2998f3d452096ac14"
"hash" : "sha1:245bd96891942232ba8d5a8e2c7ff1714cf2c230"
}
],
"configurations" : [
Expand Down
6 changes: 3 additions & 3 deletions macro/build.scala-2.12.sbt.lock
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"lockVersion" : 1,
"timestamp" : "2024-10-31T16:04:18.098577Z",
"timestamp" : "2024-11-13T11:30:59.215321253Z",
"configurations" : [
"compile",
"optional",
Expand Down Expand Up @@ -75,11 +75,11 @@
{
"org" : "com.cognite",
"name" : "cognite-sdk-scala_2.12",
"version" : "2.30.860",
"version" : "2.31.861",
"artifacts" : [
{
"name" : "cognite-sdk-scala_2.12.jar",
"hash" : "sha1:559a261df9d7bd698ec97dc45ae416870ecd2f62"
"hash" : "sha1:ae52606e613a31e1979830a6036f265328033908"
}
],
"configurations" : [
Expand Down
6 changes: 3 additions & 3 deletions macro/build.scala-2.13.sbt.lock
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"lockVersion" : 1,
"timestamp" : "2024-10-31T16:04:19.995959Z",
"timestamp" : "2024-11-13T11:31:01.766754723Z",
"configurations" : [
"compile",
"optional",
Expand Down Expand Up @@ -75,11 +75,11 @@
{
"org" : "com.cognite",
"name" : "cognite-sdk-scala_2.13",
"version" : "2.30.860",
"version" : "2.31.861",
"artifacts" : [
{
"name" : "cognite-sdk-scala_2.13.jar",
"hash" : "sha1:2712efe507294923d9cc5ca2998f3d452096ac14"
"hash" : "sha1:245bd96891942232ba8d5a8e2c7ff1714cf2c230"
}
],
"configurations" : [
Expand Down
4 changes: 3 additions & 1 deletion src/main/scala/cognite/spark/v1/DefaultSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,9 @@ object DefaultSource {
rawEnsureParent = toBoolean(parameters, "rawEnsureParent", defaultValue = true),
enableSinglePartitionDeleteAssetHierarchy = enableSinglePartitionDeleteAssetHierarchy,
tracingParent = extractTracingHeadersKernel(parameters),
useSharedThrottle = toBoolean(parameters, "useSharedThrottle", defaultValue = false)
useSharedThrottle = toBoolean(parameters, "useSharedThrottle", defaultValue = false),
serverSideFilterNullValuesOnNonSchemaRawQueries =
toBoolean(parameters, "filterNullFieldsOnNonSchemaRawQueries", defaultValue = false)
)
}

Expand Down
20 changes: 15 additions & 5 deletions src/main/scala/cognite/spark/v1/RawTableRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class RawTableRelation(
filter = RawRowFilter(),
requestedKeys = None,
schema = None,
filterNulls = false,
collectMetrics = collectSchemaInferenceMetrics,
collectTestMetrics = false
)
Expand All @@ -83,15 +84,17 @@ class RawTableRelation(
}
}

private def getStreams(filter: RawRowFilter, cursors: Vector[String])(
private def getStreams(filter: RawRowFilter, filterNullFields: Boolean, cursors: Vector[String])(
limit: Option[Int],
numPartitions: Int)(client: GenericClient[IO]): Seq[Stream[IO, RawRow]] = {
assert(numPartitions == cursors.length)
val rawClient = client.rawRows(database, table)
val rawClient = client.rawRows(database, table, filterNullFields)
cursors.map(rawClient.filterOnePartition(filter, _, limit))
}

private def getStreamByKeys(client: GenericClient[IO], keys: Set[String]): Stream[IO, RawRow] = {
// Note that retrieveByKey does not currently support filtering out null fields. When/If that is
// added, we should also pass in the flag to filter out those here.
val rawClient = client.rawRows(database, table)
Stream
.emits(keys.toSeq)
Expand Down Expand Up @@ -127,6 +130,7 @@ class RawTableRelation(
filter: RawRowFilter,
requestedKeys: Option[Set[String]],
schema: Option[StructType],
filterNulls: Boolean,
collectMetrics: Boolean = config.collectMetrics,
collectTestMetrics: Boolean = config.collectTestMetrics): RDD[Row] = {
val configWithLimit =
Expand All @@ -142,11 +146,11 @@ class RawTableRelation(
val partitionCursors =
CdpConnector
.clientFromConfig(config)
.rawRows(database, table)
.rawRows(database, table, filterNulls)
.getPartitionCursors(filter, configWithLimit.partitions)
.unsafeRunSync()
.toVector
getStreams(filter, partitionCursors)(
getStreams(filter, filterNulls, partitionCursors)(
configWithLimit.limitPerPartition,
configWithLimit.partitions)
}
Expand Down Expand Up @@ -203,7 +207,13 @@ class RawTableRelation(
}

val rdd =
readRows(config.limitPerPartition, None, rawRowFilter, requestedKeys, jsonSchema)
readRows(
config.limitPerPartition,
None,
rawRowFilter,
requestedKeys,
jsonSchema,
config.serverSideFilterNullValuesOnNonSchemaRawQueries)

rdd.map(row => {
val filteredCols = requiredColumns.map(colName => row.get(schema.fieldIndex(colName)))
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/cognite/spark/v1/RelationConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ final case class RelationConfig(
enableSinglePartitionDeleteAssetHierarchy: Boolean, // flag to test whether single partition helps avoid NPE in asset hierarchy builder
tracingParent: Kernel,
initialRetryDelayMillis: Int,
useSharedThrottle: Boolean
useSharedThrottle: Boolean,
serverSideFilterNullValuesOnNonSchemaRawQueries: Boolean,
) {

/** Desired number of Spark partitions ~= partitions / parallelismPerPartition */
Expand Down
92 changes: 90 additions & 2 deletions src/test/scala/cognite/spark/v1/RawTableRelationTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import cats.implicits._
import cognite.spark.v1.CdpConnector.ioRuntime
import com.cognite.sdk.scala.common.CdpApiException
import com.cognite.sdk.scala.v1.{RawDatabase, RawRow, RawTable}
import io.circe.Json
import io.circe.{Json, JsonObject}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -117,6 +117,16 @@ class RawTableRelationTest
RawRow("k2", Map("bool" -> Json.fromBoolean(java.lang.Boolean.parseBoolean("true")))),
RawRow("k3", Map("bool" -> Json.fromBoolean(false)))
)
private val dataWithNullFieldValue = Seq(
RawRow("k1", Map("toBeFiltered" -> Json.Null)),
RawRow("k2", Map("toBeFiltered" -> Json.Null, "notFiltered" -> Json.fromString("string"))),
RawRow("k3", Map("toBeFiltered" -> Json.fromString("but not here"), "notFiltered" -> Json.fromString("string2")))
)
private val dataWithEmptyColumn = Seq(
RawRow("k1", Map("toBeFiltered" -> Json.Null, "notFiltered" -> Json.fromString("string_k1"))),
RawRow("k2", Map("toBeFiltered" -> Json.Null, "notFiltered" -> Json.fromString("string_k2"))),
RawRow("k3", Map("toBeFiltered" -> Json.Null, "notFiltered" -> Json.fromString("string_k3"))),
)

case class TestTable(name: String, data: Seq[RawRow])
case class TestData(dbName: String, tables: Seq[TestTable])
Expand All @@ -137,6 +147,8 @@ class RawTableRelationTest
TestTable("with-long-empty-str", dataWithEmptyStringInLongField),
TestTable("with-number-empty-str", dataWithEmptyStringInDoubleField),
TestTable("with-boolean-empty-str", dataWithEmptyStringInBooleanField),
TestTable("with-some-null-values", dataWithNullFieldValue),
TestTable("with-only-null-values-for-field", dataWithEmptyColumn),
TestTable("cryptoAssets", (1 to 500).map(i =>
RawRow(i.toString, Map("i" -> Json.fromString("exist")))
)),
Expand Down Expand Up @@ -237,7 +249,8 @@ class RawTableRelationTest
table: String,
database: String = "spark-test-database",
inferSchema: Boolean = true,
metricsPrefix: Option[String] = None): DataFrame = {
metricsPrefix: Option[String] = None,
filterNullFields: Option[Boolean] = None): DataFrame = {
val df = spark.read
.format(DefaultSource.sparkFormatString)
.useOIDCWrite
Expand All @@ -248,6 +261,8 @@ class RawTableRelationTest
.option("inferSchema", inferSchema)
.option("inferSchemaLimit", "100")

filterNullFields.foreach(v => df.option("filterNullFieldsOnNonSchemaRawQueries", v.toString))

metricsPrefix match {
case Some(prefix) =>
df.option("collectMetrics", "true")
Expand Down Expand Up @@ -894,4 +909,77 @@ class RawTableRelationTest
err.getMessage shouldBe "Error while loading RAW row [key='k'] in column 'value': java.lang.NumberFormatException: For input string: \"test\""

}

it should "filter out fields with null value but not impact schema inference" in {
val tableName = "with-some-null-values"
val df = rawRead(table = tableName, database = testData.dbName, inferSchema = true, filterNullFields = Some(true))
df.count() shouldBe 3
df.schema.fieldNames.toSet shouldBe Set("key", "lastUpdatedTime", "notFiltered", "toBeFiltered")
val items = RawJsonConverter.rowsToRawItems(df.columns, "key", df.collect().toSeq).map(r => (r.key, r.columns)).toMap
items("k1")("toBeFiltered") shouldBe Json.Null
items("k2")("toBeFiltered") shouldBe Json.Null
items("k2")("notFiltered") shouldBe Json.fromString("string")
items("k3")("toBeFiltered") shouldBe Json.fromString("but not here")
items("k3")("notFiltered") shouldBe Json.fromString("string2")
}

it should "filter out columns completely when not inferring schema (confirming it filters from RAW)" in {
val tableName = "with-some-null-values"
val df = rawRead(table = tableName, database = testData.dbName, inferSchema = false, filterNullFields = Some(true))
df.count() shouldBe 3
val items = RawJsonConverter.rowsToRawItems(df.columns, "key", df.collect().toSeq).map(r => (r.key, r.columns)).toMap
items("k1")("columns") shouldBe Json.fromString("{}")
items("k2")("columns") shouldBe Json.fromString("{\"notFiltered\":\"string\"}")
val columnsParsed: JsonObject = parseColumns(items("k3"))
columnsParsed("notFiltered") shouldBe Some(Json.fromString("string2"))
columnsParsed("toBeFiltered") shouldBe Some(Json.fromString("but not here"))
}

it should "return column in schema, even if every row has it filtered out" in {
val tableName = "with-only-null-values-for-field"
val df = rawRead(table = tableName, database = testData.dbName, inferSchema = true, filterNullFields = Some(true))
df.count() shouldBe 3
df.schema.fieldNames.toSet shouldBe Set("key", "lastUpdatedTime", "notFiltered", "toBeFiltered")
}

it should "not filter out null column values when filtering is not set" in {
val tableName = "with-some-null-values"
// We run this without inferSchema, as schema would hide whether the fields are filtered or not.
val df = rawRead(table = tableName, database = testData.dbName, inferSchema = false)
df.count() shouldBe 3
validateWhenFilteringIsNotEnabled(df)
}

it should "not filter out null column values when filtering is explicitly disabled" in {
val tableName = "with-some-null-values"
// We run this without inferSchema, as schema would hide whether the fields are filtered or not.
val df = rawRead(table = tableName, database = testData.dbName, inferSchema = false, filterNullFields = Some(false))
df.count() shouldBe 3
validateWhenFilteringIsNotEnabled(df)
}

private def validateWhenFilteringIsNotEnabled(df: DataFrame): Unit = {
val rows: Map[String, Map[String, Json]] = RawJsonConverter.rowsToRawItems(df.columns, "key", df.collect().toSeq)
.map(r => (r.key, r.columns))
.toMap

rows("k1")("columns") shouldBe Json.fromString("{\"toBeFiltered\":null}")

val columnsParsedk2: JsonObject = parseColumns(rows("k2"))
columnsParsedk2("toBeFiltered") shouldBe Some(Json.Null)
columnsParsedk2("notFiltered") shouldBe Some(Json.fromString("string"))

val columnsParsedk3 = parseColumns(rows("k3"))
columnsParsedk3("toBeFiltered") shouldBe Some(Json.fromString("but not here"))
columnsParsedk3("notFiltered") shouldBe Some(Json.fromString("string2"))
()
}

private def parseColumns(row: Map[String, Json]): JsonObject = {
io.circe.parser.parse(row("columns").asString.get) match {
case Right(json) => json.asObject.get
case Left(error) => throw error
}
}

}
3 changes: 2 additions & 1 deletion src/test/scala/cognite/spark/v1/SparkTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ trait SparkTest {
rawEnsureParent = false,
enableSinglePartitionDeleteAssetHierarchy = false,
tracingParent = new Kernel(Map.empty),
useSharedThrottle = false
useSharedThrottle = false,
serverSideFilterNullValuesOnNonSchemaRawQueries = false
)

private def getCounterSafe(metricsNamespace: String, resource: String): Option[Long] = {
Expand Down

0 comments on commit 4e12657

Please sign in to comment.