diff --git a/NOTICE b/NOTICE new file mode 100644 index 00000000000..0662cf42b17 --- /dev/null +++ b/NOTICE @@ -0,0 +1,13 @@ +RAPIDS plugin for Apache Spark +Copyright (c) 2019, NVIDIA CORPORATION + +// ------------------------------------------------------------------ +// NOTICE file corresponding to the section 4d of The Apache License, +// Version 2.0, in this case for +// ------------------------------------------------------------------ + +Apache Spark +Copyright 2014 and onwards The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). diff --git a/NOTICE-binary b/NOTICE-binary new file mode 100644 index 00000000000..ad40699ab14 --- /dev/null +++ b/NOTICE-binary @@ -0,0 +1,24 @@ +RAPIDS plugin for Apache Spark +Copyright (c) 2019, NVIDIA CORPORATION + +// ------------------------------------------------------------------ +// NOTICE file corresponding to the section 4d of The Apache License, +// Version 2.0, in this case for +// ------------------------------------------------------------------ + +Apache Spark +Copyright 2014 and onwards The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +--------------------------------------------------------------------- + +Apache ORC +Copyright 2013-2019 The Apache Software Foundation + +This product includes software developed by The Apache Software +Foundation (http://www.apache.org/). + +This product includes software developed by Hewlett-Packard: +(c) Copyright [2014-2015] Hewlett-Packard Development Company, L.P diff --git a/README.md b/README.md index 89e862ad7b7..bd49700599a 100644 --- a/README.md +++ b/README.md @@ -17,3 +17,8 @@ enable it again. > spark-shell --jars 'rapids-4-spark-0.8-SNAPSHOT.jar,cudf-0.8-SNAPSHOT-cuda10.jar' --conf spark.sql.extensions=ai.rapids.spark.Plugin ``` +## Notes on Building + +The build requires Apache Spark 3.0+ built against the nohive ORC classifier. Building with a +version of Spark built without the nohive support (e.g.: the Apache Spark hadoop-3.2 profile) +will cause errors during build and test due to the conflicting ORC jars. diff --git a/pom.xml b/pom.xml index e65f160690c..8425142d5e1 100644 --- a/pom.xml +++ b/pom.xml @@ -20,6 +20,9 @@ cuda10 2.12 2.12.8 + 1.5.5 + nohive + ai.rapids.spark.shaded @@ -43,9 +46,76 @@ 1.1.1 - org.apache.spark - spark-sql_${scala.binary.version} - 3.0.0-SNAPSHOT + org.apache.spark + spark-sql_${scala.binary.version} + 3.0.0-SNAPSHOT + + + org.apache.orc + orc-core + ${orc.version} + ${orc.classifier} + + + javax.xml.bind + jaxb-api + + + org.apache.hadoop + hadoop-common + + + org.apache.hadoop + hadoop-hdfs + + + org.apache.hive + hive-storage-api + + + + + org.apache.orc + orc-mapreduce + ${orc.version} + ${orc.classifier} + + + org.apache.hadoop + hadoop-common + + + org.apache.hadoop + hadoop-mapreduce-client-core + + + org.apache.orc + orc-core + + + org.apache.hive + hive-storage-api + + + com.esotericsoftware + kryo-shaded + + + + + org.apache.orc + orc-shims + ${orc.version} + + + org.apache.hadoop + hadoop-common + + + org.apache.hadoop + hadoop-hdfs + + + LICENSE + + + + + maven-antrun-plugin + + + copy-notice + + run + + process-resources + + + + + + + + + + + + + + + + org.apache.maven.plugins + maven-shade-plugin + + false + true + + + org.apache.orc:orc-core + org.apache.orc:orc-mapreduce + org.apache.orc:orc-shims + + + + + org.apache.orc + ${rapids.shade.package}.orc + + + com.google.protobuf25 + ${rapids.shade.package}.protobuf25 + + + + + + package + + shade + + + + org.apache.maven.plugins diff --git a/sql-plugin/src/main/scala/ai/rapids/spark/GpuOrcScan.scala b/sql-plugin/src/main/scala/ai/rapids/spark/GpuOrcScan.scala index ab430f93de1..1271813081b 100644 --- a/sql-plugin/src/main/scala/ai/rapids/spark/GpuOrcScan.scala +++ b/sql-plugin/src/main/scala/ai/rapids/spark/GpuOrcScan.scala @@ -48,6 +48,7 @@ import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.datasources.orc.OrcUtils import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.rapids.OrcFilters import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.types.StructType @@ -68,10 +69,14 @@ class GpuOrcScan( readDataSchema, readPartitionSchema, options, pushedFilters) with GpuScan { override def createReaderFactory(): PartitionReaderFactory = { + // Unset any serialized search argument setup by Spark's OrcScanBuilder as + // it will be incompatible due to shading and potential ORC classifier mismatch. + hadoopConf.unset(OrcConf.KRYO_SARG.getAttribute) + val broadcastedConf = sparkSession.sparkContext.broadcast( new GpuSerializableConfiguration(hadoopConf)) GpuOrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, - dataSchema, readDataSchema, readPartitionSchema, rapidsConf) + dataSchema, readDataSchema, readPartitionSchema, pushedFilters, rapidsConf) } } @@ -92,6 +97,7 @@ case class GpuOrcPartitionReaderFactory( dataSchema: StructType, readDataSchema: StructType, partitionSchema: StructType, + pushedFilters: Array[Filter], @transient rapidsConf: RapidsConf) extends FilePartitionReaderFactory { private val isCaseSensitive = sqlConf.caseSensitiveAnalysis private val debugDumpPrefix = rapidsConf.orcDebugDumpPrefix @@ -108,8 +114,9 @@ case class GpuOrcPartitionReaderFactory( OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, orcSchemaString) OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive) + val fullSchema = StructType(dataSchema ++ partitionSchema) val reader = new GpuOrcPartitionReader(conf, partFile, dataSchema, readDataSchema, - debugDumpPrefix) + fullSchema, pushedFilters, debugDumpPrefix) ColumnarPartitionReaderWithPartitionValues.newReader(partFile, reader, partitionSchema) } } @@ -155,6 +162,8 @@ class GpuOrcPartitionReader( partFile: PartitionedFile, dataSchema: StructType, readDataSchema: StructType, + fullSchema: StructType, + pushedFilters: Array[Filter], debugDumpPrefix: String) extends PartitionReader[ColumnarBatch] with Logging { private var batch: Option[ColumnarBatch] = None private var isExhausted: Boolean = false @@ -598,6 +607,11 @@ class GpuOrcPartitionReader( splitStripes: Seq[StripeInformation], useUTCTimestamp: Boolean): Option[Table] = { val readerOpts = OrcInputFormat.buildOptions(conf, orcReader, partFile.start, partFile.length) + // create the search argument if we have pushed filters + OrcFilters.createFilter(fullSchema, pushedFilters).foreach { f => + readerOpts.searchArgument(f, fullSchema.fieldNames) + } + val updatedReadSchema = checkSchemaCompatibility(orcReader.getSchema, readerOpts.getSchema, readerOpts.getIsSchemaEvolutionCaseAware) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFilters.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFilters.scala new file mode 100644 index 00000000000..298efebd5a7 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFilters.scala @@ -0,0 +1,191 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.rapids + +import org.apache.orc.storage.common.`type`.HiveDecimal +import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} +import org.apache.orc.storage.ql.io.sarg.SearchArgument.Builder +import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory.newBuilder +import org.apache.orc.storage.serde2.io.HiveDecimalWritable + +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types._ +import org.apache.spark.SparkException +import org.apache.spark.sql.execution.datasources.orc.OrcFiltersBase + +// This is derived from Apache Spark's OrcFilters code to avoid calling the +// Spark version. Spark's version can potentially create a search argument +// applier object that is incompatible with the orc:nohive jar that has been +// shaded as part of this project. +// +// The use of Spark's OrcFiltersBase is safe since it is explicitly designed +// to be reusable across different ORC jar classifiers. +object OrcFilters extends OrcFiltersBase { + + /** + * Create ORC filter as a SearchArgument instance. + * + * NOTE: These filters should be pre-filtered by Spark to only contain the + * filters convertible to ORC, so checking what is convertible is + * not necessary here. + */ + def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { + val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + // Combines all filters using `And` to produce a single conjunction + val conjunctionOptional = buildTree(filters) + conjunctionOptional.map { conjunction => + // Then tries to build a single ORC `SearchArgument` for the conjunction predicate. + // The input predicate is fully convertible. There should not be any empty result in the + // following recursive method call `buildSearchArgument`. + buildSearchArgument(dataTypeMap, conjunction, newBuilder).build() + } + } + + /** + * Get PredicateLeafType which is corresponding to the given DataType. + */ + private def getPredicateLeafType(dataType: DataType) = dataType match { + case BooleanType => PredicateLeaf.Type.BOOLEAN + case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG + case FloatType | DoubleType => PredicateLeaf.Type.FLOAT + case StringType => PredicateLeaf.Type.STRING + case DateType => PredicateLeaf.Type.DATE + case TimestampType => PredicateLeaf.Type.TIMESTAMP + case _: DecimalType => PredicateLeaf.Type.DECIMAL + case _ => throw new UnsupportedOperationException(s"DataType: ${dataType.catalogString}") + } + + /** + * Cast literal values for filters. + * + * We need to cast to long because ORC raises exceptions + * at 'checkLiteralType' of SearchArgumentImpl.java. + */ + private def castLiteralValue(value: Any, dataType: DataType): Any = dataType match { + case ByteType | ShortType | IntegerType | LongType => + value.asInstanceOf[Number].longValue + case FloatType | DoubleType => + value.asInstanceOf[Number].doubleValue() + case _: DecimalType => + new HiveDecimalWritable(HiveDecimal.create(value.asInstanceOf[java.math.BigDecimal])) + case _ => value + } + + /** + * Build a SearchArgument and return the builder so far. + * + * @param dataTypeMap a map from the attribute name to its data type. + * @param expression the input predicates, which should be fully convertible to SearchArgument. + * @param builder the input SearchArgument.Builder. + * @return the builder so far. + */ + private def buildSearchArgument( + dataTypeMap: Map[String, DataType], + expression: Filter, + builder: Builder): Builder = { + import org.apache.spark.sql.sources._ + + expression match { + case And(left, right) => + val lhs = buildSearchArgument(dataTypeMap, left, builder.startAnd()) + val rhs = buildSearchArgument(dataTypeMap, right, lhs) + rhs.end() + + case Or(left, right) => + val lhs = buildSearchArgument(dataTypeMap, left, builder.startOr()) + val rhs = buildSearchArgument(dataTypeMap, right, lhs) + rhs.end() + + case Not(child) => + buildSearchArgument(dataTypeMap, child, builder.startNot()).end() + + case other => + buildLeafSearchArgument(dataTypeMap, other, builder).getOrElse { + throw new SparkException( + "The input filter of OrcFilters.buildSearchArgument should be fully convertible.") + } + } + } + + /** + * Build a SearchArgument for a leaf predicate and return the builder so far. + * + * @param dataTypeMap a map from the attribute name to its data type. + * @param expression the input filter predicates. + * @param builder the input SearchArgument.Builder. + * @return the builder so far. + */ + private def buildLeafSearchArgument( + dataTypeMap: Map[String, DataType], + expression: Filter, + builder: Builder): Option[Builder] = { + def getType(attribute: String): PredicateLeaf.Type = + getPredicateLeafType(dataTypeMap(attribute)) + + import org.apache.spark.sql.sources._ + + // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()` + // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be + // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). + expression match { + case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().equals(quotedName, getType(attribute), castedValue).end()) + + case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().nullSafeEquals(quotedName, getType(attribute), castedValue).end()) + + case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().lessThan(quotedName, getType(attribute), castedValue).end()) + + case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + + case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startNot().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + + case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startNot().lessThan(quotedName, getType(attribute), castedValue).end()) + + case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) + Some(builder.startAnd().isNull(quotedName, getType(attribute)).end()) + + case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) + Some(builder.startNot().isNull(quotedName, getType(attribute)).end()) + + case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute))) + Some(builder.startAnd().in(quotedName, getType(attribute), + castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) + + case _ => None + } + } +}