Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support rand_gamma on spark 3.4+ #165

Merged
merged 2 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/core-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
fail-fast: false
matrix:
spark: ["3.0.3", "3.1.3", "3.2.4", "3.3.4", "3.4.3", "3.5.3"]
spark: ["3.2.4", "3.3.4", "3.4.3", "3.5.3"]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/unsafe-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
fail-fast: false
matrix:
spark: ["3.2.4", "3.3.4"]
spark: ["3.2.4", "3.3.4", "3.4.3", "3.5.3"]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ import org.apache.spark.util.Utils
object functions {
private def withExpr(expr: Expression): Column = Column(expr)

def randGamma(seed: Long, shape: Double, scale: Double): Column = withExpr(RandGamma(seed, shape, scale)).alias("gamma_random")
def randGamma(shape: Double, scale: Double): Column = randGamma(Utils.random.nextLong, shape, scale)
def randGamma(): Column = randGamma(1.0, 1.0)
def randGamma(seed: Long, shape: Double, scale: Double): Column = withExpr(RandGamma(seed, shape, scale)).alias("gamma_random")
def randGamma(seed: Column, shape: Column, scale: Column): Column = withExpr(RandGamma(seed.expr, shape.expr, scale.expr)).alias("gamma_random")
def randGamma(shape: Double, scale: Double): Column = randGamma(Utils.random.nextLong, shape, scale)
def randGamma(shape: Column, scale: Column): Column = randGamma(lit(Utils.random.nextLong), shape, scale)
def randGamma(): Column = randGamma(1.0, 1.0)

def randLaplace(seed: Long, mu: Double, beta: Double): Column = {
val mu_ = lit(mu)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.math3.distribution.GammaDistribution
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.RandGamma.defaultSeedExpression
import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
import org.apache.spark.sql.catalyst.expressions.codegen.FalseLiteral
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandomAdapted

import scala.util.{Success, Try}

case class RandGamma(child: Expression, shape: Expression, scale: Expression, hideSeed: Boolean = false)
extends TernaryExpression
with ExpectsInputTypes
Expand Down Expand Up @@ -43,7 +40,7 @@ case class RandGamma(child: Expression, shape: Expression, scale: Expression, hi
distribution = new GammaDistribution(new XORShiftRandomAdapted(seed + partitionIndex), shapeVal, scaleVal)
}

def this() = this(defaultSeedExpression, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true)
def this() = this(UnresolvedSeed, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true)

def this(child: Expression, shape: Expression, scale: Expression) = this(child, shape, scale, false)

Expand Down Expand Up @@ -87,10 +84,4 @@ case class RandGamma(child: Expression, shape: Expression, scale: Expression, hi
object RandGamma {
def apply(seed: Long, shape: Double, scale: Double): RandGamma =
RandGamma(Literal(seed, LongType), Literal(shape, DoubleType), Literal(scale, DoubleType))

def defaultSeedExpression: Expression =
Try(Class.forName("org.apache.spark.sql.catalyst.analysis.UnresolvedSeed")) match {
case Success(clazz) => clazz.getConstructor().newInstance().asInstanceOf[Expression]
case _ => Literal(Utils.random.nextLong(), LongType)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.math3.distribution.GammaDistribution
import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.FalseLiteral
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.util.random.XORShiftRandomAdapted

case class RandGamma(child: Expression, shape: Expression, scale: Expression, hideSeed: Boolean = false)
extends TernaryExpression
with ExpectsInputTypes
with Nondeterministic
with ExpressionWithRandomSeed {

def seedExpression: Expression = child

@transient protected lazy val seed: Long = seedExpression match {
case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int]
case e if e.dataType == LongType => e.eval().asInstanceOf[Long]
}

@transient protected lazy val shapeVal: Double = shape.dataType match {
case IntegerType => shape.eval().asInstanceOf[Int]
case LongType => shape.eval().asInstanceOf[Long]
case FloatType | DoubleType => shape.eval().asInstanceOf[Double]
}

@transient protected lazy val scaleVal: Double = scale.dataType match {
case IntegerType => scale.eval().asInstanceOf[Int]
case LongType => scale.eval().asInstanceOf[Long]
case FloatType | DoubleType => scale.eval().asInstanceOf[Double]
}

@transient private var distribution: GammaDistribution = _

override protected def initializeInternal(partitionIndex: Int): Unit = {
distribution = new GammaDistribution(new XORShiftRandomAdapted(seed + partitionIndex), shapeVal, scaleVal)
}

def this() = this(UnresolvedSeed, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true)

def this(child: Expression, shape: Expression, scale: Expression) = this(child, shape, scale, false)

def withNewSeed(seed: Long): RandGamma = RandGamma(Literal(seed, LongType), shape, scale, hideSeed)

protected def evalInternal(input: InternalRow): Double = distribution.sample()

def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val distributionClassName = classOf[GammaDistribution].getName
val rngClassName = classOf[XORShiftRandomAdapted].getName
val disTerm = ctx.addMutableState(distributionClassName, "distribution")
ctx.addPartitionInitializationStatement(
s"$disTerm = new $distributionClassName(new $rngClassName(${seed}L + partitionIndex), $shapeVal, $scaleVal);"
)
ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $disTerm.sample();""", isNull = FalseLiteral)
}

def freshCopy(): RandGamma = RandGamma(child, shape, scale, hideSeed)

override def flatArguments: Iterator[Any] = Iterator(child, shape, scale)

override def prettyName: String = "rand_gamma"

override def sql: String = s"rand_gamma(${if (hideSeed) "" else s"${child.sql}, ${shape.sql}, ${scale.sql}"})"

override def stateful: Boolean = true

def inputTypes: Seq[AbstractDataType] = Seq(LongType, DoubleType, DoubleType)

def dataType: DataType = DoubleType

def first: Expression = child

def second: Expression = shape

def third: Expression = scale

protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
copy(child = newFirst, shape = newSecond, scale = newThird)
}

object RandGamma {
def apply(seed: Long, shape: Double, scale: Double): RandGamma =
RandGamma(Literal(seed, LongType), Literal(shape, DoubleType), Literal(scale, DoubleType))
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,28 @@ object functionsTests extends TestSuite with DataFrameComparer with ColumnCompar
assert(math.abs(gammaMean - 4.0) < 0.5)
assert(math.abs(gammaStddev - math.sqrt(8.0)) < 0.5)
}

"has correct mean and standard deviation from shape/scale column" - {
val sourceDF = spark
.range(100000)
.withColumn("shape", lit(2.0))
.withColumn("scale", lit(2.0))
.select(randGamma(col("shape"), col("shape")))
val stats = sourceDF
.agg(
mean("gamma_random").as("mean"),
stddev("gamma_random").as("stddev")
)
.collect()(0)

val gammaMean = stats.getAs[Double]("mean")
val gammaStddev = stats.getAs[Double]("stddev")

// Gamma distribution with shape=2.0 and scale=2.0 has mean=4.0 and stddev=sqrt(8.0)
assert(gammaMean > 0)
assert(math.abs(gammaMean - 4.0) < 0.5)
assert(math.abs(gammaStddev - math.sqrt(8.0)) < 0.5)
}
}

'rand_laplace - {
Expand Down
Loading