From be4c35ee893518967e2a0ff610725177f5af27cf Mon Sep 17 00:00:00 2001 From: Chris Twiner Date: Thu, 21 Mar 2024 15:05:46 +0100 Subject: [PATCH] #787 - Seq can be stream, fails on dbr, do the same as for arb --- .../main/scala/frameless/TypedEncoder.scala | 10 ++-- .../scala/frameless/TypedDatasetSuite.scala | 53 +++++++++++++------ .../src/test/scala/frameless/package.scala | 11 +++- 3 files changed, 52 insertions(+), 22 deletions(-) diff --git a/dataset/src/main/scala/frameless/TypedEncoder.scala b/dataset/src/main/scala/frameless/TypedEncoder.scala index e11ec73d..ebbe6e56 100644 --- a/dataset/src/main/scala/frameless/TypedEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedEncoder.scala @@ -11,9 +11,6 @@ import org.apache.spark.sql.catalyst.expressions.{ UnsafeArrayData, Literal } -import org.apache.spark.sql.FramelessInternals -import org.apache.spark.sql.FramelessInternals.UserDefinedType -import org.apache.spark.sql.{ reflection => ScalaReflection } import org.apache.spark.sql.catalyst.util.{ ArrayBasedMapData, @@ -528,7 +525,12 @@ object TypedEncoder { object CollectionConversion { implicit def seqToSeq[Y] = new CollectionConversion[Seq, Seq, Y] { - override def convert(c: Seq[Y]): Seq[Y] = c + override def convert(c: Seq[Y]): Seq[Y] = + c match { + // Stream is produced + case _: Stream[Y]@unchecked => c.toVector.toSeq + case _ => c + } } implicit def seqToVector[Y] = new CollectionConversion[Seq, Vector, Y] { diff --git a/dataset/src/test/scala/frameless/TypedDatasetSuite.scala b/dataset/src/test/scala/frameless/TypedDatasetSuite.scala index 8a469783..ef778922 100644 --- a/dataset/src/test/scala/frameless/TypedDatasetSuite.scala +++ b/dataset/src/test/scala/frameless/TypedDatasetSuite.scala @@ -2,28 +2,35 @@ package frameless import com.globalmentor.apache.hadoop.fs.BareLocalFileSystem import org.apache.hadoop.fs.local.StreamingFS -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SQLContext, SparkSession} +import org.apache.spark.{ SparkConf, SparkContext } +import org.apache.spark.sql.{ SQLContext, SparkSession } import org.scalactic.anyvals.PosZInt import org.scalatest.BeforeAndAfterAll import org.scalatestplus.scalacheck.Checkers import org.scalacheck.Prop import org.scalacheck.Prop._ -import scala.util.{Properties, Try} +import scala.util.{ Properties, Try } import org.scalatest.funsuite.AnyFunSuite trait SparkTesting { self: BeforeAndAfterAll => - val appID: String = new java.util.Date().toString + math.floor(math.random * 10E4).toLong.toString + val appID: String = new java.util.Date().toString + math + .floor(math.random * 10e4) + .toLong + .toString /** * Allows bare naked to be used instead of winutils for testing / dev */ def registerFS(sparkConf: SparkConf): SparkConf = { if (System.getProperty("os.name").startsWith("Windows")) - sparkConf.set("spark.hadoop.fs.file.impl", classOf[BareLocalFileSystem].getName). - set("spark.hadoop.fs.AbstractFileSystem.file.impl", classOf[StreamingFS].getName) + sparkConf + .set("spark.hadoop.fs.file.impl", classOf[BareLocalFileSystem].getName) + .set( + "spark.hadoop.fs.AbstractFileSystem.file.impl", + classOf[StreamingFS].getName + ) else sparkConf } @@ -40,9 +47,9 @@ trait SparkTesting { self: BeforeAndAfterAll => implicit def sc: SparkContext = session.sparkContext implicit def sqlContext: SQLContext = session.sqlContext - def registerOptimizations(sqlContext: SQLContext): Unit = { } + def registerOptimizations(sqlContext: SQLContext): Unit = {} - def addSparkConfigProperties(config: SparkConf): Unit = { } + def addSparkConfigProperties(config: SparkConf): Unit = {} override def beforeAll(): Unit = { assert(s == null) @@ -51,7 +58,7 @@ trait SparkTesting { self: BeforeAndAfterAll => registerOptimizations(sqlContext) } - override def afterAll(): Unit = { + override def afterAll(): Unit = if (shouldCloseSession) { if (s != null) { s.stop() s = null @@ -59,11 +66,16 @@ trait SparkTesting { self: BeforeAndAfterAll => } } +class TypedDatasetSuite + extends AnyFunSuite + with Checkers + with BeforeAndAfterAll + with SparkTesting { -class TypedDatasetSuite extends AnyFunSuite with Checkers with BeforeAndAfterAll with SparkTesting { // Limit size of generated collections and number of checks to avoid OutOfMemoryError implicit override val generatorDrivenConfig: PropertyCheckConfiguration = { - def getPosZInt(name: String, default: PosZInt) = Properties.envOrNone(s"FRAMELESS_GEN_${name}") + def getPosZInt(name: String, default: PosZInt) = Properties + .envOrNone(s"FRAMELESS_GEN_${name}") .flatMap(s => Try(s.toInt).toOption) .flatMap(PosZInt.from) .getOrElse(default) @@ -75,17 +87,24 @@ class TypedDatasetSuite extends AnyFunSuite with Checkers with BeforeAndAfterAll implicit val sparkDelay: SparkDelay[Job] = Job.framelessSparkDelayForJob - def approximatelyEqual[A](a: A, b: A)(implicit numeric: Numeric[A]): Prop = { + def approximatelyEqual[A]( + a: A, + b: A + )(implicit + numeric: Numeric[A] + ): Prop = { val da = numeric.toDouble(a) val db = numeric.toDouble(b) - val epsilon = 1E-6 + val epsilon = 1e-6 // Spark has a weird behaviour concerning expressions that should return Inf // Most of the time they return NaN instead, for instance stddev of Seq(-7.827553978923477E227, -5.009124275715786E153) - if((da.isNaN || da.isInfinity) && (db.isNaN || db.isInfinity)) proved + if ((da.isNaN || da.isInfinity) && (db.isNaN || db.isInfinity)) proved else if ( (da - db).abs < epsilon || - (da - db).abs < da.abs / 100) - proved - else falsified :| s"Expected $a but got $b, which is more than 1% off and greater than epsilon = $epsilon." + (da - db).abs < da.abs / 100 + ) + proved + else + falsified :| s"Expected $a but got $b, which is more than 1% off and greater than epsilon = $epsilon." } } diff --git a/dataset/src/test/scala/frameless/package.scala b/dataset/src/test/scala/frameless/package.scala index 601613c8..8085582a 100644 --- a/dataset/src/test/scala/frameless/package.scala +++ b/dataset/src/test/scala/frameless/package.scala @@ -119,7 +119,7 @@ package object frameless { private var outputDir: String = _ - /** allow usage on non-build environments */ + /** allow test usage on non-build environments */ def setOutputDir(path: String): Unit = { outputDir = path } @@ -130,6 +130,15 @@ package object frameless { else "target/test-output" + private var shouldClose = true + + /** allow test usage on non-build environments */ + def setShouldCloseSession(shouldClose: Boolean): Unit = { + this.shouldClose = shouldClose + } + + lazy val shouldCloseSession = shouldClose + /** * Will dive down causes until either the cause is true or there are no more causes * @param t