diff --git a/NOTICE b/NOTICE
new file mode 100644
index 00000000000..0662cf42b17
--- /dev/null
+++ b/NOTICE
@@ -0,0 +1,13 @@
+RAPIDS plugin for Apache Spark
+Copyright (c) 2019, NVIDIA CORPORATION
+
+// ------------------------------------------------------------------
+// NOTICE file corresponding to the section 4d of The Apache License,
+// Version 2.0, in this case for
+// ------------------------------------------------------------------
+
+Apache Spark
+Copyright 2014 and onwards The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
diff --git a/NOTICE-binary b/NOTICE-binary
new file mode 100644
index 00000000000..ad40699ab14
--- /dev/null
+++ b/NOTICE-binary
@@ -0,0 +1,24 @@
+RAPIDS plugin for Apache Spark
+Copyright (c) 2019, NVIDIA CORPORATION
+
+// ------------------------------------------------------------------
+// NOTICE file corresponding to the section 4d of The Apache License,
+// Version 2.0, in this case for
+// ------------------------------------------------------------------
+
+Apache Spark
+Copyright 2014 and onwards The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+---------------------------------------------------------------------
+
+Apache ORC
+Copyright 2013-2019 The Apache Software Foundation
+
+This product includes software developed by The Apache Software
+Foundation (http://www.apache.org/).
+
+This product includes software developed by Hewlett-Packard:
+(c) Copyright [2014-2015] Hewlett-Packard Development Company, L.P
diff --git a/README.md b/README.md
index 89e862ad7b7..bd49700599a 100644
--- a/README.md
+++ b/README.md
@@ -17,3 +17,8 @@ enable it again.
> spark-shell --jars 'rapids-4-spark-0.8-SNAPSHOT.jar,cudf-0.8-SNAPSHOT-cuda10.jar' --conf spark.sql.extensions=ai.rapids.spark.Plugin
```
+## Notes on Building
+
+The build requires Apache Spark 3.0+ built against the nohive ORC classifier. Building with a
+version of Spark built without the nohive support (e.g.: the Apache Spark hadoop-3.2 profile)
+will cause errors during build and test due to the conflicting ORC jars.
diff --git a/pom.xml b/pom.xml
index e65f160690c..8425142d5e1 100644
--- a/pom.xml
+++ b/pom.xml
@@ -20,6 +20,9 @@
cuda10
2.12
2.12.8
+ 1.5.5
+ nohive
+ ai.rapids.spark.shaded
@@ -43,9 +46,76 @@
1.1.1
- org.apache.spark
- spark-sql_${scala.binary.version}
- 3.0.0-SNAPSHOT
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ 3.0.0-SNAPSHOT
+
+
+ org.apache.orc
+ orc-core
+ ${orc.version}
+ ${orc.classifier}
+
+
+ javax.xml.bind
+ jaxb-api
+
+
+ org.apache.hadoop
+ hadoop-common
+
+
+ org.apache.hadoop
+ hadoop-hdfs
+
+
+ org.apache.hive
+ hive-storage-api
+
+
+
+
+ org.apache.orc
+ orc-mapreduce
+ ${orc.version}
+ ${orc.classifier}
+
+
+ org.apache.hadoop
+ hadoop-common
+
+
+ org.apache.hadoop
+ hadoop-mapreduce-client-core
+
+
+ org.apache.orc
+ orc-core
+
+
+ org.apache.hive
+ hive-storage-api
+
+
+ com.esotericsoftware
+ kryo-shaded
+
+
+
+
+ org.apache.orc
+ orc-shims
+ ${orc.version}
+
+
+ org.apache.hadoop
+ hadoop-common
+
+
+ org.apache.hadoop
+ hadoop-hdfs
+
+
+ LICENSE
+
+
+
+
+ maven-antrun-plugin
+
+
+ copy-notice
+
+ run
+
+ process-resources
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-shade-plugin
+
+ false
+ true
+
+
+ org.apache.orc:orc-core
+ org.apache.orc:orc-mapreduce
+ org.apache.orc:orc-shims
+
+
+
+
+ org.apache.orc
+ ${rapids.shade.package}.orc
+
+
+ com.google.protobuf25
+ ${rapids.shade.package}.protobuf25
+
+
+
+
+
+ package
+
+ shade
+
+
+
+
org.apache.maven.plugins
diff --git a/sql-plugin/src/main/scala/ai/rapids/spark/GpuOrcScan.scala b/sql-plugin/src/main/scala/ai/rapids/spark/GpuOrcScan.scala
index ab430f93de1..1271813081b 100644
--- a/sql-plugin/src/main/scala/ai/rapids/spark/GpuOrcScan.scala
+++ b/sql-plugin/src/main/scala/ai/rapids/spark/GpuOrcScan.scala
@@ -48,6 +48,7 @@ import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory
import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.sql.execution.datasources.orc.OrcUtils
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.rapids.OrcFilters
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.types.StructType
@@ -68,10 +69,14 @@ class GpuOrcScan(
readDataSchema, readPartitionSchema, options, pushedFilters) with GpuScan {
override def createReaderFactory(): PartitionReaderFactory = {
+ // Unset any serialized search argument setup by Spark's OrcScanBuilder as
+ // it will be incompatible due to shading and potential ORC classifier mismatch.
+ hadoopConf.unset(OrcConf.KRYO_SARG.getAttribute)
+
val broadcastedConf = sparkSession.sparkContext.broadcast(
new GpuSerializableConfiguration(hadoopConf))
GpuOrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
- dataSchema, readDataSchema, readPartitionSchema, rapidsConf)
+ dataSchema, readDataSchema, readPartitionSchema, pushedFilters, rapidsConf)
}
}
@@ -92,6 +97,7 @@ case class GpuOrcPartitionReaderFactory(
dataSchema: StructType,
readDataSchema: StructType,
partitionSchema: StructType,
+ pushedFilters: Array[Filter],
@transient rapidsConf: RapidsConf) extends FilePartitionReaderFactory {
private val isCaseSensitive = sqlConf.caseSensitiveAnalysis
private val debugDumpPrefix = rapidsConf.orcDebugDumpPrefix
@@ -108,8 +114,9 @@ case class GpuOrcPartitionReaderFactory(
OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, orcSchemaString)
OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive)
+ val fullSchema = StructType(dataSchema ++ partitionSchema)
val reader = new GpuOrcPartitionReader(conf, partFile, dataSchema, readDataSchema,
- debugDumpPrefix)
+ fullSchema, pushedFilters, debugDumpPrefix)
ColumnarPartitionReaderWithPartitionValues.newReader(partFile, reader, partitionSchema)
}
}
@@ -155,6 +162,8 @@ class GpuOrcPartitionReader(
partFile: PartitionedFile,
dataSchema: StructType,
readDataSchema: StructType,
+ fullSchema: StructType,
+ pushedFilters: Array[Filter],
debugDumpPrefix: String) extends PartitionReader[ColumnarBatch] with Logging {
private var batch: Option[ColumnarBatch] = None
private var isExhausted: Boolean = false
@@ -598,6 +607,11 @@ class GpuOrcPartitionReader(
splitStripes: Seq[StripeInformation],
useUTCTimestamp: Boolean): Option[Table] = {
val readerOpts = OrcInputFormat.buildOptions(conf, orcReader, partFile.start, partFile.length)
+ // create the search argument if we have pushed filters
+ OrcFilters.createFilter(fullSchema, pushedFilters).foreach { f =>
+ readerOpts.searchArgument(f, fullSchema.fieldNames)
+ }
+
val updatedReadSchema = checkSchemaCompatibility(orcReader.getSchema, readerOpts.getSchema,
readerOpts.getIsSchemaEvolutionCaseAware)
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFilters.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFilters.scala
new file mode 100644
index 00000000000..298efebd5a7
--- /dev/null
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFilters.scala
@@ -0,0 +1,191 @@
+/*
+ * Copyright (c) 2019, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.rapids
+
+import org.apache.orc.storage.common.`type`.HiveDecimal
+import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument}
+import org.apache.orc.storage.ql.io.sarg.SearchArgument.Builder
+import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory.newBuilder
+import org.apache.orc.storage.serde2.io.HiveDecimalWritable
+
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types._
+import org.apache.spark.SparkException
+import org.apache.spark.sql.execution.datasources.orc.OrcFiltersBase
+
+// This is derived from Apache Spark's OrcFilters code to avoid calling the
+// Spark version. Spark's version can potentially create a search argument
+// applier object that is incompatible with the orc:nohive jar that has been
+// shaded as part of this project.
+//
+// The use of Spark's OrcFiltersBase is safe since it is explicitly designed
+// to be reusable across different ORC jar classifiers.
+object OrcFilters extends OrcFiltersBase {
+
+ /**
+ * Create ORC filter as a SearchArgument instance.
+ *
+ * NOTE: These filters should be pre-filtered by Spark to only contain the
+ * filters convertible to ORC, so checking what is convertible is
+ * not necessary here.
+ */
+ def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = {
+ val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap
+ // Combines all filters using `And` to produce a single conjunction
+ val conjunctionOptional = buildTree(filters)
+ conjunctionOptional.map { conjunction =>
+ // Then tries to build a single ORC `SearchArgument` for the conjunction predicate.
+ // The input predicate is fully convertible. There should not be any empty result in the
+ // following recursive method call `buildSearchArgument`.
+ buildSearchArgument(dataTypeMap, conjunction, newBuilder).build()
+ }
+ }
+
+ /**
+ * Get PredicateLeafType which is corresponding to the given DataType.
+ */
+ private def getPredicateLeafType(dataType: DataType) = dataType match {
+ case BooleanType => PredicateLeaf.Type.BOOLEAN
+ case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG
+ case FloatType | DoubleType => PredicateLeaf.Type.FLOAT
+ case StringType => PredicateLeaf.Type.STRING
+ case DateType => PredicateLeaf.Type.DATE
+ case TimestampType => PredicateLeaf.Type.TIMESTAMP
+ case _: DecimalType => PredicateLeaf.Type.DECIMAL
+ case _ => throw new UnsupportedOperationException(s"DataType: ${dataType.catalogString}")
+ }
+
+ /**
+ * Cast literal values for filters.
+ *
+ * We need to cast to long because ORC raises exceptions
+ * at 'checkLiteralType' of SearchArgumentImpl.java.
+ */
+ private def castLiteralValue(value: Any, dataType: DataType): Any = dataType match {
+ case ByteType | ShortType | IntegerType | LongType =>
+ value.asInstanceOf[Number].longValue
+ case FloatType | DoubleType =>
+ value.asInstanceOf[Number].doubleValue()
+ case _: DecimalType =>
+ new HiveDecimalWritable(HiveDecimal.create(value.asInstanceOf[java.math.BigDecimal]))
+ case _ => value
+ }
+
+ /**
+ * Build a SearchArgument and return the builder so far.
+ *
+ * @param dataTypeMap a map from the attribute name to its data type.
+ * @param expression the input predicates, which should be fully convertible to SearchArgument.
+ * @param builder the input SearchArgument.Builder.
+ * @return the builder so far.
+ */
+ private def buildSearchArgument(
+ dataTypeMap: Map[String, DataType],
+ expression: Filter,
+ builder: Builder): Builder = {
+ import org.apache.spark.sql.sources._
+
+ expression match {
+ case And(left, right) =>
+ val lhs = buildSearchArgument(dataTypeMap, left, builder.startAnd())
+ val rhs = buildSearchArgument(dataTypeMap, right, lhs)
+ rhs.end()
+
+ case Or(left, right) =>
+ val lhs = buildSearchArgument(dataTypeMap, left, builder.startOr())
+ val rhs = buildSearchArgument(dataTypeMap, right, lhs)
+ rhs.end()
+
+ case Not(child) =>
+ buildSearchArgument(dataTypeMap, child, builder.startNot()).end()
+
+ case other =>
+ buildLeafSearchArgument(dataTypeMap, other, builder).getOrElse {
+ throw new SparkException(
+ "The input filter of OrcFilters.buildSearchArgument should be fully convertible.")
+ }
+ }
+ }
+
+ /**
+ * Build a SearchArgument for a leaf predicate and return the builder so far.
+ *
+ * @param dataTypeMap a map from the attribute name to its data type.
+ * @param expression the input filter predicates.
+ * @param builder the input SearchArgument.Builder.
+ * @return the builder so far.
+ */
+ private def buildLeafSearchArgument(
+ dataTypeMap: Map[String, DataType],
+ expression: Filter,
+ builder: Builder): Option[Builder] = {
+ def getType(attribute: String): PredicateLeaf.Type =
+ getPredicateLeafType(dataTypeMap(attribute))
+
+ import org.apache.spark.sql.sources._
+
+ // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()`
+ // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be
+ // wrapped by a "parent" predicate (`And`, `Or`, or `Not`).
+ expression match {
+ case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val quotedName = quoteAttributeNameIfNeeded(attribute)
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startAnd().equals(quotedName, getType(attribute), castedValue).end())
+
+ case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val quotedName = quoteAttributeNameIfNeeded(attribute)
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startAnd().nullSafeEquals(quotedName, getType(attribute), castedValue).end())
+
+ case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val quotedName = quoteAttributeNameIfNeeded(attribute)
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startAnd().lessThan(quotedName, getType(attribute), castedValue).end())
+
+ case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val quotedName = quoteAttributeNameIfNeeded(attribute)
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startAnd().lessThanEquals(quotedName, getType(attribute), castedValue).end())
+
+ case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val quotedName = quoteAttributeNameIfNeeded(attribute)
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startNot().lessThanEquals(quotedName, getType(attribute), castedValue).end())
+
+ case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val quotedName = quoteAttributeNameIfNeeded(attribute)
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startNot().lessThan(quotedName, getType(attribute), castedValue).end())
+
+ case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) =>
+ val quotedName = quoteAttributeNameIfNeeded(attribute)
+ Some(builder.startAnd().isNull(quotedName, getType(attribute)).end())
+
+ case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) =>
+ val quotedName = quoteAttributeNameIfNeeded(attribute)
+ Some(builder.startNot().isNull(quotedName, getType(attribute)).end())
+
+ case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) =>
+ val quotedName = quoteAttributeNameIfNeeded(attribute)
+ val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute)))
+ Some(builder.startAnd().in(quotedName, getType(attribute),
+ castedValues.map(_.asInstanceOf[AnyRef]): _*).end())
+
+ case _ => None
+ }
+ }
+}