From 51b9d719faf13760524aee56febb080ff9da1b04 Mon Sep 17 00:00:00 2001 From: peng Date: Fri, 24 Nov 2023 16:51:58 -0500 Subject: [PATCH] Add a test case for the new TypedRow encoder implemented the proposal --- .../main/scala/frameless/RecordEncoder.scala | 94 ++++++++++++++----- .../scala/frameless/RecordEncoderStage1.scala | 49 ++++++++++ .../main/scala/frameless/TypedEncoder.scala | 12 ++- .../src/main/scala/frameless/TypedRow.scala | 45 +++++++++ .../test/scala/frameless/InjectionTests.scala | 2 +- .../scala/frameless/RecordEncoderTests.scala | 46 +++++---- .../frameless/RefinedFieldEncoderTests.scala | 4 +- 7 files changed, 209 insertions(+), 43 deletions(-) create mode 100644 dataset/src/main/scala/frameless/RecordEncoderStage1.scala create mode 100644 dataset/src/main/scala/frameless/TypedRow.scala diff --git a/dataset/src/main/scala/frameless/RecordEncoder.scala b/dataset/src/main/scala/frameless/RecordEncoder.scala index 7427d9de0..3349c72a0 100644 --- a/dataset/src/main/scala/frameless/RecordEncoder.scala +++ b/dataset/src/main/scala/frameless/RecordEncoder.scala @@ -119,20 +119,19 @@ object DropUnitValues { } } -class RecordEncoder[F, G <: HList, H <: HList] +abstract class RecordEncoder[F, G <: HList, H <: HList] (implicit - i0: LabelledGeneric.Aux[F, G], - i1: DropUnitValues.Aux[G, H], - i2: IsHCons[H], - fields: Lazy[RecordEncoderFields[H]], - newInstanceExprs: Lazy[NewInstanceExprs[G]], + stage1: RecordEncoderStage1[G, H], classTag: ClassTag[F] ) extends TypedEncoder[F] { + + import stage1._ + def nullable: Boolean = false - def jvmRepr: DataType = FramelessInternals.objectTypeFor[F] + lazy val jvmRepr: DataType = FramelessInternals.objectTypeFor[F] - def catalystRepr: DataType = { + lazy val catalystRepr: DataType = { val structFields = fields.value.value.map { field => StructField( name = field.name, @@ -145,34 +144,35 @@ class RecordEncoder[F, G <: HList, H <: HList] StructType(structFields) } + } + +object RecordEncoder { + + case class ForGeneric[F, G <: HList, H <: HList]( + )(implicit + stage1: RecordEncoderStage1[G, H], + classTag: ClassTag[F]) + extends RecordEncoder[F, G, H] { + + import stage1._ + def toCatalyst(path: Expression): Expression = { - val nameExprs = fields.value.value.map { field => - Literal(field.name) - } val valueExprs = fields.value.value.map { field => val fieldPath = Invoke(path, field.name, field.encoder.jvmRepr, Nil) field.encoder.toCatalyst(fieldPath) } - // the way exprs are encoded in CreateNamedStruct - val exprs = nameExprs.zip(valueExprs).flatMap { - case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil - } + val createExpr = stage1.cellsToCatalyst(valueExprs) - val createExpr = CreateNamedStruct(exprs) val nullExpr = Literal.create(null, createExpr.dataType) If(IsNull(path), nullExpr, createExpr) } def fromCatalyst(path: Expression): Expression = { - val exprs = fields.value.value.map { field => - field.encoder.fromCatalyst( - GetStructField(path, field.ordinal, Some(field.name))) - } - val newArgs = newInstanceExprs.value.from(exprs) + val newArgs = stage1.fromCatalystToCells(path) val newExpr = NewInstance( classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true) @@ -180,6 +180,58 @@ class RecordEncoder[F, G <: HList, H <: HList] If(IsNull(path), nullExpr, newExpr) } + } + + case class ForTypedRow[G <: HList, H <: HList]( + )(implicit + stage1: RecordEncoderStage1[G, H], + classTag: ClassTag[TypedRow[G]]) + extends RecordEncoder[TypedRow[G], G, H] { + + import stage1._ + + private final val _apply = "apply" + private final val _fromInternalRow = "fromInternalRow" + + def toCatalyst(path: Expression): Expression = { + + val valueExprs = fields.value.value.zipWithIndex.map { + case (field, i) => + val fieldPath = Invoke( + path, + _apply, + field.encoder.jvmRepr, + Seq(Literal.create(i, IntegerType)) + ) + field.encoder.toCatalyst(fieldPath) + } + + val createExpr = stage1.cellsToCatalyst(valueExprs) + + val nullExpr = Literal.create(null, createExpr.dataType) + + If(IsNull(path), nullExpr, createExpr) + } + + def fromCatalyst(path: Expression): Expression = { + + val newArgs = stage1.fromCatalystToCells(path) + val aggregated = CreateStruct(newArgs) + + val partial = TypedRow.WithCatalystTypes(newArgs.map(_.dataType)) + + val newExpr = Invoke( + Literal.fromObject(partial), + _fromInternalRow, + TypedRow.catalystType, + Seq(aggregated) + ) + + val nullExpr = Literal.create(null, jvmRepr) + + If(IsNull(path), nullExpr, newExpr) + } + } } final class RecordFieldEncoder[T]( diff --git a/dataset/src/main/scala/frameless/RecordEncoderStage1.scala b/dataset/src/main/scala/frameless/RecordEncoderStage1.scala new file mode 100644 index 000000000..b7cecf380 --- /dev/null +++ b/dataset/src/main/scala/frameless/RecordEncoderStage1.scala @@ -0,0 +1,49 @@ +package frameless + +import org.apache.spark.sql.catalyst.expressions.{ + CreateNamedStruct, + Expression, + GetStructField, + Literal +} +import shapeless.{ HList, Lazy } + +case class RecordEncoderStage1[G <: HList, H <: HList]( + )(implicit +// i1: DropUnitValues.Aux[G, H], +// i2: IsHCons[H], + val fields: Lazy[RecordEncoderFields[H]], + val newInstanceExprs: Lazy[NewInstanceExprs[G]]) { + + def cellsToCatalyst(valueExprs: Seq[Expression]): Expression = { + val nameExprs = fields.value.value.map { field => Literal(field.name) } + + // the way exprs are encoded in CreateNamedStruct + val exprs = nameExprs.zip(valueExprs).flatMap { + case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil + } + + val createExpr = CreateNamedStruct(exprs) + createExpr + } + + def fromCatalystToCells(path: Expression): Seq[Expression] = { + val exprs = fields.value.value.map { field => + field.encoder.fromCatalyst( + GetStructField(path, field.ordinal, Some(field.name)) + ) + } + + val newArgs = newInstanceExprs.value.from(exprs) + newArgs + } +} + +object RecordEncoderStage1 { + + implicit def usingDerivation[G <: HList, H <: HList]( + implicit + i3: Lazy[RecordEncoderFields[H]], + i4: Lazy[NewInstanceExprs[G]] + ): RecordEncoderStage1[G, H] = RecordEncoderStage1[G, H]() +} diff --git a/dataset/src/main/scala/frameless/TypedEncoder.scala b/dataset/src/main/scala/frameless/TypedEncoder.scala index b42b026ee..2877dc7d9 100644 --- a/dataset/src/main/scala/frameless/TypedEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedEncoder.scala @@ -727,7 +727,7 @@ object TypedEncoder { } /** Encodes things as records if there is no Injection defined */ - implicit def usingDerivation[F, G <: HList, H <: HList]( + implicit def deriveForGeneric[F, G <: HList, H <: HList]( implicit i0: LabelledGeneric.Aux[F, G], i1: DropUnitValues.Aux[G, H], @@ -735,7 +735,15 @@ object TypedEncoder { i3: Lazy[RecordEncoderFields[H]], i4: Lazy[NewInstanceExprs[G]], i5: ClassTag[F] - ): TypedEncoder[F] = new RecordEncoder[F, G, H] + ): TypedEncoder[F] = RecordEncoder.ForGeneric[F, G, H]() + + implicit def deriveForTypedRow[G <: HList, H <: HList]( + implicit + i1: DropUnitValues.Aux[G, H], + i2: IsHCons[H], + i3: Lazy[RecordEncoderFields[H]], + i4: Lazy[NewInstanceExprs[G]] + ): TypedEncoder[TypedRow[G]] = RecordEncoder.ForTypedRow[G, H]() /** Encodes things using a Spark SQL's User Defined Type (UDT) if there is one defined in implicit */ implicit def usingUserDefinedType[ diff --git a/dataset/src/main/scala/frameless/TypedRow.scala b/dataset/src/main/scala/frameless/TypedRow.scala new file mode 100644 index 000000000..c7abae774 --- /dev/null +++ b/dataset/src/main/scala/frameless/TypedRow.scala @@ -0,0 +1,45 @@ +package frameless + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.{ DataType, ObjectType } +import shapeless.HList + +case class TypedRow[T <: HList](row: Row) { + + def apply(i: Int): Any = row.apply(i) +} + +object TypedRow { + + def apply(values: Any*): TypedRow[HList] = { + + val row = Row.fromSeq(values) + TypedRow(row) + } + + case class WithCatalystTypes(schema: Seq[DataType]) { + + def fromInternalRow(row: InternalRow): TypedRow[HList] = { + val data = row.toSeq(schema).toArray + + apply(data: _*) + } + + } + + object WithCatalystTypes {} + + def fromHList[T <: HList]( + hlist: T + ): TypedRow[T] = { + + val cells = hlist.runtimeList + + val row = Row.fromSeq(cells) + TypedRow(row) + } + + lazy val catalystType: ObjectType = ObjectType(classOf[TypedRow[_]]) + +} diff --git a/dataset/src/test/scala/frameless/InjectionTests.scala b/dataset/src/test/scala/frameless/InjectionTests.scala index c17a52bd7..9cd762a78 100644 --- a/dataset/src/test/scala/frameless/InjectionTests.scala +++ b/dataset/src/test/scala/frameless/InjectionTests.scala @@ -180,7 +180,7 @@ class InjectionTests extends TypedDatasetSuite { } test("Resolve ambiguity by importing usingDerivation") { - import TypedEncoder.usingDerivation + import TypedEncoder.deriveForGeneric assert(implicitly[TypedEncoder[Person]].isInstanceOf[RecordEncoder[Person, _, _]]) check(forAll(prop[Person] _)) } diff --git a/dataset/src/test/scala/frameless/RecordEncoderTests.scala b/dataset/src/test/scala/frameless/RecordEncoderTests.scala index 98274cf01..101a486bf 100644 --- a/dataset/src/test/scala/frameless/RecordEncoderTests.scala +++ b/dataset/src/test/scala/frameless/RecordEncoderTests.scala @@ -1,23 +1,12 @@ package frameless -import org.apache.spark.sql.{Row, functions => F} -import org.apache.spark.sql.types.{ - ArrayType, - BinaryType, - DecimalType, - IntegerType, - LongType, - MapType, - ObjectType, - StringType, - StructField, - StructType -} - -import shapeless.{HList, LabelledGeneric} -import shapeless.test.illTyped - +import frameless.RecordEncoderTests.{ A, B, E } +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{ Row, functions => F } import org.scalatest.matchers.should.Matchers +import shapeless.record.Record +import shapeless.test.illTyped +import shapeless.{ HList, LabelledGeneric } final class RecordEncoderTests extends TypedDatasetSuite with Matchers { test("Unable to encode products made from units only") { @@ -87,6 +76,26 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { ds.collect.head shouldBe obj } + test("TypedRow") { + + val r1: RecordEncoderTests.RR = Record(x = 1, y = "abc") + val r2: TypedRow[RecordEncoderTests.RR] = TypedRow.fromHList(r1) + + val rdd = sc.parallelize(Seq(r2)) + val ds = + session.createDataset(rdd)( + TypedExpressionEncoder[TypedRow[RecordEncoderTests.RR]] + ) + + ds.schema.treeString shouldBe + """root + | |-- x: integer (nullable = true) + | |-- y: string (nullable = true) + |""".stripMargin + + ds.collect.head shouldBe r2 + } + test("Scalar value class") { import RecordEncoderTests._ @@ -548,6 +557,9 @@ object RecordEncoderTests { case class D(m: Map[String, Int]) case class E(b: Set[B]) + val RR = Record.`'x -> Int, 'y -> String` + type RR = RR.T + final class Subject(val name: String) extends AnyVal with Serializable final class Grade(val value: BigDecimal) extends AnyVal with Serializable diff --git a/refined/src/test/scala/frameless/RefinedFieldEncoderTests.scala b/refined/src/test/scala/frameless/RefinedFieldEncoderTests.scala index 5476284ea..38c2781f0 100644 --- a/refined/src/test/scala/frameless/RefinedFieldEncoderTests.scala +++ b/refined/src/test/scala/frameless/RefinedFieldEncoderTests.scala @@ -114,7 +114,7 @@ object RefinedTypesTests { import frameless.refined._ // implicit instances for refined - implicit val encoderA: TypedEncoder[A] = TypedEncoder.usingDerivation + implicit val encoderA: TypedEncoder[A] = TypedEncoder.deriveForGeneric - implicit val encoderB: TypedEncoder[B] = TypedEncoder.usingDerivation + implicit val encoderB: TypedEncoder[B] = TypedEncoder.deriveForGeneric }