Skip to content

Commit

Permalink
use reflection to get UnresolvedSeed
Browse files Browse the repository at this point in the history
  • Loading branch information
zeotuan committed Sep 23, 2024
1 parent a7d7559 commit 50cbdfe
Showing 1 changed file with 9 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.math3.distribution.GammaDistribution
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.RandGamma.defaultSeedExpression
import org.apache.spark.sql.catalyst.expressions.codegen.FalseLiteral
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandomAdapted

import scala.util.Try
import scala.util.{Success, Try}

case class RandGamma(child: Expression, shape: Expression, scale: Expression, hideSeed: Boolean = false)
extends TernaryExpression
Expand Down Expand Up @@ -40,7 +42,7 @@ case class RandGamma(child: Expression, shape: Expression, scale: Expression, hi
}
@transient private var distribution: GammaDistribution = _

def this() = this(Try(org.apache.spark.sql.catalyst.analysis.UnresolvedSeed).getOrElse(Literal(42L, LongType)), Literal(1.0, DoubleType), Literal(1.0, DoubleType), true)
def this() = this(defaultSeedExpression, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true)

def this(child: Expression, shape: Expression, scale: Expression) = this(child, shape, scale, false)

Expand Down Expand Up @@ -84,4 +86,9 @@ case class RandGamma(child: Expression, shape: Expression, scale: Expression, hi
object RandGamma {
def apply(seed: Long, shape: Double, scale: Double): RandGamma =
RandGamma(Literal(seed, LongType), Literal(shape, DoubleType), Literal(scale, DoubleType))

def defaultSeedExpression: Expression = Try(Class.forName("org.apache.spark.sql.catalyst.analysis.UnresolvedSeed")) match {
case Success(clazz) => clazz.getConstructor().newInstance().asInstanceOf[Expression]
case _ => Literal(Utils.random.nextLong(), LongType)
}
}

0 comments on commit 50cbdfe

Please sign in to comment.