From 5e0c3bf45f85473f4528cea65856f37671c751ab Mon Sep 17 00:00:00 2001 From: Shyam Sai Date: Sat, 2 Nov 2024 22:43:58 -0500 Subject: [PATCH 1/4] GlobalParamObject implementation with higher-order getters. DeploymentName added to GlobalParam set --- .../ml/services/openai/GlobalParam.scala | 161 ++++++++++++++++++ .../synapse/ml/services/openai/OpenAI.scala | 6 +- .../openai/OpenAIChatCompletion.scala | 2 +- .../ml/services/openai/OpenAICompletion.scala | 2 +- .../ml/services/openai/OpenAIEmbedding.scala | 2 +- .../ml/services/openai/GlobalParamSuite.scala | 64 +++++++ 6 files changed, 232 insertions(+), 5 deletions(-) create mode 100644 cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParam.scala create mode 100644 cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParamSuite.scala diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParam.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParam.scala new file mode 100644 index 0000000000..bdac25dd61 --- /dev/null +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParam.scala @@ -0,0 +1,161 @@ +package com.microsoft.azure.synapse.ml.services.openai + +import com.microsoft.azure.synapse.ml.param.ServiceParam +import com.microsoft.azure.synapse.ml.services.HasServiceParams +import org.apache.spark.ml.param.Param +import org.apache.spark.sql.Row + +import scala.collection.mutable + +sealed trait GlobalParamKey[T] +case object OpenAIDeploymentNameKey extends GlobalParamKey[String] + +object GlobalParamObject { + private val ParamToKeyMap: mutable.Map[Any, GlobalParamKey[_]] = mutable.Map.empty + private val GlobalParams: mutable.Map[GlobalParamKey[Any], Any] = mutable.Map.empty + + def setGlobalParamKey[T](key: GlobalParamKey[T], value: T): Unit = { + GlobalParams(key.asInstanceOf[GlobalParamKey[Any]]) = value + } + + def getGlobalParamKey[T](key: GlobalParamKey[T]): Option[T] = { + GlobalParams.get(key.asInstanceOf[GlobalParamKey[Any]]).map(_.asInstanceOf[T]) + } + + def getParam[T](p: Param[T]): Option[T] = { + ParamToKeyMap.get(p).flatMap { key => + key match { + case k: GlobalParamKey[T] => getGlobalParamKey(k) + case _ => None + } + } + } + + def getServiceParam[T](sp: ServiceParam[T]): Option[T] = { + ParamToKeyMap.get(sp).flatMap { key => + key match { + case k: GlobalParamKey[T] => getGlobalParamKey(k) + case _ => None + } + } + } + + def registerParam[T](p: Param[T], key: GlobalParamKey[T]): Unit = { + ParamToKeyMap(p) = key + } + + def registerServiceParam[T](sp: ServiceParam[T], key: GlobalParamKey[T]): Unit = { + ParamToKeyMap(sp) = key + } +} + + + +trait HasGlobalParams extends HasServiceParams { + + def getGlobalParam[T](p: Param[T]): T = { + try { + this.getOrDefault(p) + } + catch { + case e: Exception => + GlobalParamObject.getParam(p) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalParam[T](name: String): T = { + val param = this.getParam(name).asInstanceOf[Param[T]] + try { + this.getOrDefault(param) + } + catch { + case e: Exception => + GlobalParamObject.getParam(param) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamScalar[T](p: ServiceParam[T]): T = { + try { + this.getScalarParam(p) + } + catch { + case e: Exception => + GlobalParamObject.getServiceParam(p) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamVector[T](p: ServiceParam[T]): T= { //TODO: should return a string + try { + this.getScalarParam(p) + } + catch { + case e: Exception => + GlobalParamObject.getServiceParam(p)match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamScalar[T](name: String): T = { + val serviceParam = this.getParam(name).asInstanceOf[ServiceParam[T]] + try { + this.getScalarParam(serviceParam) + } + catch { + case e: Exception => + GlobalParamObject.getServiceParam(serviceParam) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamVector[T](name: String): T = { //TODO: should return a string + val serviceParam = this.getParam(name).asInstanceOf[ServiceParam[T]] + try { + this.getScalarParam(serviceParam) + } + catch { + case e: Exception => + GlobalParamObject.getServiceParam(serviceParam) match { + case Some(v) => v + case None => throw e + } + } + } + + protected def getGlobalServiceParamValueOpt[T](row: Row, p: ServiceParam[T]): Option[T] = { + //TODO: Confirm with Mark what this colName business is about + try { + get(p).orElse(getDefault(p)).flatMap { + case Right(colName) => Option(row.getAs[T](colName)) + case Left(value) => Some(value) + } + match { + case some @ Some(_) => some + case None => GlobalParamObject.getServiceParam(p) + } + } + catch { + case e: Exception => + GlobalParamObject.getServiceParam(p) match { + case Some(v) => Some(v) + case None => throw e + } + } + } + + + protected def getGlobalServiceParamValue[T](row: Row, p: ServiceParam[T]): T = + getGlobalServiceParamValueOpt(row, p).get +} \ No newline at end of file diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index b57f4d65da..883b2b7c61 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -51,12 +51,14 @@ trait HasMessagesInput extends Params { def setMessagesCol(v: String): this.type = set(messagesCol, v) } -trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion { +trait HasOpenAISharedParams extends HasGlobalParams with HasAPIVersion { val deploymentName = new ServiceParam[String]( this, "deploymentName", "The name of the deployment", isRequired = false) - def getDeploymentName: String = getScalarParam(deploymentName) + GlobalParamObject.registerServiceParam(deploymentName, OpenAIDeploymentNameKey) + + def getDeploymentName: String = getGlobalServiceParamScalar(deploymentName) def setDeploymentName(v: String): this.type = setScalarParam(deploymentName, v) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala index 703fc7f471..e667cc0625 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala @@ -35,7 +35,7 @@ class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase( } override protected def prepareUrlRoot: Row => String = { row => - s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/chat/completions" + s"${getUrl}openai/deployments/${getGlobalServiceParamValue(row, deploymentName)}/chat/completions" } override protected[openai] def prepareEntity: Row => Option[AbstractHttpEntity] = { diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala index 953138bc36..5fbbb1243b 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala @@ -34,7 +34,7 @@ class OpenAICompletion(override val uid: String) extends OpenAIServicesBase(uid) } override protected def prepareUrlRoot: Row => String = { row => - s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/completions" + s"${getUrl}openai/deployments/${getGlobalServiceParamValue(row, deploymentName)}/completions" } override protected[openai] def prepareEntity: Row => Option[AbstractHttpEntity] = { diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala index 58f5c857d6..8c9ee199f3 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala @@ -59,7 +59,7 @@ class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid) } override protected def prepareUrlRoot: Row => String = { row => - s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/embeddings" + s"${getUrl}openai/deployments/${getGlobalServiceParamValue(row, deploymentName)}/embeddings" } private[this] def getStringEntity[A](text: A, optionalParams: Map[String, Any]): StringEntity = { diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParamSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParamSuite.scala new file mode 100644 index 0000000000..c3612ddf46 --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParamSuite.scala @@ -0,0 +1,64 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.openai + +import com.microsoft.azure.synapse.ml.Secrets.getAccessToken +import com.microsoft.azure.synapse.ml.core.test.base.Flaky +import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing} +import org.apache.spark.ml.util.MLReadable +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.{DataFrame, Row} +import org.scalactic.Equality + +class GlobalParamSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIKey with Flaky { + + import spark.implicits._ + + override def beforeAll(): Unit = { + val aadToken = getAccessToken("https://cognitiveservices.azure.com/") + println(s"Triggering token creation early ${aadToken.length}") + super.beforeAll() + } + + GlobalParamObject.setGlobalParamKey(OpenAIDeploymentNameKey, deploymentName) + + lazy val prompt: OpenAIPrompt = new OpenAIPrompt() + .setSubscriptionKey(openAIAPIKey) + .setCustomServiceName(openAIServiceName) + .setOutputCol("outParsed") + .setTemperature(0) + + lazy val df: DataFrame = Seq( + ("apple", "fruits"), + ("mercedes", "cars"), + ("cake", "dishes"), + (null, "none") //scalastyle:ignore null + ).toDF("text", "category") + + test("Basic Usage") { + val nonNullCount = prompt + .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") + .setPostProcessing("csv") + .transform(df) + .select("outParsed") + .collect() + .count(r => Option(r.getSeq[String](0)).isDefined) + + assert(nonNullCount == 3) + } + + override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = { + super.assertDFEq(df1.drop("out", "outParsed"), df2.drop("out", "outParsed"))(eq) + } + + override def testObjects(): Seq[TestObject[OpenAIPrompt]] = { + val testPrompt = prompt + .setPromptTemplate("{text} rhymes with ") + + Seq(new TestObject(testPrompt, df)) + } + + override def reader: MLReadable[_] = OpenAIPrompt + +} From 7451422098237906e7f8a47d8f932e5ff20551ea Mon Sep 17 00:00:00 2001 From: Shyam Sai Date: Thu, 7 Nov 2024 11:44:13 -0500 Subject: [PATCH 2/4] Enable ServiceParams as part of GlobalParam --- .../synapse/ml/services/GlobalParam.scala | 213 ++++++++++++++++++ .../ml/services/openai/GlobalParam.scala | 161 ------------- .../synapse/ml/services/openai/OpenAI.scala | 2 +- .../{openai => }/GlobalParamSuite.scala | 24 +- 4 files changed, 219 insertions(+), 181 deletions(-) create mode 100644 cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala delete mode 100644 cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParam.scala rename cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/{openai => }/GlobalParamSuite.scala (64%) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala new file mode 100644 index 0000000000..b1546985d0 --- /dev/null +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala @@ -0,0 +1,213 @@ +package com.microsoft.azure.synapse.ml.services + +import com.microsoft.azure.synapse.ml.param.ServiceParam +import org.apache.spark.ml.param.Param +import org.apache.spark.sql.Row + +import scala.collection.mutable + +sealed trait GlobalKey[T] +sealed trait GlobalParamKey[T] extends GlobalKey[T] +sealed trait GlobalServiceParamKey[T] extends GlobalKey[Either[T, String]] //left is Scalar, right is Vector + +case object OpenAIDeploymentNameKey extends GlobalServiceParamKey[String] + +object GlobalParams { + private val ParamToKeyMap: mutable.Map[Any, GlobalKey[_]] = mutable.Map.empty + private val GlobalParams: mutable.Map[GlobalParamKey[Any], Any] = mutable.Map.empty + private val GlobalServiceParams: mutable.Map[GlobalServiceParamKey[Any], Either[Any, String]] = mutable.Map.empty + + private def setGlobalParamKey[T](key: GlobalParamKey[T], value: T): Unit = { + GlobalParams(key.asInstanceOf[GlobalParamKey[Any]]) = value + } + + private def setGlobalServiceParamKey[T](key: GlobalServiceParamKey[T], value: T): Unit = { + GlobalServiceParams(key.asInstanceOf[GlobalServiceParamKey[Any]]) = Left(value) + } + + private def setGlobalServiceParamKeyCol[T](key: GlobalServiceParamKey[T], value: String): Unit = { + GlobalServiceParams(key.asInstanceOf[GlobalServiceParamKey[Any]]) = Right(value) + } + + private def getGlobalParam[T](key: GlobalParamKey[T]): Option[T] = { + GlobalParams.get(key.asInstanceOf[GlobalParamKey[Any]]).map(_.asInstanceOf[T]) + } + + private def getGlobalServiceParam[T](key: GlobalServiceParamKey[T]): Option[Either[T, String]] = { + GlobalServiceParams.get(key.asInstanceOf[GlobalServiceParamKey[Any]]).map(_.asInstanceOf[Either[T, String]]) + } + + def getParam[T](p: Param[T]): Option[T] = { + ParamToKeyMap.get(p).flatMap { key => + key match { + case k: GlobalParamKey[T] => getGlobalParam(k) + case _ => None + } + } + } + + def getServiceParam[T](sp: ServiceParam[T]): Option[Either[T, String]] = { + ParamToKeyMap.get(sp).flatMap { key => + key match { + case k: GlobalServiceParamKey[T] => getGlobalServiceParam(k) + case _ => None + } + } + } + + def getServiceParamScalar[T](sp: ServiceParam[T]): Option[T] = { + ParamToKeyMap.get(sp).flatMap { key => + key match { + case k: GlobalServiceParamKey[T] => + getGlobalServiceParam(k) match { + case Some(Left(value)) => Some(value) + case _ => None + } + case _ => None + } + } + } + + def getServiceParamVector[T](sp: ServiceParam[T]): Option[String] = { + ParamToKeyMap.get(sp).flatMap { key => + key match { + case k: GlobalServiceParamKey[T] => + getGlobalServiceParam(k) match { + case Some(Right(colName)) => Some(colName) + case _ => None + } + case _ => None + } + } + } + + def registerParam[T](p: Param[T], key: GlobalParamKey[T]): Unit = { + ParamToKeyMap(p) = key + } + + def registerServiceParam[T](sp: ServiceParam[T], key: GlobalServiceParamKey[T]): Unit = { + ParamToKeyMap(sp) = key + } + + def setDeploymentName(deploymentName: String): Unit = { + setGlobalServiceParamKey(OpenAIDeploymentNameKey, deploymentName) + } + + def setDeploymentNameCol(colName: String): Unit = { + setGlobalServiceParamKeyCol(OpenAIDeploymentNameKey, colName) + } +} + + + +trait HasGlobalParams extends HasServiceParams { + + def getGlobalParam[T](p: Param[T]): T = { + try { + this.getOrDefault(p) + } + catch { + case e: Exception => + GlobalParams.getParam(p) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalParam[T](name: String): T = { + val param = this.getParam(name).asInstanceOf[Param[T]] + try { + this.getOrDefault(param) + } + catch { + case e: Exception => + GlobalParams.getParam(param) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamScalar[T](p: ServiceParam[T]): T = { + try { + this.getScalarParam(p) + } + catch { + case e: Exception => + GlobalParams.getServiceParamScalar(p) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamVector[T](p: ServiceParam[T]): String= { + try { + this.getVectorParam(p) + } + catch { + case e: Exception => + GlobalParams.getServiceParamVector(p)match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamScalar[T](name: String): T = { + val serviceParam = this.getParam(name).asInstanceOf[ServiceParam[T]] + try { + this.getScalarParam(serviceParam) + } + catch { + case e: Exception => + GlobalParams.getServiceParamScalar(serviceParam) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamVector[T](name: String): String = { + val serviceParam = this.getParam(name).asInstanceOf[ServiceParam[T]] + try { + this.getVectorParam(serviceParam) + } + catch { + case e: Exception => + GlobalParams.getServiceParamVector(serviceParam) match { + case Some(v) => v + case None => throw e + } + } + } + + protected def getGlobalServiceParamValueOpt[T](row: Row, p: ServiceParam[T]): Option[T] = { + val globalParam: Option[T] = GlobalParams.getServiceParam(p).flatMap { + case Right(colName) => Option(row.getAs[T](colName)) + case Left(value) => Some(value) + } + try { + get(p).orElse(getDefault(p)).flatMap { + case Right(colName) => Option(row.getAs[T](colName)) + case Left(value) => Some(value) + } + match { + case some @ Some(_) => some + case None => globalParam + } + } + catch { + case e: Exception => + globalParam match { + case Some(v) => Some(v) + case None => throw e + } + } + } + + + protected def getGlobalServiceParamValue[T](row: Row, p: ServiceParam[T]): T = + getGlobalServiceParamValueOpt(row, p).get +} \ No newline at end of file diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParam.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParam.scala deleted file mode 100644 index bdac25dd61..0000000000 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParam.scala +++ /dev/null @@ -1,161 +0,0 @@ -package com.microsoft.azure.synapse.ml.services.openai - -import com.microsoft.azure.synapse.ml.param.ServiceParam -import com.microsoft.azure.synapse.ml.services.HasServiceParams -import org.apache.spark.ml.param.Param -import org.apache.spark.sql.Row - -import scala.collection.mutable - -sealed trait GlobalParamKey[T] -case object OpenAIDeploymentNameKey extends GlobalParamKey[String] - -object GlobalParamObject { - private val ParamToKeyMap: mutable.Map[Any, GlobalParamKey[_]] = mutable.Map.empty - private val GlobalParams: mutable.Map[GlobalParamKey[Any], Any] = mutable.Map.empty - - def setGlobalParamKey[T](key: GlobalParamKey[T], value: T): Unit = { - GlobalParams(key.asInstanceOf[GlobalParamKey[Any]]) = value - } - - def getGlobalParamKey[T](key: GlobalParamKey[T]): Option[T] = { - GlobalParams.get(key.asInstanceOf[GlobalParamKey[Any]]).map(_.asInstanceOf[T]) - } - - def getParam[T](p: Param[T]): Option[T] = { - ParamToKeyMap.get(p).flatMap { key => - key match { - case k: GlobalParamKey[T] => getGlobalParamKey(k) - case _ => None - } - } - } - - def getServiceParam[T](sp: ServiceParam[T]): Option[T] = { - ParamToKeyMap.get(sp).flatMap { key => - key match { - case k: GlobalParamKey[T] => getGlobalParamKey(k) - case _ => None - } - } - } - - def registerParam[T](p: Param[T], key: GlobalParamKey[T]): Unit = { - ParamToKeyMap(p) = key - } - - def registerServiceParam[T](sp: ServiceParam[T], key: GlobalParamKey[T]): Unit = { - ParamToKeyMap(sp) = key - } -} - - - -trait HasGlobalParams extends HasServiceParams { - - def getGlobalParam[T](p: Param[T]): T = { - try { - this.getOrDefault(p) - } - catch { - case e: Exception => - GlobalParamObject.getParam(p) match { - case Some(v) => v - case None => throw e - } - } - } - - def getGlobalParam[T](name: String): T = { - val param = this.getParam(name).asInstanceOf[Param[T]] - try { - this.getOrDefault(param) - } - catch { - case e: Exception => - GlobalParamObject.getParam(param) match { - case Some(v) => v - case None => throw e - } - } - } - - def getGlobalServiceParamScalar[T](p: ServiceParam[T]): T = { - try { - this.getScalarParam(p) - } - catch { - case e: Exception => - GlobalParamObject.getServiceParam(p) match { - case Some(v) => v - case None => throw e - } - } - } - - def getGlobalServiceParamVector[T](p: ServiceParam[T]): T= { //TODO: should return a string - try { - this.getScalarParam(p) - } - catch { - case e: Exception => - GlobalParamObject.getServiceParam(p)match { - case Some(v) => v - case None => throw e - } - } - } - - def getGlobalServiceParamScalar[T](name: String): T = { - val serviceParam = this.getParam(name).asInstanceOf[ServiceParam[T]] - try { - this.getScalarParam(serviceParam) - } - catch { - case e: Exception => - GlobalParamObject.getServiceParam(serviceParam) match { - case Some(v) => v - case None => throw e - } - } - } - - def getGlobalServiceParamVector[T](name: String): T = { //TODO: should return a string - val serviceParam = this.getParam(name).asInstanceOf[ServiceParam[T]] - try { - this.getScalarParam(serviceParam) - } - catch { - case e: Exception => - GlobalParamObject.getServiceParam(serviceParam) match { - case Some(v) => v - case None => throw e - } - } - } - - protected def getGlobalServiceParamValueOpt[T](row: Row, p: ServiceParam[T]): Option[T] = { - //TODO: Confirm with Mark what this colName business is about - try { - get(p).orElse(getDefault(p)).flatMap { - case Right(colName) => Option(row.getAs[T](colName)) - case Left(value) => Some(value) - } - match { - case some @ Some(_) => some - case None => GlobalParamObject.getServiceParam(p) - } - } - catch { - case e: Exception => - GlobalParamObject.getServiceParam(p) match { - case Some(v) => Some(v) - case None => throw e - } - } - } - - - protected def getGlobalServiceParamValue[T](row: Row, p: ServiceParam[T]): T = - getGlobalServiceParamValueOpt(row, p).get -} \ No newline at end of file diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index f91b9b418b..dd1b61b63e 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -56,7 +56,7 @@ trait HasOpenAISharedParams extends HasGlobalParams with HasAPIVersion { val deploymentName = new ServiceParam[String]( this, "deploymentName", "The name of the deployment", isRequired = false) - GlobalParamObject.registerServiceParam(deploymentName, OpenAIDeploymentNameKey) + GlobalParams.registerServiceParam(deploymentName, OpenAIDeploymentNameKey) def getDeploymentName: String = getGlobalServiceParamScalar(deploymentName) diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParamSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala similarity index 64% rename from cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParamSuite.scala rename to cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala index c3612ddf46..c6714decaa 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParamSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala @@ -1,17 +1,17 @@ // Copyright (C) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. See LICENSE in project root for information. -package com.microsoft.azure.synapse.ml.services.openai +package com.microsoft.azure.synapse.ml.services import com.microsoft.azure.synapse.ml.Secrets.getAccessToken import com.microsoft.azure.synapse.ml.core.test.base.Flaky import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing} +import com.microsoft.azure.synapse.ml.services.openai.{OpenAIAPIKey, OpenAIPrompt} import org.apache.spark.ml.util.MLReadable -import org.apache.spark.sql.functions.col -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.DataFrame import org.scalactic.Equality -class GlobalParamSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIKey with Flaky { +class GlobalParamSuite extends Flaky with OpenAIAPIKey { import spark.implicits._ @@ -21,7 +21,7 @@ class GlobalParamSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIKe super.beforeAll() } - GlobalParamObject.setGlobalParamKey(OpenAIDeploymentNameKey, deploymentName) + GlobalParams.setDeploymentName(deploymentName) lazy val prompt: OpenAIPrompt = new OpenAIPrompt() .setSubscriptionKey(openAIAPIKey) @@ -47,18 +47,4 @@ class GlobalParamSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIKe assert(nonNullCount == 3) } - - override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = { - super.assertDFEq(df1.drop("out", "outParsed"), df2.drop("out", "outParsed"))(eq) - } - - override def testObjects(): Seq[TestObject[OpenAIPrompt]] = { - val testPrompt = prompt - .setPromptTemplate("{text} rhymes with ") - - Seq(new TestObject(testPrompt, df)) - } - - override def reader: MLReadable[_] = OpenAIPrompt - } From efeb9d0ea3595ebfb3d9a7b395779bfd29a28238 Mon Sep 17 00:00:00 2001 From: Shyam Sai Date: Mon, 11 Nov 2024 11:07:31 -0500 Subject: [PATCH 3/4] Only one GlobalKey type and only one GlobalParams dictionary. Added asserts to type check Params vs Service Params. Added more tests. Tests pass! --- .../synapse/ml/services/GlobalParam.scala | 159 ++++++++++++++---- .../synapse/ml/services/openai/OpenAI.scala | 1 + .../ml/services/GlobalParamSuite.scala | 72 +++++++- 3 files changed, 193 insertions(+), 39 deletions(-) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala index b1546985d0..53d21a8ebf 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala @@ -5,42 +5,135 @@ import org.apache.spark.ml.param.Param import org.apache.spark.sql.Row import scala.collection.mutable +import scala.reflect.ClassTag -sealed trait GlobalKey[T] -sealed trait GlobalParamKey[T] extends GlobalKey[T] -sealed trait GlobalServiceParamKey[T] extends GlobalKey[Either[T, String]] //left is Scalar, right is Vector +trait GlobalKey[T] { + val name: String + val isServiceParam: Boolean +} -case object OpenAIDeploymentNameKey extends GlobalServiceParamKey[String] +case object OpenAIDeploymentNameKey extends GlobalKey[String] + {val name: String = "OpenAIDeploymentName"; val isServiceParam = true} +case object TestParamKey extends GlobalKey[Double] {val name: String = "TestParam"; val isServiceParam = false} +case object TestServiceParamKey extends GlobalKey[Int] + {val name: String = "TestServiceParam"; val isServiceParam = true} object GlobalParams { private val ParamToKeyMap: mutable.Map[Any, GlobalKey[_]] = mutable.Map.empty - private val GlobalParams: mutable.Map[GlobalParamKey[Any], Any] = mutable.Map.empty - private val GlobalServiceParams: mutable.Map[GlobalServiceParamKey[Any], Either[Any, String]] = mutable.Map.empty + private val GlobalParams: mutable.Map[GlobalKey[_], Any] = mutable.Map.empty + + private def boxedClass(c: Class[_]): Class[_] = { + if (!c.isPrimitive) c + c match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Long] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Float.TYPE => classOf[java.lang.Float] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => c // Fallback for any other primitive types + } + } + + private val StringtoKeyMap: Map[String, GlobalKey[_]] = Map( + "OpenAIDeploymentName" -> OpenAIDeploymentNameKey, + "TestParam" -> TestParamKey, + "TestServiceParam" -> TestServiceParamKey, + ) + + private def findGlobalKeyByName(keyName: String): Option[GlobalKey[_]] = { + StringtoKeyMap.get(keyName) + } + + def setGlobalParam[T](key: GlobalKey[T], value: T)(implicit ct: ClassTag[T]): Unit = { + assert(!key.isServiceParam, s"${key.name} is a Service Param. setGlobalServiceParamKey should be used.") + val expectedClass = boxedClass(ct.runtimeClass) + val actualClass = value.getClass + assert( + expectedClass.isAssignableFrom(actualClass), + s"Value of type ${actualClass.getName} is not compatible with expected type ${expectedClass.getName}" + ) + GlobalParams(key) = value + } - private def setGlobalParamKey[T](key: GlobalParamKey[T], value: T): Unit = { - GlobalParams(key.asInstanceOf[GlobalParamKey[Any]]) = value + def setGlobalParam[T](keyName: String, value: T)(implicit ct: ClassTag[T]): Unit = { + val expectedClass = boxedClass(ct.runtimeClass) + val actualClass = value.getClass + assert( + expectedClass.isAssignableFrom(actualClass), + s"Value of type ${actualClass.getName} is not compatible with expected type ${expectedClass.getName}" + ) + val key = findGlobalKeyByName(keyName) + key match { + case Some(k) => + assert(!k.isServiceParam, s"${k.name} is a Service Param. setGlobalServiceParamKey should be used.") + GlobalParams(k) = value + case None => throw new IllegalArgumentException("${keyName} is not a valid key in GlobalParams") + } } - private def setGlobalServiceParamKey[T](key: GlobalServiceParamKey[T], value: T): Unit = { - GlobalServiceParams(key.asInstanceOf[GlobalServiceParamKey[Any]]) = Left(value) + def setGlobalServiceParam[T](key: GlobalKey[T], value: T)(implicit ct: ClassTag[T]): Unit = { + assert(key.isServiceParam, s"${key.name} is a Param. setGlobalParamKey should be used.") + val expectedClass = boxedClass(ct.runtimeClass) + val actualClass = value.getClass + assert( + expectedClass.isAssignableFrom(actualClass), + s"Value of type ${actualClass.getName} is not compatible with expected type ${expectedClass.getName}" + ) + GlobalParams(key) = Left(value) } - private def setGlobalServiceParamKeyCol[T](key: GlobalServiceParamKey[T], value: String): Unit = { - GlobalServiceParams(key.asInstanceOf[GlobalServiceParamKey[Any]]) = Right(value) + def setGlobalServiceParam[T](keyName: String, value: T)(implicit ct: ClassTag[T]): Unit = { + val expectedClass = boxedClass(ct.runtimeClass) + val actualClass = value.getClass + assert( + expectedClass.isAssignableFrom(actualClass), + s"Value of type ${actualClass.getName} is not compatible with expected type ${expectedClass.getName}" + ) + val key = findGlobalKeyByName(keyName) + key match { + case Some(k) => + assert(k.isServiceParam, s"${k.name} is a Param. setGlobalParamKey should be used.") + GlobalParams(k) = Left(value) + case None => throw new IllegalArgumentException("${keyName} is not a valid key in GlobalParams") + } } - private def getGlobalParam[T](key: GlobalParamKey[T]): Option[T] = { - GlobalParams.get(key.asInstanceOf[GlobalParamKey[Any]]).map(_.asInstanceOf[T]) + def setGlobalServiceParamCol[T](key: GlobalKey[T], value: String): Unit = { + assert(key.isServiceParam, s"${key.name} is a Param. setGlobalParamKey should be used.") + GlobalParams(key) = Right(value) } - private def getGlobalServiceParam[T](key: GlobalServiceParamKey[T]): Option[Either[T, String]] = { - GlobalServiceParams.get(key.asInstanceOf[GlobalServiceParamKey[Any]]).map(_.asInstanceOf[Either[T, String]]) + def setGlobalServiceParamCol[T](keyName: String, value: String): Unit = { + val key = findGlobalKeyByName(keyName) + key match { + case Some(k) => + assert(k.isServiceParam, s"${k.name} is a Param. setGlobalParamKey should be used.") + GlobalParams(k) = Right(value) + case None => throw new IllegalArgumentException("${keyName} is not a valid key in GlobalParams") + } + } + + private def getGlobalParam[T](key: GlobalKey[T]): Option[T] = { + GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[T]) + } + + private def getGlobalServiceParam[T](key: GlobalKey[T]): Option[Either[T, String]] = { + val value = GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[Either[T, String]]) + value match { + case some @ Some(v) => + assert(v.isInstanceOf[Either[T, String]], + "getGlobalServiceParam used to fetch a normal Param") + value + case None => None + } } def getParam[T](p: Param[T]): Option[T] = { ParamToKeyMap.get(p).flatMap { key => key match { - case k: GlobalParamKey[T] => getGlobalParam(k) + case k: GlobalKey[T] => + assert(!k.isServiceParam, s"${k.name} is a Service Param. getServiceParam should be used.") + getGlobalParam(k) case _ => None } } @@ -49,7 +142,9 @@ object GlobalParams { def getServiceParam[T](sp: ServiceParam[T]): Option[Either[T, String]] = { ParamToKeyMap.get(sp).flatMap { key => key match { - case k: GlobalServiceParamKey[T] => getGlobalServiceParam(k) + case k: GlobalKey[T] => + assert(k.isServiceParam, s"${k.name} is a Param. getParam should be used.") + getGlobalServiceParam(k) case _ => None } } @@ -58,9 +153,11 @@ object GlobalParams { def getServiceParamScalar[T](sp: ServiceParam[T]): Option[T] = { ParamToKeyMap.get(sp).flatMap { key => key match { - case k: GlobalServiceParamKey[T] => + case k: GlobalKey[T] => getGlobalServiceParam(k) match { - case Some(Left(value)) => Some(value) + case Some(Left(value)) => + assert(k.isServiceParam, s"${k.name} is a Param. getParam should be used.") + Some(value) case _ => None } case _ => None @@ -71,9 +168,11 @@ object GlobalParams { def getServiceParamVector[T](sp: ServiceParam[T]): Option[String] = { ParamToKeyMap.get(sp).flatMap { key => key match { - case k: GlobalServiceParamKey[T] => + case k: GlobalKey[T] => getGlobalServiceParam(k) match { - case Some(Right(colName)) => Some(colName) + case Some(Right(colName)) => + assert(k.isServiceParam, s"${k.name} is a Param. getParam should be used.") + Some(colName) case _ => None } case _ => None @@ -81,25 +180,17 @@ object GlobalParams { } } - def registerParam[T](p: Param[T], key: GlobalParamKey[T]): Unit = { + def registerParam[T](p: Param[T], key: GlobalKey[T]): Unit = { + assert(!key.isServiceParam, s"${key.name} is a Service Param. registerServiceParam should be used.") ParamToKeyMap(p) = key } - def registerServiceParam[T](sp: ServiceParam[T], key: GlobalServiceParamKey[T]): Unit = { + def registerServiceParam[T](sp: ServiceParam[T], key: GlobalKey[T]): Unit = { + assert(key.isServiceParam, s"${key.name} is a Param. registerParam should be used.") ParamToKeyMap(sp) = key } - - def setDeploymentName(deploymentName: String): Unit = { - setGlobalServiceParamKey(OpenAIDeploymentNameKey, deploymentName) - } - - def setDeploymentNameCol(colName: String): Unit = { - setGlobalServiceParamKeyCol(OpenAIDeploymentNameKey, colName) - } } - - trait HasGlobalParams extends HasServiceParams { def getGlobalParam[T](p: Param[T]): T = { diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index dd1b61b63e..90de15260e 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -51,6 +51,7 @@ trait HasMessagesInput extends Params { def setMessagesCol(v: String): this.type = set(messagesCol, v) } + trait HasOpenAISharedParams extends HasGlobalParams with HasAPIVersion { val deploymentName = new ServiceParam[String]( diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala index c6714decaa..792ae04957 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala @@ -5,11 +5,44 @@ package com.microsoft.azure.synapse.ml.services import com.microsoft.azure.synapse.ml.Secrets.getAccessToken import com.microsoft.azure.synapse.ml.core.test.base.Flaky -import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing} +import com.microsoft.azure.synapse.ml.param.ServiceParam import com.microsoft.azure.synapse.ml.services.openai.{OpenAIAPIKey, OpenAIPrompt} -import org.apache.spark.ml.util.MLReadable +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame -import org.scalactic.Equality +import spray.json.DefaultJsonProtocol.IntJsonFormat + +trait TestGlobalParamsTrait extends HasGlobalParams { + val testParam: Param[Double] = new Param[Double]( + this, "TestParam", "Test Param") + + GlobalParams.registerParam(testParam, TestParamKey) + + def getTestParam: Double = getGlobalParam(testParam) + + def setTestParam(v: Double): this.type = set(testParam, v) + + val testServiceParam = new ServiceParam[Int]( + this, "testServiceParam", "Test Service Param", isRequired = false) + + GlobalParams.registerServiceParam(testServiceParam, TestServiceParamKey) + + def getTestServiceParam: Int = getGlobalServiceParamScalar(testServiceParam) + + def setTestServiceParam(v: Int): this.type = setScalarParam(testServiceParam, v) + + def getTestServiceParamCol: String = getVectorParam(testServiceParam) + + def setTestServiceParamCol(v: String): this.type = setVectorParam(testServiceParam, v) +} + +class TestGlobalParams extends TestGlobalParamsTrait { + + override def copy(extra: ParamMap): Transformer = defaultCopy(extra) + + override val uid: String = Identifiable.randomUID("TestGlobalParams") +} class GlobalParamSuite extends Flaky with OpenAIAPIKey { @@ -21,7 +54,35 @@ class GlobalParamSuite extends Flaky with OpenAIAPIKey { super.beforeAll() } - GlobalParams.setDeploymentName(deploymentName) + GlobalParams.setGlobalParam(TestParamKey, 12.5) + GlobalParams.setGlobalServiceParam(TestServiceParamKey, 1) + + val testGlobalParams = new TestGlobalParams() + + test("Basic Usage") { + assert(testGlobalParams.getTestParam == 12.5) + assert(testGlobalParams.getTestServiceParam == 1) + } + + test("Test Changing Value") { + assert(testGlobalParams.getTestParam == 12.5) + GlobalParams.setGlobalParam("TestParam", 19.853) + assert(testGlobalParams.getTestParam == 19.853) + assert(testGlobalParams.getTestServiceParam == 1) + GlobalParams.setGlobalServiceParam("TestServiceParam", 4) + assert(testGlobalParams.getTestServiceParam == 4) + } + + test("Using wrong setters and getters") { + assertThrows[AssertionError] { + GlobalParams.setGlobalServiceParam("TestParam", 8.8888) + } + assertThrows[AssertionError] { + GlobalParams.getParam(testGlobalParams.testServiceParam) + } + } + + GlobalParams.setGlobalServiceParam("OpenAIDeploymentName", deploymentName) lazy val prompt: OpenAIPrompt = new OpenAIPrompt() .setSubscriptionKey(openAIAPIKey) @@ -36,7 +97,7 @@ class GlobalParamSuite extends Flaky with OpenAIAPIKey { (null, "none") //scalastyle:ignore null ).toDF("text", "category") - test("Basic Usage") { + test("OpenAIPrompt w Globals") { val nonNullCount = prompt .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") .setPostProcessing("csv") @@ -48,3 +109,4 @@ class GlobalParamSuite extends Flaky with OpenAIAPIKey { assert(nonNullCount == 3) } } + From 7fb45dbc75661eee2951ebb23de62d25e3535f97 Mon Sep 17 00:00:00 2001 From: Shyam Sai Date: Mon, 18 Nov 2024 11:53:16 -0500 Subject: [PATCH 4/4] Use initialize objects with GlobalParams and get rid of boxedClass logic --- .../synapse/ml/services/GlobalParam.scala | 60 ++++--------------- .../synapse/ml/services/openai/OpenAI.scala | 2 + .../ml/services/GlobalParamSuite.scala | 10 +++- 3 files changed, 19 insertions(+), 53 deletions(-) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala index 53d21a8ebf..9d9c220942 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala @@ -5,63 +5,35 @@ import org.apache.spark.ml.param.Param import org.apache.spark.sql.Row import scala.collection.mutable -import scala.reflect.ClassTag +import com.microsoft.azure.synapse.ml.core.utils.JarLoadingUtils trait GlobalKey[T] { val name: String val isServiceParam: Boolean } -case object OpenAIDeploymentNameKey extends GlobalKey[String] - {val name: String = "OpenAIDeploymentName"; val isServiceParam = true} -case object TestParamKey extends GlobalKey[Double] {val name: String = "TestParam"; val isServiceParam = false} -case object TestServiceParamKey extends GlobalKey[Int] - {val name: String = "TestServiceParam"; val isServiceParam = true} - object GlobalParams { private val ParamToKeyMap: mutable.Map[Any, GlobalKey[_]] = mutable.Map.empty private val GlobalParams: mutable.Map[GlobalKey[_], Any] = mutable.Map.empty - private def boxedClass(c: Class[_]): Class[_] = { - if (!c.isPrimitive) c - c match { - case java.lang.Integer.TYPE => classOf[java.lang.Integer] - case java.lang.Long.TYPE => classOf[java.lang.Long] - case java.lang.Double.TYPE => classOf[java.lang.Double] - case java.lang.Float.TYPE => classOf[java.lang.Float] - case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] - case _ => c // Fallback for any other primitive types + private val StringtoKeyMap: mutable.Map[String, GlobalKey[_]] = { + val strToKeyMap = mutable.Map[String, GlobalKey[_]]() + JarLoadingUtils.instantiateObjects[GlobalKey[_]]().foreach { key: GlobalKey[_] => + strToKeyMap += (key.name -> key) } + strToKeyMap } - private val StringtoKeyMap: Map[String, GlobalKey[_]] = Map( - "OpenAIDeploymentName" -> OpenAIDeploymentNameKey, - "TestParam" -> TestParamKey, - "TestServiceParam" -> TestServiceParamKey, - ) - private def findGlobalKeyByName(keyName: String): Option[GlobalKey[_]] = { StringtoKeyMap.get(keyName) } - def setGlobalParam[T](key: GlobalKey[T], value: T)(implicit ct: ClassTag[T]): Unit = { + def setGlobalParam[T](key: GlobalKey[T], value: T): Unit = { assert(!key.isServiceParam, s"${key.name} is a Service Param. setGlobalServiceParamKey should be used.") - val expectedClass = boxedClass(ct.runtimeClass) - val actualClass = value.getClass - assert( - expectedClass.isAssignableFrom(actualClass), - s"Value of type ${actualClass.getName} is not compatible with expected type ${expectedClass.getName}" - ) GlobalParams(key) = value } - def setGlobalParam[T](keyName: String, value: T)(implicit ct: ClassTag[T]): Unit = { - val expectedClass = boxedClass(ct.runtimeClass) - val actualClass = value.getClass - assert( - expectedClass.isAssignableFrom(actualClass), - s"Value of type ${actualClass.getName} is not compatible with expected type ${expectedClass.getName}" - ) + def setGlobalParam[T](keyName: String, value: T): Unit = { val key = findGlobalKeyByName(keyName) key match { case Some(k) => @@ -71,24 +43,12 @@ object GlobalParams { } } - def setGlobalServiceParam[T](key: GlobalKey[T], value: T)(implicit ct: ClassTag[T]): Unit = { + def setGlobalServiceParam[T](key: GlobalKey[T], value: T): Unit = { assert(key.isServiceParam, s"${key.name} is a Param. setGlobalParamKey should be used.") - val expectedClass = boxedClass(ct.runtimeClass) - val actualClass = value.getClass - assert( - expectedClass.isAssignableFrom(actualClass), - s"Value of type ${actualClass.getName} is not compatible with expected type ${expectedClass.getName}" - ) GlobalParams(key) = Left(value) } - def setGlobalServiceParam[T](keyName: String, value: T)(implicit ct: ClassTag[T]): Unit = { - val expectedClass = boxedClass(ct.runtimeClass) - val actualClass = value.getClass - assert( - expectedClass.isAssignableFrom(actualClass), - s"Value of type ${actualClass.getName} is not compatible with expected type ${expectedClass.getName}" - ) + def setGlobalServiceParam[T](keyName: String, value: T): Unit = { val key = findGlobalKeyByName(keyName) key match { case Some(k) => diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index 90de15260e..e89bb2b735 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -51,6 +51,8 @@ trait HasMessagesInput extends Params { def setMessagesCol(v: String): this.type = set(messagesCol, v) } +case object OpenAIDeploymentNameKey extends GlobalKey[String] +{val name: String = "OpenAIDeploymentName"; val isServiceParam = true} trait HasOpenAISharedParams extends HasGlobalParams with HasAPIVersion { diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala index 792ae04957..f8f331d25c 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala @@ -44,6 +44,11 @@ class TestGlobalParams extends TestGlobalParamsTrait { override val uid: String = Identifiable.randomUID("TestGlobalParams") } + +case object TestParamKey extends GlobalKey[Double] {val name: String = "TestParam"; val isServiceParam = false} +case object TestServiceParamKey extends GlobalKey[Int] +{val name: String = "TestServiceParam"; val isServiceParam = true} + class GlobalParamSuite extends Flaky with OpenAIAPIKey { import spark.implicits._ @@ -54,12 +59,11 @@ class GlobalParamSuite extends Flaky with OpenAIAPIKey { super.beforeAll() } - GlobalParams.setGlobalParam(TestParamKey, 12.5) - GlobalParams.setGlobalServiceParam(TestServiceParamKey, 1) - val testGlobalParams = new TestGlobalParams() test("Basic Usage") { + GlobalParams.setGlobalParam(TestParamKey, 12.5) + GlobalParams.setGlobalServiceParam(TestServiceParamKey, 1) assert(testGlobalParams.getTestParam == 12.5) assert(testGlobalParams.getTestServiceParam == 1) }