Skip to content

Commit

Permalink
typelevel#787 - remove all sql package private code
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-twiner committed Mar 7, 2024
1 parent 0f9b7cf commit 059a8e6
Show file tree
Hide file tree
Showing 12 changed files with 79 additions and 55 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ val shimVersion = "0.0.1-SNAPSHOT"
val Scala212 = "2.12.19"
val Scala213 = "2.13.13"

//resolvers in Global += Resolver.mavenLocal
resolvers in Global += Resolver.mavenLocal
resolvers in Global += MavenRepository(
"sonatype-s01-snapshots",
Resolver.SonatypeS01RepositoryRoot + "/snapshots"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
package org.apache.spark.sql
package frameless

import org.apache.spark.sql.catalyst.expressions.codegen._
import com.sparkutils.shim.expressions.{
Alias2 => Alias,
CreateStruct1 => CreateStruct
}
import org.apache.spark.sql.catalyst.expressions.{
Expression,
NamedExpression,
NonSQLExpression
}
import com.sparkutils.shim.expressions.{Alias2 => Alias, CreateStruct1 => CreateStruct}
import org.apache.spark.sql.shim.{utils => shimUtils}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.{ LogicalPlan, Project }
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, NonSQLExpression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.ObjectType
import org.apache.spark.sql._

import scala.reflect.ClassTag

object FramelessInternals {
Expand All @@ -36,7 +31,7 @@ object FramelessInternals {

def expr(column: Column): Expression = column.expr

def logicalPlan(ds: Dataset[_]): LogicalPlan = ds.logicalPlan
def logicalPlan(ds: Dataset[_]): LogicalPlan = shimUtils.logicalPlan(ds)

def executePlan(ds: Dataset[_], plan: LogicalPlan): QueryExecution =
ds.sparkSession.sessionState.executePlan(plan)
Expand Down Expand Up @@ -68,7 +63,7 @@ object FramelessInternals {
new Dataset(sqlContext, plan, encoder)

def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame =
Dataset.ofRows(sparkSession, logicalPlan)
shimUtils.ofRows(sparkSession, logicalPlan)

// because org.apache.spark.sql.types.UserDefinedType is private[spark]
type UserDefinedType[A >: Null] =
Expand Down
1 change: 0 additions & 1 deletion dataset/src/main/scala/frameless/RecordEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import com.sparkutils.shim.expressions.{
WrapOption2 => WrapOption
}
import com.sparkutils.shim.{ deriveUnitLiteral, ifIsNull }
import org.apache.spark.sql.FramelessInternals
import org.apache.spark.sql.catalyst.expressions.{ Expression, Literal }
import org.apache.spark.sql.shim.{
Invoke5 => Invoke,
Expand Down
2 changes: 1 addition & 1 deletion dataset/src/main/scala/frameless/TypedColumn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import org.apache.spark.sql.catalyst.expressions.{
Literal
} // 787 - Spark 4 source code compat
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.{ Column, FramelessInternals }
import org.apache.spark.sql.Column

import shapeless._
import shapeless.ops.record.Selector
Expand Down
2 changes: 1 addition & 1 deletion dataset/src/main/scala/frameless/TypedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import java.util
import frameless.functions.CatalystExplodableCollection
import frameless.ops._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Column, DataFrame, Dataset, FramelessInternals, SparkSession}
import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint}
import org.apache.spark.sql.catalyst.plans.Inner
Expand Down
6 changes: 2 additions & 4 deletions dataset/src/main/scala/frameless/TypedEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@ import java.util.Date
import java.time.{ Duration, Instant, LocalDate, Period }
import java.sql.Timestamp
import scala.reflect.ClassTag

import org.apache.spark.sql.FramelessInternals
import org.apache.spark.sql.FramelessInternals.UserDefinedType
import org.apache.spark.sql.{ reflection => ScalaReflection }
import FramelessInternals.UserDefinedType
import org.apache.spark.sql.catalyst.expressions.{ Expression, UnsafeArrayData, Literal }
import org.apache.spark.sql.catalyst.util.{
ArrayBasedMapData,
Expand All @@ -26,6 +23,7 @@ import com.sparkutils.shim.expressions.{
MapObjects5 => MapObjects,
ExternalMapToCatalyst7 => ExternalMapToCatalyst
}
import frameless.{reflection => ScalaReflection}
import org.apache.spark.sql.shim.{
StaticInvoke4 => StaticInvoke,
NewInstance4 => NewInstance,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package frameless
package functions

import org.apache.spark.sql.FramelessInternals.expr
import FramelessInternals.expr
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.{ functions => sparkFunctions }
import frameless.syntax._
Expand Down
3 changes: 1 addition & 2 deletions dataset/src/main/scala/frameless/functions/package.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package frameless

import frameless.{reflection => ScalaReflection}
import scala.reflect.ClassTag

import shapeless._
import shapeless.labelled.FieldType
import shapeless.ops.hlist.IsHCons
import shapeless.ops.record.{ Keys, Values }

import org.apache.spark.sql.{ reflection => ScalaReflection }
import org.apache.spark.sql.catalyst.expressions.Literal

package object functions extends Udf with UnaryFunctions {
Expand Down
2 changes: 1 addition & 1 deletion dataset/src/main/scala/frameless/ops/GroupByOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.{
Column,
Dataset,
FramelessInternals,
RelationalGroupedDataset
}
import shapeless._
Expand All @@ -19,6 +18,7 @@ import shapeless.ops.hlist.{
Tupler
}
import com.sparkutils.shim.expressions.{ MapGroups4 => MapGroups }
import frameless.FramelessInternals

class GroupedByManyOps[T, TK <: HList, K <: HList, KT](
self: TypedDataset[T],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,6 @@
package org.apache.spark.sql
package frameless

import org.apache.spark.sql.catalyst.ScalaReflection.{
cleanUpReflectionObjects,
getClassFromType,
localTypeOf
}
import org.apache.spark.sql.types.{
BinaryType,
BooleanType,
ByteType,
CalendarIntervalType,
DataType,
Decimal,
DecimalType,
DoubleType,
FloatType,
IntegerType,
LongType,
NullType,
ObjectType,
ShortType
}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

/**
Expand All @@ -45,6 +25,59 @@ package object reflection {

import universe._

// Since we are creating a runtime mirror using the class loader of current thread,
// we need to use def at here. So, every time we call mirror, it is using the
// class loader of the current thread.
def mirror: universe.Mirror = {
universe.runtimeMirror(Thread.currentThread().getContextClassLoader)
}

/**
* Any codes calling `scala.reflect.api.Types.TypeApi.<:<` should be wrapped by this method to
* clean up the Scala reflection garbage automatically. Otherwise, it will leak some objects to
* `scala.reflect.runtime.JavaUniverse.undoLog`.
*
* @see https://github.com/scala/bug/issues/8302
*/
def cleanUpReflectionObjects[T](func: => T): T = {
universe.asInstanceOf[scala.reflect.runtime.JavaUniverse].undoLog.undo(func)
}

/**
* Return the Scala Type for `T` in the current classloader mirror.
*
* Use this method instead of the convenience method `universe.typeOf`, which
* assumes that all types can be found in the classloader that loaded scala-reflect classes.
* That's not necessarily the case when running using Eclipse launchers or even
* Sbt console or test (without `fork := true`).
*
* @see SPARK-5281
*/
def localTypeOf[T: TypeTag]: `Type` = {
val tag = implicitly[TypeTag[T]]
tag.in(mirror).tpe.dealias
}

/*
* Retrieves the runtime class corresponding to the provided type.
*/
def getClassFromType(tpe: Type): Class[_] =
mirror.runtimeClass(erasure(tpe).dealias.typeSymbol.asClass)

private def erasure(tpe: Type): Type = {
// For user-defined AnyVal classes, we should not erasure it. Otherwise, it will
// resolve to underlying type which wrapped by this class, e.g erasure
// `case class Foo(i: Int) extends AnyVal` will return type `Int` instead of `Foo`.
// But, for other types, we do need to erasure it. For example, we need to erasure
// `scala.Any` to `java.lang.Object` in order to load it from Java ClassLoader.
// Please see SPARK-17368 & SPARK-31190 for more details.
if (isSubtype(tpe, localTypeOf[AnyVal]) && !tpe.toString.startsWith("scala")) {
tpe
} else {
tpe.erasure
}
}

/**
* Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping
* to a native type, an ObjectType is returned. Special handling is also used for Arrays including
Expand All @@ -62,7 +95,7 @@ package object reflection {
*
* See https://github.com/scala/bug/issues/10766
*/
private[sql] def isSubtype(tpe1: `Type`, tpe2: `Type`): Boolean = {
private def isSubtype(tpe1: `Type`, tpe2: `Type`): Boolean = {
ScalaSubtypeLock.synchronized {
tpe1 <:< tpe2
}
Expand Down
2 changes: 1 addition & 1 deletion dataset/src/test/scala/frameless/UdtEncodedClass.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package frameless
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.sql.FramelessInternals.UserDefinedType
import FramelessInternals.UserDefinedType

@SQLUserDefinedType(udt = classOf[UdtEncodedClassUdt])
class UdtEncodedClass(val a: Int, val b: Array[Double]) {
Expand Down
8 changes: 4 additions & 4 deletions ml/src/main/scala/frameless/ml/package.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package frameless

import org.apache.spark.sql.FramelessInternals.UserDefinedType
import org.apache.spark.ml.FramelessInternals
import FramelessInternals.UserDefinedType
import org.apache.spark.ml.{FramelessInternals => MLFramelessInternals}
import org.apache.spark.ml.linalg.{Matrix, Vector}

package object ml {

implicit val mlVectorUdt: UserDefinedType[Vector] = FramelessInternals.vectorUdt
implicit val mlVectorUdt: UserDefinedType[Vector] = MLFramelessInternals.vectorUdt

implicit val mlMatrixUdt: UserDefinedType[Matrix] = FramelessInternals.matrixUdt
implicit val mlMatrixUdt: UserDefinedType[Matrix] = MLFramelessInternals.matrixUdt

}

0 comments on commit 059a8e6

Please sign in to comment.