diff --git a/src/main/scala/org/apache/spark/sql/catalyst/expressions/RandGamma.scala b/src/main/scala/org/apache/spark/sql/catalyst/expressions/RandGamma.scala index 8085068..bcbcd0c 100644 --- a/src/main/scala/org/apache/spark/sql/catalyst/expressions/RandGamma.scala +++ b/src/main/scala/org/apache/spark/sql/catalyst/expressions/RandGamma.scala @@ -2,13 +2,15 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.commons.math3.distribution.GammaDistribution import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.RandGamma.defaultSeedExpression import org.apache.spark.sql.catalyst.expressions.codegen.FalseLiteral import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandomAdapted -import scala.util.Try +import scala.util.{Success, Try} case class RandGamma(child: Expression, shape: Expression, scale: Expression, hideSeed: Boolean = false) extends TernaryExpression @@ -40,7 +42,7 @@ case class RandGamma(child: Expression, shape: Expression, scale: Expression, hi } @transient private var distribution: GammaDistribution = _ - def this() = this(Try(org.apache.spark.sql.catalyst.analysis.UnresolvedSeed).getOrElse(Literal(42L, LongType)), Literal(1.0, DoubleType), Literal(1.0, DoubleType), true) + def this() = this(defaultSeedExpression, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true) def this(child: Expression, shape: Expression, scale: Expression) = this(child, shape, scale, false) @@ -84,4 +86,9 @@ case class RandGamma(child: Expression, shape: Expression, scale: Expression, hi object RandGamma { def apply(seed: Long, shape: Double, scale: Double): RandGamma = RandGamma(Literal(seed, LongType), Literal(shape, DoubleType), Literal(scale, DoubleType)) + + def defaultSeedExpression: Expression = Try(Class.forName("org.apache.spark.sql.catalyst.analysis.UnresolvedSeed")) match { + case Success(clazz) => clazz.getConstructor().newInstance().asInstanceOf[Expression] + case _ => Literal(Utils.random.nextLong(), LongType) + } }