From 03c951f9488c1a7eb3a18da7934ebd232e4beddf Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Fri, 11 Oct 2024 19:51:44 +1100 Subject: [PATCH 1/2] support rand_gamma on spark 3.4+ --- .github/workflows/core-ci.yml | 2 +- .github/workflows/unsafe-ci.yml | 2 +- .../sql/catalyst/expressions/RandGamma.scala | 13 +-- .../sql/catalyst/expressions/RandGamma.scala | 89 +++++++++++++++++++ 4 files changed, 93 insertions(+), 13 deletions(-) rename unsafe/src/main/{scala/org/apache/spark => spark_3.2_3.3/scala/org.apache.spark}/sql/catalyst/expressions/RandGamma.scala (86%) create mode 100644 unsafe/src/main/spark_3.4_3.5/scala/org.apache.spark/sql/catalyst/expressions/RandGamma.scala diff --git a/.github/workflows/core-ci.yml b/.github/workflows/core-ci.yml index a080076..1ead725 100644 --- a/.github/workflows/core-ci.yml +++ b/.github/workflows/core-ci.yml @@ -11,7 +11,7 @@ jobs: strategy: fail-fast: false matrix: - spark: ["3.0.3", "3.1.3", "3.2.4", "3.3.4", "3.4.3", "3.5.3"] + spark: ["3.2.4", "3.3.4", "3.4.3", "3.5.3"] runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 diff --git a/.github/workflows/unsafe-ci.yml b/.github/workflows/unsafe-ci.yml index 9b25c31..fee5eaa 100644 --- a/.github/workflows/unsafe-ci.yml +++ b/.github/workflows/unsafe-ci.yml @@ -11,7 +11,7 @@ jobs: strategy: fail-fast: false matrix: - spark: ["3.2.4", "3.3.4"] + spark: ["3.2.4", "3.3.4", "3.4.3", "3.5.3"] runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 diff --git a/unsafe/src/main/scala/org/apache/spark/sql/catalyst/expressions/RandGamma.scala b/unsafe/src/main/spark_3.2_3.3/scala/org.apache.spark/sql/catalyst/expressions/RandGamma.scala similarity index 86% rename from unsafe/src/main/scala/org/apache/spark/sql/catalyst/expressions/RandGamma.scala rename to unsafe/src/main/spark_3.2_3.3/scala/org.apache.spark/sql/catalyst/expressions/RandGamma.scala index ce3aed5..4141478 100644 --- a/unsafe/src/main/scala/org/apache/spark/sql/catalyst/expressions/RandGamma.scala +++ b/unsafe/src/main/spark_3.2_3.3/scala/org.apache.spark/sql/catalyst/expressions/RandGamma.scala @@ -2,16 +2,13 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.commons.math3.distribution.GammaDistribution import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.RandGamma.defaultSeedExpression +import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed import org.apache.spark.sql.catalyst.expressions.codegen.FalseLiteral import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandomAdapted -import scala.util.{Success, Try} - case class RandGamma(child: Expression, shape: Expression, scale: Expression, hideSeed: Boolean = false) extends TernaryExpression with ExpectsInputTypes @@ -43,7 +40,7 @@ case class RandGamma(child: Expression, shape: Expression, scale: Expression, hi distribution = new GammaDistribution(new XORShiftRandomAdapted(seed + partitionIndex), shapeVal, scaleVal) } - def this() = this(defaultSeedExpression, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true) + def this() = this(UnresolvedSeed, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true) def this(child: Expression, shape: Expression, scale: Expression) = this(child, shape, scale, false) @@ -87,10 +84,4 @@ case class RandGamma(child: Expression, shape: Expression, scale: Expression, hi object RandGamma { def apply(seed: Long, shape: Double, scale: Double): RandGamma = RandGamma(Literal(seed, LongType), Literal(shape, DoubleType), Literal(scale, DoubleType)) - - def defaultSeedExpression: Expression = - Try(Class.forName("org.apache.spark.sql.catalyst.analysis.UnresolvedSeed")) match { - case Success(clazz) => clazz.getConstructor().newInstance().asInstanceOf[Expression] - case _ => Literal(Utils.random.nextLong(), LongType) - } } diff --git a/unsafe/src/main/spark_3.4_3.5/scala/org.apache.spark/sql/catalyst/expressions/RandGamma.scala b/unsafe/src/main/spark_3.4_3.5/scala/org.apache.spark/sql/catalyst/expressions/RandGamma.scala new file mode 100644 index 0000000..ab2df0e --- /dev/null +++ b/unsafe/src/main/spark_3.4_3.5/scala/org.apache.spark/sql/catalyst/expressions/RandGamma.scala @@ -0,0 +1,89 @@ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.commons.math3.distribution.GammaDistribution +import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.FalseLiteral +import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} +import org.apache.spark.sql.types._ +import org.apache.spark.util.random.XORShiftRandomAdapted + +case class RandGamma(child: Expression, shape: Expression, scale: Expression, hideSeed: Boolean = false) + extends TernaryExpression + with ExpectsInputTypes + with Nondeterministic + with ExpressionWithRandomSeed { + + def seedExpression: Expression = child + + @transient protected lazy val seed: Long = seedExpression match { + case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int] + case e if e.dataType == LongType => e.eval().asInstanceOf[Long] + } + + @transient protected lazy val shapeVal: Double = shape.dataType match { + case IntegerType => shape.eval().asInstanceOf[Int] + case LongType => shape.eval().asInstanceOf[Long] + case FloatType | DoubleType => shape.eval().asInstanceOf[Double] + } + + @transient protected lazy val scaleVal: Double = scale.dataType match { + case IntegerType => scale.eval().asInstanceOf[Int] + case LongType => scale.eval().asInstanceOf[Long] + case FloatType | DoubleType => scale.eval().asInstanceOf[Double] + } + + @transient private var distribution: GammaDistribution = _ + + override protected def initializeInternal(partitionIndex: Int): Unit = { + distribution = new GammaDistribution(new XORShiftRandomAdapted(seed + partitionIndex), shapeVal, scaleVal) + } + + def this() = this(UnresolvedSeed, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true) + + def this(child: Expression, shape: Expression, scale: Expression) = this(child, shape, scale, false) + + def withNewSeed(seed: Long): RandGamma = RandGamma(Literal(seed, LongType), shape, scale, hideSeed) + + protected def evalInternal(input: InternalRow): Double = distribution.sample() + + def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val distributionClassName = classOf[GammaDistribution].getName + val rngClassName = classOf[XORShiftRandomAdapted].getName + val disTerm = ctx.addMutableState(distributionClassName, "distribution") + ctx.addPartitionInitializationStatement( + s"$disTerm = new $distributionClassName(new $rngClassName(${seed}L + partitionIndex), $shapeVal, $scaleVal);" + ) + ev.copy(code = code""" + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $disTerm.sample();""", isNull = FalseLiteral) + } + + def freshCopy(): RandGamma = RandGamma(child, shape, scale, hideSeed) + + override def flatArguments: Iterator[Any] = Iterator(child, shape, scale) + + override def prettyName: String = "rand_gamma" + + override def sql: String = s"rand_gamma(${if (hideSeed) "" else s"${child.sql}, ${shape.sql}, ${scale.sql}"})" + + override def stateful: Boolean = true + + def inputTypes: Seq[AbstractDataType] = Seq(LongType, DoubleType, DoubleType) + + def dataType: DataType = DoubleType + + def first: Expression = child + + def second: Expression = shape + + def third: Expression = scale + + protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + copy(child = newFirst, shape = newSecond, scale = newThird) +} + +object RandGamma { + def apply(seed: Long, shape: Double, scale: Double): RandGamma = + RandGamma(Literal(seed, LongType), Literal(shape, DoubleType), Literal(scale, DoubleType)) +} From 178acc13cba91c42f0c98f608e4ff809fdcdb2ea Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Fri, 11 Oct 2024 23:28:55 +1100 Subject: [PATCH 2/2] public randGamma API with shape/scale expression --- .../apache/spark/sql/daria/functions.scala | 8 ++++--- .../spark/sql/daria/functionsTests.scala | 22 +++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/unsafe/src/main/scala/org/apache/spark/sql/daria/functions.scala b/unsafe/src/main/scala/org/apache/spark/sql/daria/functions.scala index d4322e5..2445065 100644 --- a/unsafe/src/main/scala/org/apache/spark/sql/daria/functions.scala +++ b/unsafe/src/main/scala/org/apache/spark/sql/daria/functions.scala @@ -8,9 +8,11 @@ import org.apache.spark.util.Utils object functions { private def withExpr(expr: Expression): Column = Column(expr) - def randGamma(seed: Long, shape: Double, scale: Double): Column = withExpr(RandGamma(seed, shape, scale)).alias("gamma_random") - def randGamma(shape: Double, scale: Double): Column = randGamma(Utils.random.nextLong, shape, scale) - def randGamma(): Column = randGamma(1.0, 1.0) + def randGamma(seed: Long, shape: Double, scale: Double): Column = withExpr(RandGamma(seed, shape, scale)).alias("gamma_random") + def randGamma(seed: Column, shape: Column, scale: Column): Column = withExpr(RandGamma(seed.expr, shape.expr, scale.expr)).alias("gamma_random") + def randGamma(shape: Double, scale: Double): Column = randGamma(Utils.random.nextLong, shape, scale) + def randGamma(shape: Column, scale: Column): Column = randGamma(lit(Utils.random.nextLong), shape, scale) + def randGamma(): Column = randGamma(1.0, 1.0) def randLaplace(seed: Long, mu: Double, beta: Double): Column = { val mu_ = lit(mu) diff --git a/unsafe/src/test/scala/org/apache/spark/sql/daria/functionsTests.scala b/unsafe/src/test/scala/org/apache/spark/sql/daria/functionsTests.scala index 5147c13..b529c1e 100644 --- a/unsafe/src/test/scala/org/apache/spark/sql/daria/functionsTests.scala +++ b/unsafe/src/test/scala/org/apache/spark/sql/daria/functionsTests.scala @@ -27,6 +27,28 @@ object functionsTests extends TestSuite with DataFrameComparer with ColumnCompar assert(math.abs(gammaMean - 4.0) < 0.5) assert(math.abs(gammaStddev - math.sqrt(8.0)) < 0.5) } + + "has correct mean and standard deviation from shape/scale column" - { + val sourceDF = spark + .range(100000) + .withColumn("shape", lit(2.0)) + .withColumn("scale", lit(2.0)) + .select(randGamma(col("shape"), col("shape"))) + val stats = sourceDF + .agg( + mean("gamma_random").as("mean"), + stddev("gamma_random").as("stddev") + ) + .collect()(0) + + val gammaMean = stats.getAs[Double]("mean") + val gammaStddev = stats.getAs[Double]("stddev") + + // Gamma distribution with shape=2.0 and scale=2.0 has mean=4.0 and stddev=sqrt(8.0) + assert(gammaMean > 0) + assert(math.abs(gammaMean - 4.0) < 0.5) + assert(math.abs(gammaStddev - math.sqrt(8.0)) < 0.5) + } } 'rand_laplace - {