Skip to content

Commit

Permalink
typelevel#787 - attempt to solve all but covar_pop and kurtosis
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-twiner committed Apr 11, 2024
1 parent 986891a commit 66b31e9
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,35 @@ object DoubleBehaviourUtils {
val nanNullHandler: Any => Option[BigDecimal] = {
case null => None
case d: Double =>
nanHandler(d).map { d =>
if (d == Double.NegativeInfinity || d == Double.PositiveInfinity)
BigDecimal("1000000.000000") * (if (d == Double.PositiveInfinity) 1
else -1)
else
BigDecimal(d).setScale(
6,
if (d > 0)
BigDecimal.RoundingMode.FLOOR
else
BigDecimal.RoundingMode.CEILING
)
}
nanHandler(d).map(truncate)
case _ => ???
}

/** ensure different serializations are 'comparable' */
def truncate(d: Double): BigDecimal =
if (d == Double.NegativeInfinity || d == Double.PositiveInfinity)
BigDecimal("1000000.000000") * (if (d == Double.PositiveInfinity) 1
else -1)
else
BigDecimal(d).setScale(
6,
if (d > 0)
BigDecimal.RoundingMode.FLOOR
else
BigDecimal.RoundingMode.CEILING
)
}

/** drop in conversion for doubles to handle serialization on cluster */
trait ToDecimal[A] {
def truncate(a: A): Option[BigDecimal]
}

object ToDecimal {

implicit val doubleToDecimal: ToDecimal[Double] = new ToDecimal[Double] {

override def truncate(a: Double): Option[BigDecimal] =
DoubleBehaviourUtils.nanNullHandler(a)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._

def prop[A: CatalystNumeric: TypedEncoder: Encoder](
def prop[A: CatalystNumeric: TypedEncoder: Encoder: CatalystOrdered](
na: A,
values: List[X1[A]]
)(implicit
Expand All @@ -811,6 +811,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.map(DoubleBehaviourUtils.nanNullHandler)
.collect()
.toList
.sorted

val typedDS = TypedDataset.create(cDS)
val res = typedDS
Expand All @@ -820,20 +821,26 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.collect()
.run()
.toList
.sorted

val aggrTyped = typedDS
.orderBy(typedDS('a).asc)
.agg(atan(frameless.functions.aggregate.first(typedDS('a))))
.firstOption()
.run()
.get

val aggrSpark = cDS
.orderBy("a")
.select(
sparkFunctions.atan(sparkFunctions.first("a")).as[Double]
)
.first()

(res ?= resCompare).&&(aggrTyped ?= aggrSpark)
(res ?= resCompare).&&(
DoubleBehaviourUtils.nanNullHandler(aggrTyped) ?= DoubleBehaviourUtils
.nanNullHandler(aggrSpark)
)
}

check(forAll(prop[Int] _))
Expand All @@ -849,8 +856,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
import spark.implicits._

def prop[
A: CatalystNumeric: TypedEncoder: Encoder,
B: CatalystNumeric: TypedEncoder: Encoder
A: CatalystNumeric: TypedEncoder: Encoder: CatalystOrdered,
B: CatalystNumeric: TypedEncoder: Encoder: CatalystOrdered
](na: X2[A, B],
values: List[X2[A, B]]
)(implicit
Expand All @@ -863,6 +870,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.map(DoubleBehaviourUtils.nanNullHandler)
.collect()
.toList
.sorted

val typedDS = TypedDataset.create(cDS)
val res = typedDS
Expand All @@ -872,8 +880,10 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.collect()
.run()
.toList
.sorted

val aggrTyped = typedDS
.orderBy(typedDS('a).asc, typedDS('b).asc)
.agg(
atan2(
frameless.functions.aggregate.first(typedDS('a)),
Expand All @@ -885,14 +895,18 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.get

val aggrSpark = cDS
.orderBy("a", "b")
.select(
sparkFunctions
.atan2(sparkFunctions.first("a"), sparkFunctions.first("b"))
.as[Double]
)
.first()

(res ?= resCompare).&&(aggrTyped ?= aggrSpark)
(res ?= resCompare).&&(
DoubleBehaviourUtils.nanNullHandler(aggrTyped) ?= DoubleBehaviourUtils
.nanNullHandler(aggrSpark)
)
}

check(forAll(prop[Int, Long] _))
Expand All @@ -907,7 +921,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._

def prop[A: CatalystNumeric: TypedEncoder: Encoder](
def prop[A: CatalystNumeric: TypedEncoder: Encoder: CatalystOrdered](
na: X1[A],
value: List[X1[A]],
lit: Double
Expand All @@ -921,6 +935,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.map(DoubleBehaviourUtils.nanNullHandler)
.collect()
.toList
.sorted

val typedDS = TypedDataset.create(cDS)
val res = typedDS
Expand All @@ -930,20 +945,26 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.collect()
.run()
.toList
.sorted

val aggrTyped = typedDS
.orderBy(typedDS('a).asc)
.agg(atan2(lit, frameless.functions.aggregate.first(typedDS('a))))
.firstOption()
.run()
.get

val aggrSpark = cDS
.orderBy("a")
.select(
sparkFunctions.atan2(lit, sparkFunctions.first("a")).as[Double]
)
.first()

(res ?= resCompare).&&(aggrTyped ?= aggrSpark)
(res ?= resCompare).&&(
DoubleBehaviourUtils.nanNullHandler(aggrTyped) ?= DoubleBehaviourUtils
.nanNullHandler(aggrSpark)
)
}

check(forAll(prop[Int] _))
Expand All @@ -958,7 +979,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._

def prop[A: CatalystNumeric: TypedEncoder: Encoder](
def prop[A: CatalystNumeric: TypedEncoder: Encoder: CatalystOrdered](
na: X1[A],
value: List[X1[A]],
lit: Double
Expand All @@ -972,6 +993,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.map(DoubleBehaviourUtils.nanNullHandler)
.collect()
.toList
.sorted

val typedDS = TypedDataset.create(cDS)
val res = typedDS
Expand All @@ -981,20 +1003,26 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.collect()
.run()
.toList
.sorted

val aggrTyped = typedDS
.orderBy(typedDS('a).asc)
.agg(atan2(frameless.functions.aggregate.first(typedDS('a)), lit))
.firstOption()
.run()
.get

val aggrSpark = cDS
.orderBy("a")
.select(
sparkFunctions.atan2(sparkFunctions.first("a"), lit).as[Double]
)
.first()

(res ?= resCompare).&&(aggrTyped ?= aggrSpark)
(res ?= resCompare).&&(
DoubleBehaviourUtils.nanNullHandler(aggrTyped) ?= DoubleBehaviourUtils
.nanNullHandler(aggrSpark)
)
}

check(forAll(prop[Int] _))
Expand Down Expand Up @@ -2139,15 +2167,18 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.map(_.getAs[Int](0))
.collect()
.toVector
.sorted

val typed = ds
.select(levenshtein(ds('a), concat(ds('a), lit("Hello"))))
.collect()
.run()
.toVector
.sorted

val cDS = ds.dataset
val aggrTyped = ds
.orderBy(ds('a).asc)
.agg(
levenshtein(
frameless.functions.aggregate.first(ds('a)),
Expand All @@ -2159,6 +2190,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.get

val aggrSpark = cDS
.orderBy("a")
.select(
sparkFunctions
.levenshtein(sparkFunctions.first("a"), sparkFunctions.lit("Hello"))
Expand Down
Loading

0 comments on commit 66b31e9

Please sign in to comment.