Skip to content

Commit

Permalink
- Add support for randn/rand_range/rand_laplace with column
Browse files Browse the repository at this point in the history
- Fix laplace distribution have wrong standard deviation
  • Loading branch information
zeotuan committed Oct 12, 2024
2 parents eef06ec + ffa6005 commit bb8dade
Show file tree
Hide file tree
Showing 7 changed files with 286 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/core-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
fail-fast: false
matrix:
spark: ["3.0.1", "3.1.3", "3.2.4", "3.3.4"]
spark: ["3.2.4", "3.3.4", "3.4.3", "3.5.3"]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/unsafe-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
fail-fast: false
matrix:
spark: ["3.2.4", "3.3.4"]
spark: ["3.2.4", "3.3.4", "3.4.3", "3.5.3"]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
Expand Down
26 changes: 19 additions & 7 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
import scala.language.postfixOps

Compile / scalafmtOnCompile := true

organization := "com.github.mrpowers"
name := "spark-daria"

version := "1.2.3"

crossScalaVersions := Seq("2.12.15", "2.13.8")
scalaVersion := "2.12.15"

val versionRegex = """^(.*)\.(.*)\.(.*)$""".r

val scala2_13 = "2.13.14"
val scala2_13 = "2.13.15"
val scala2_12 = "2.12.20"

val sparkVersion = System.getProperty("spark.testVersion", "3.3.4")
crossScalaVersions := {
sparkVersion match {
case versionRegex("3", m, _) if m.toInt >= 2 => Seq(scala2_12, scala2_13)
case versionRegex("3", _, _) => Seq(scala2_12)
case versionRegex("4", _, _) => Seq(scala2_13)
}
}

Expand All @@ -32,9 +32,9 @@ lazy val commonSettings = Seq(
libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-sql" % sparkVersion % "provided",
"org.apache.spark" %% "spark-mllib" % sparkVersion % "provided",
"com.github.mrpowers" %% "spark-fast-tests" % "1.1.0" % "test",
"com.lihaoyi" %% "utest" % "0.7.11" % "test",
"com.lihaoyi" %% "os-lib" % "0.8.0" % "test"
"com.github.mrpowers" %% "spark-fast-tests" % "1.3.0" % "test",
"com.lihaoyi" %% "utest" % "0.8.2" % "test",
"com.lihaoyi" %% "os-lib" % "0.10.3" % "test"
),
)

Expand All @@ -48,6 +48,18 @@ lazy val unsafe = (project in file("unsafe"))
.settings(
commonSettings,
name := "unsafe",
Compile / unmanagedSourceDirectories ++= {
sparkVersion match {
case versionRegex(mayor, minor, _) =>
(Compile / sourceDirectory).value ** s"*spark_*$mayor.$minor*" / "scala" get
}
},
Test / unmanagedSourceDirectories ++= {
sparkVersion match {
case versionRegex(mayor, minor, _) =>
(Compile / sourceDirectory).value ** s"*spark_*$mayor.$minor*" / "scala" get
}
},
)

testFrameworks += new TestFramework("com.github.mrpowers.spark.daria.CustomFramework")
Expand Down
94 changes: 81 additions & 13 deletions unsafe/src/main/scala/org/apache/spark/sql/daria/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package org.apache.spark.sql.daria

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{Expression, RandGamma}
import org.apache.spark.sql.functions.{lit, log, when}
import org.apache.spark.sql.functions.{lit, log, signum, when}
import org.apache.spark.sql.{functions => F}
import org.apache.spark.util.Utils

Expand All @@ -15,7 +15,15 @@ object functions {
*
* @note The function is non-deterministic in general case.
*/
def rand_gamma(seed: Long, shape: Double, scale: Double): Column = withExpr(RandGamma(seed, shape, scale)).alias("gamma_random")
def rand_gamma(seed: Long, shape: Double, scale: Double): Column = withExpr(RandGamma(seed, shape, scale)).alias("gamma_random")

/**
* Generate a column with independent and identically distributed (i.i.d.) samples
* from the Gamma distribution with the specified shape and scale parameters.
*
* @note The function is non-deterministic in general case.
*/
def rand_gamma(seed: Column, shape: Column, scale: Column): Column = withExpr(RandGamma(seed.expr, shape.expr, scale.expr)).alias("gamma_random")

/**
* Generate a column with independent and identically distributed (i.i.d.) samples
Expand All @@ -25,6 +33,14 @@ object functions {
*/
def rand_gamma(shape: Double, scale: Double): Column = rand_gamma(Utils.random.nextLong, shape, scale)

/**
* Generate a column with independent and identically distributed (i.i.d.) samples
* from the Gamma distribution with the specified shape and scale parameters.
*
* @note The function is non-deterministic in general case.
*/
def rand_gamma(shape: Column, scale: Column): Column = rand_gamma(lit(Utils.random.nextLong), shape, scale)

/**
* Generate a column with independent and identically distributed (i.i.d.) samples
* from the Gamma distribution with default parameters (shape = 1.0, scale = 1.0).
Expand All @@ -35,21 +51,35 @@ object functions {
*/
def rand_gamma(): Column = rand_gamma(1.0, 1.0)

/**
* Generate a column with independent and identically distributed (i.i.d.) samples
* from the Laplace distribution with the specified location parameter `mu` and scale parameter `beta`.
*
* @note The function is non-deterministic in general case.
*/
def rand_laplace(seed: Long, mu: Column, beta: Column): Column = {
val u = F.rand(seed) - lit(0.5)
mu - beta * signum(u) * log(lit(1) - (lit(2) * F.abs(u)))
}

/**
* Generate a column with independent and identically distributed (i.i.d.) samples
* from the Laplace distribution with the specified location parameter `mu` and scale parameter `beta`.
*
* @note The function is non-deterministic in general case.
*/
def rand_laplace(seed: Long, mu: Double, beta: Double): Column = {
val mu_ = lit(mu)
val beta_ = lit(beta)
val u = F.rand(seed)
when(u < 0.5, mu_ + beta_ * log(lit(2) * u))
.otherwise(mu_ - beta_ * log(lit(2) * (lit(1) - u)))
.alias("laplace_random")
rand_laplace(seed, lit(mu), lit(beta))
}

/**
* Generate a column with independent and identically distributed (i.i.d.) samples
* from the Laplace distribution with the specified location parameter `mu` and scale parameter `beta`.
*
* @note The function is non-deterministic in general case.
*/
def rand_laplace(mu: Column, beta: Column): Column = rand_laplace(Utils.random.nextLong, mu, beta)

/**
* Generate a column with independent and identically distributed (i.i.d.) samples
* from the Laplace distribution with the specified location parameter `mu` and scale parameter `beta`.
Expand All @@ -66,16 +96,24 @@ object functions {
*/
def rand_laplace(): Column = rand_laplace(0.0, 1.0)

/**
* Generate a random column with independent and identically distributed (i.i.d.) samples
* uniformly distributed in [`min`, `max`).
*
* @note The function is non-deterministic in general case.
*/
def rand_range(seed: Long, min: Column, max: Column): Column = {
min + (max - min) * F.rand(seed)
}

/**
* Generate a random column with independent and identically distributed (i.i.d.) samples
* uniformly distributed in [`min`, `max`).
*
* @note The function is non-deterministic in general case.
*/
def rand_range(seed: Long, min: Int, max: Int): Column = {
val min_ = lit(min)
val max_ = lit(max)
min_ + (max_ - min_) * F.rand(seed)
rand_range(seed, lit(min), lit(max))
}

/**
Expand All @@ -88,15 +126,35 @@ object functions {
rand_range(Utils.random.nextLong, min, max)
}

/**
* Generate a random column with independent and identically distributed (i.i.d.) samples
* uniformly distributed in [`min`, `max`).
*
* @note The function is non-deterministic in general case.
*/
def rand_range(min: Column, max: Column): Column = {
rand_range(Utils.random.nextLong, min, max)
}

/**
* Generate a column with independent and identically distributed (i.i.d.) samples from
* the standard normal distribution with given `mean` and `variance`.
*
* @note The function is non-deterministic in general case.
*/
def randn(seed: Long, mean: Column, variance: Column): Column = {
val stddev = F.sqrt(variance)
F.randn(seed) * stddev + lit(mean)
}

/**
* Generate a column with independent and identically distributed (i.i.d.) samples from
* the standard normal distribution with given `mean` and `variance`.
*
* @note The function is non-deterministic in general case.
*/
def randn(seed: Long, mean: Double, variance: Double): Column = {
val stddev = math.sqrt(variance)
F.randn(seed) * lit(stddev) + lit(mean)
randn(seed, lit(mean), lit(variance))
}

/**
Expand All @@ -108,4 +166,14 @@ object functions {
def randn(mean: Double, variance: Double): Column = {
randn(Utils.random.nextLong, mean, variance)
}

/**
* Generate a column with independent and identically distributed (i.i.d.) samples from
* the standard normal distribution with given `mean` and `variance`.
*
* @note The function is non-deterministic in general case.
*/
def randn(mean: Column, variance: Column): Column = {
randn(Utils.random.nextLong, mean, variance)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.math3.distribution.GammaDistribution
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.RandGamma.defaultSeedExpression
import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
import org.apache.spark.sql.catalyst.expressions.codegen.FalseLiteral
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandomAdapted

import scala.util.{Success, Try}

case class RandGamma(child: Expression, shape: Expression, scale: Expression, hideSeed: Boolean = false)
extends TernaryExpression
with ExpectsInputTypes
Expand Down Expand Up @@ -43,7 +40,7 @@ case class RandGamma(child: Expression, shape: Expression, scale: Expression, hi
distribution = new GammaDistribution(new XORShiftRandomAdapted(seed + partitionIndex), shapeVal, scaleVal)
}

def this() = this(defaultSeedExpression, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true)
def this() = this(UnresolvedSeed, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true)

def this(child: Expression, shape: Expression, scale: Expression) = this(child, shape, scale, false)

Expand Down Expand Up @@ -87,10 +84,4 @@ case class RandGamma(child: Expression, shape: Expression, scale: Expression, hi
object RandGamma {
def apply(seed: Long, shape: Double, scale: Double): RandGamma =
RandGamma(Literal(seed, LongType), Literal(shape, DoubleType), Literal(scale, DoubleType))

def defaultSeedExpression: Expression =
Try(Class.forName("org.apache.spark.sql.catalyst.analysis.UnresolvedSeed")) match {
case Success(clazz) => clazz.getConstructor().newInstance().asInstanceOf[Expression]
case _ => Literal(Utils.random.nextLong(), LongType)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.math3.distribution.GammaDistribution
import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.FalseLiteral
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.util.random.XORShiftRandomAdapted

case class RandGamma(child: Expression, shape: Expression, scale: Expression, hideSeed: Boolean = false)
extends TernaryExpression
with ExpectsInputTypes
with Nondeterministic
with ExpressionWithRandomSeed {

def seedExpression: Expression = child

@transient protected lazy val seed: Long = seedExpression match {
case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int]
case e if e.dataType == LongType => e.eval().asInstanceOf[Long]
}

@transient protected lazy val shapeVal: Double = shape.dataType match {
case IntegerType => shape.eval().asInstanceOf[Int]
case LongType => shape.eval().asInstanceOf[Long]
case FloatType | DoubleType => shape.eval().asInstanceOf[Double]
}

@transient protected lazy val scaleVal: Double = scale.dataType match {
case IntegerType => scale.eval().asInstanceOf[Int]
case LongType => scale.eval().asInstanceOf[Long]
case FloatType | DoubleType => scale.eval().asInstanceOf[Double]
}

@transient private var distribution: GammaDistribution = _

override protected def initializeInternal(partitionIndex: Int): Unit = {
distribution = new GammaDistribution(new XORShiftRandomAdapted(seed + partitionIndex), shapeVal, scaleVal)
}

def this() = this(UnresolvedSeed, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true)

def this(child: Expression, shape: Expression, scale: Expression) = this(child, shape, scale, false)

def withNewSeed(seed: Long): RandGamma = RandGamma(Literal(seed, LongType), shape, scale, hideSeed)

protected def evalInternal(input: InternalRow): Double = distribution.sample()

def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val distributionClassName = classOf[GammaDistribution].getName
val rngClassName = classOf[XORShiftRandomAdapted].getName
val disTerm = ctx.addMutableState(distributionClassName, "distribution")
ctx.addPartitionInitializationStatement(
s"$disTerm = new $distributionClassName(new $rngClassName(${seed}L + partitionIndex), $shapeVal, $scaleVal);"
)
ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $disTerm.sample();""", isNull = FalseLiteral)
}

def freshCopy(): RandGamma = RandGamma(child, shape, scale, hideSeed)

override def flatArguments: Iterator[Any] = Iterator(child, shape, scale)

override def prettyName: String = "rand_gamma"

override def sql: String = s"rand_gamma(${if (hideSeed) "" else s"${child.sql}, ${shape.sql}, ${scale.sql}"})"

override def stateful: Boolean = true

def inputTypes: Seq[AbstractDataType] = Seq(LongType, DoubleType, DoubleType)

def dataType: DataType = DoubleType

def first: Expression = child

def second: Expression = shape

def third: Expression = scale

protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
copy(child = newFirst, shape = newSecond, scale = newThird)
}

object RandGamma {
def apply(seed: Long, shape: Double, scale: Double): RandGamma =
RandGamma(Literal(seed, LongType), Literal(shape, DoubleType), Literal(scale, DoubleType))
}
Loading

0 comments on commit bb8dade

Please sign in to comment.