diff --git a/build.sbt b/build.sbt index 371226cb4..5e529c631 100644 --- a/build.sbt +++ b/build.sbt @@ -17,7 +17,7 @@ val shimVersion = "0.0.1-SNAPSHOT" val Scala212 = "2.12.19" val Scala213 = "2.13.13" -//resolvers in Global += Resolver.mavenLocal +resolvers in Global += Resolver.mavenLocal resolvers in Global += MavenRepository( "sonatype-s01-snapshots", Resolver.SonatypeS01RepositoryRoot + "/snapshots" diff --git a/dataset/src/main/scala/org/apache/spark/sql/FramelessInternals.scala b/dataset/src/main/scala/frameless/FramelessInternals.scala similarity index 86% rename from dataset/src/main/scala/org/apache/spark/sql/FramelessInternals.scala rename to dataset/src/main/scala/frameless/FramelessInternals.scala index 6eb1d1baf..78684e7b1 100644 --- a/dataset/src/main/scala/org/apache/spark/sql/FramelessInternals.scala +++ b/dataset/src/main/scala/frameless/FramelessInternals.scala @@ -1,20 +1,15 @@ -package org.apache.spark.sql +package frameless -import org.apache.spark.sql.catalyst.expressions.codegen._ -import com.sparkutils.shim.expressions.{ - Alias2 => Alias, - CreateStruct1 => CreateStruct -} -import org.apache.spark.sql.catalyst.expressions.{ - Expression, - NamedExpression, - NonSQLExpression -} +import com.sparkutils.shim.expressions.{Alias2 => Alias, CreateStruct1 => CreateStruct} +import org.apache.spark.sql.shim.{utils => shimUtils} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.plans.logical.{ LogicalPlan, Project } +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, NonSQLExpression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.ObjectType +import org.apache.spark.sql._ + import scala.reflect.ClassTag object FramelessInternals { @@ -36,7 +31,7 @@ object FramelessInternals { def expr(column: Column): Expression = column.expr - def logicalPlan(ds: Dataset[_]): LogicalPlan = ds.logicalPlan + def logicalPlan(ds: Dataset[_]): LogicalPlan = shimUtils.logicalPlan(ds) def executePlan(ds: Dataset[_], plan: LogicalPlan): QueryExecution = ds.sparkSession.sessionState.executePlan(plan) @@ -68,7 +63,7 @@ object FramelessInternals { new Dataset(sqlContext, plan, encoder) def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = - Dataset.ofRows(sparkSession, logicalPlan) + shimUtils.ofRows(sparkSession, logicalPlan) // because org.apache.spark.sql.types.UserDefinedType is private[spark] type UserDefinedType[A >: Null] = diff --git a/dataset/src/main/scala/frameless/RecordEncoder.scala b/dataset/src/main/scala/frameless/RecordEncoder.scala index a781902b0..574ce4272 100644 --- a/dataset/src/main/scala/frameless/RecordEncoder.scala +++ b/dataset/src/main/scala/frameless/RecordEncoder.scala @@ -7,7 +7,6 @@ import com.sparkutils.shim.expressions.{ WrapOption2 => WrapOption } import com.sparkutils.shim.{ deriveUnitLiteral, ifIsNull } -import org.apache.spark.sql.FramelessInternals import org.apache.spark.sql.catalyst.expressions.{ Expression, Literal } import org.apache.spark.sql.shim.{ Invoke5 => Invoke, diff --git a/dataset/src/main/scala/frameless/TypedColumn.scala b/dataset/src/main/scala/frameless/TypedColumn.scala index 4cd3fcc64..5a31a8529 100644 --- a/dataset/src/main/scala/frameless/TypedColumn.scala +++ b/dataset/src/main/scala/frameless/TypedColumn.scala @@ -8,7 +8,7 @@ import org.apache.spark.sql.catalyst.expressions.{ Literal } // 787 - Spark 4 source code compat import org.apache.spark.sql.types.DecimalType -import org.apache.spark.sql.{ Column, FramelessInternals } +import org.apache.spark.sql.Column import shapeless._ import shapeless.ops.record.Selector diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala index add2170b2..8a75c009e 100644 --- a/dataset/src/main/scala/frameless/TypedDataset.scala +++ b/dataset/src/main/scala/frameless/TypedDataset.scala @@ -4,7 +4,7 @@ import java.util import frameless.functions.CatalystExplodableCollection import frameless.ops._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Column, DataFrame, Dataset, FramelessInternals, SparkSession} +import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal} import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} import org.apache.spark.sql.catalyst.plans.Inner diff --git a/dataset/src/main/scala/frameless/TypedEncoder.scala b/dataset/src/main/scala/frameless/TypedEncoder.scala index 6f450c7f0..3a10f9781 100644 --- a/dataset/src/main/scala/frameless/TypedEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedEncoder.scala @@ -5,10 +5,7 @@ import java.util.Date import java.time.{ Duration, Instant, LocalDate, Period } import java.sql.Timestamp import scala.reflect.ClassTag - -import org.apache.spark.sql.FramelessInternals -import org.apache.spark.sql.FramelessInternals.UserDefinedType -import org.apache.spark.sql.{ reflection => ScalaReflection } +import FramelessInternals.UserDefinedType import org.apache.spark.sql.catalyst.expressions.{ Expression, UnsafeArrayData, Literal } import org.apache.spark.sql.catalyst.util.{ ArrayBasedMapData, @@ -26,6 +23,7 @@ import com.sparkutils.shim.expressions.{ MapObjects5 => MapObjects, ExternalMapToCatalyst7 => ExternalMapToCatalyst } +import frameless.{reflection => ScalaReflection} import org.apache.spark.sql.shim.{ StaticInvoke4 => StaticInvoke, NewInstance4 => NewInstance, diff --git a/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala b/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala index 1263afd68..ad137a4d6 100644 --- a/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala +++ b/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala @@ -1,7 +1,7 @@ package frameless package functions -import org.apache.spark.sql.FramelessInternals.expr +import FramelessInternals.expr import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.{ functions => sparkFunctions } import frameless.syntax._ diff --git a/dataset/src/main/scala/frameless/functions/package.scala b/dataset/src/main/scala/frameless/functions/package.scala index 1a57101e0..391852dce 100644 --- a/dataset/src/main/scala/frameless/functions/package.scala +++ b/dataset/src/main/scala/frameless/functions/package.scala @@ -1,13 +1,12 @@ package frameless +import frameless.{reflection => ScalaReflection} import scala.reflect.ClassTag import shapeless._ import shapeless.labelled.FieldType import shapeless.ops.hlist.IsHCons import shapeless.ops.record.{ Keys, Values } - -import org.apache.spark.sql.{ reflection => ScalaReflection } import org.apache.spark.sql.catalyst.expressions.Literal package object functions extends Udf with UnaryFunctions { diff --git a/dataset/src/main/scala/frameless/ops/GroupByOps.scala b/dataset/src/main/scala/frameless/ops/GroupByOps.scala index 7cda753e5..1fbb314e5 100644 --- a/dataset/src/main/scala/frameless/ops/GroupByOps.scala +++ b/dataset/src/main/scala/frameless/ops/GroupByOps.scala @@ -6,7 +6,6 @@ import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.{ Column, Dataset, - FramelessInternals, RelationalGroupedDataset } import shapeless._ @@ -19,6 +18,7 @@ import shapeless.ops.hlist.{ Tupler } import com.sparkutils.shim.expressions.{ MapGroups4 => MapGroups } +import frameless.FramelessInternals class GroupedByManyOps[T, TK <: HList, K <: HList, KT]( self: TypedDataset[T], diff --git a/dataset/src/main/scala/org/apache/spark/sql/reflection/package.scala b/dataset/src/main/scala/frameless/reflection/package.scala similarity index 53% rename from dataset/src/main/scala/org/apache/spark/sql/reflection/package.scala rename to dataset/src/main/scala/frameless/reflection/package.scala index 07090a8db..aa4551225 100644 --- a/dataset/src/main/scala/org/apache/spark/sql/reflection/package.scala +++ b/dataset/src/main/scala/frameless/reflection/package.scala @@ -1,26 +1,6 @@ -package org.apache.spark.sql +package frameless -import org.apache.spark.sql.catalyst.ScalaReflection.{ - cleanUpReflectionObjects, - getClassFromType, - localTypeOf -} -import org.apache.spark.sql.types.{ - BinaryType, - BooleanType, - ByteType, - CalendarIntervalType, - DataType, - Decimal, - DecimalType, - DoubleType, - FloatType, - IntegerType, - LongType, - NullType, - ObjectType, - ShortType -} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval /** @@ -45,6 +25,59 @@ package object reflection { import universe._ + // Since we are creating a runtime mirror using the class loader of current thread, + // we need to use def at here. So, every time we call mirror, it is using the + // class loader of the current thread. + def mirror: universe.Mirror = { + universe.runtimeMirror(Thread.currentThread().getContextClassLoader) + } + + /** + * Any codes calling `scala.reflect.api.Types.TypeApi.<:<` should be wrapped by this method to + * clean up the Scala reflection garbage automatically. Otherwise, it will leak some objects to + * `scala.reflect.runtime.JavaUniverse.undoLog`. + * + * @see https://github.com/scala/bug/issues/8302 + */ + def cleanUpReflectionObjects[T](func: => T): T = { + universe.asInstanceOf[scala.reflect.runtime.JavaUniverse].undoLog.undo(func) + } + + /** + * Return the Scala Type for `T` in the current classloader mirror. + * + * Use this method instead of the convenience method `universe.typeOf`, which + * assumes that all types can be found in the classloader that loaded scala-reflect classes. + * That's not necessarily the case when running using Eclipse launchers or even + * Sbt console or test (without `fork := true`). + * + * @see SPARK-5281 + */ + def localTypeOf[T: TypeTag]: `Type` = { + val tag = implicitly[TypeTag[T]] + tag.in(mirror).tpe.dealias + } + + /* + * Retrieves the runtime class corresponding to the provided type. + */ + def getClassFromType(tpe: Type): Class[_] = + mirror.runtimeClass(erasure(tpe).dealias.typeSymbol.asClass) + + private def erasure(tpe: Type): Type = { + // For user-defined AnyVal classes, we should not erasure it. Otherwise, it will + // resolve to underlying type which wrapped by this class, e.g erasure + // `case class Foo(i: Int) extends AnyVal` will return type `Int` instead of `Foo`. + // But, for other types, we do need to erasure it. For example, we need to erasure + // `scala.Any` to `java.lang.Object` in order to load it from Java ClassLoader. + // Please see SPARK-17368 & SPARK-31190 for more details. + if (isSubtype(tpe, localTypeOf[AnyVal]) && !tpe.toString.startsWith("scala")) { + tpe + } else { + tpe.erasure + } + } + /** * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping * to a native type, an ObjectType is returned. Special handling is also used for Arrays including @@ -62,7 +95,7 @@ package object reflection { * * See https://github.com/scala/bug/issues/10766 */ - private[sql] def isSubtype(tpe1: `Type`, tpe2: `Type`): Boolean = { + private def isSubtype(tpe1: `Type`, tpe2: `Type`): Boolean = { ScalaSubtypeLock.synchronized { tpe1 <:< tpe2 } diff --git a/dataset/src/test/scala/frameless/UdtEncodedClass.scala b/dataset/src/test/scala/frameless/UdtEncodedClass.scala index 4e5c2c6d9..1c000c58c 100644 --- a/dataset/src/test/scala/frameless/UdtEncodedClass.scala +++ b/dataset/src/test/scala/frameless/UdtEncodedClass.scala @@ -3,7 +3,7 @@ package frameless import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} import org.apache.spark.sql.types._ -import org.apache.spark.sql.FramelessInternals.UserDefinedType +import FramelessInternals.UserDefinedType @SQLUserDefinedType(udt = classOf[UdtEncodedClassUdt]) class UdtEncodedClass(val a: Int, val b: Array[Double]) { diff --git a/ml/src/main/scala/frameless/ml/package.scala b/ml/src/main/scala/frameless/ml/package.scala index d1c306158..1ce56980b 100644 --- a/ml/src/main/scala/frameless/ml/package.scala +++ b/ml/src/main/scala/frameless/ml/package.scala @@ -1,13 +1,13 @@ package frameless -import org.apache.spark.sql.FramelessInternals.UserDefinedType -import org.apache.spark.ml.FramelessInternals +import FramelessInternals.UserDefinedType +import org.apache.spark.ml.{FramelessInternals => MLFramelessInternals} import org.apache.spark.ml.linalg.{Matrix, Vector} package object ml { - implicit val mlVectorUdt: UserDefinedType[Vector] = FramelessInternals.vectorUdt + implicit val mlVectorUdt: UserDefinedType[Vector] = MLFramelessInternals.vectorUdt - implicit val mlMatrixUdt: UserDefinedType[Matrix] = FramelessInternals.matrixUdt + implicit val mlMatrixUdt: UserDefinedType[Matrix] = MLFramelessInternals.matrixUdt }