From 5e0c3bf45f85473f4528cea65856f37671c751ab Mon Sep 17 00:00:00 2001 From: Shyam Sai Date: Sat, 2 Nov 2024 22:43:58 -0500 Subject: [PATCH 1/8] GlobalParamObject implementation with higher-order getters. DeploymentName added to GlobalParam set --- .../ml/services/openai/GlobalParam.scala | 161 ++++++++++++++++++ .../synapse/ml/services/openai/OpenAI.scala | 6 +- .../openai/OpenAIChatCompletion.scala | 2 +- .../ml/services/openai/OpenAICompletion.scala | 2 +- .../ml/services/openai/OpenAIEmbedding.scala | 2 +- .../ml/services/openai/GlobalParamSuite.scala | 64 +++++++ 6 files changed, 232 insertions(+), 5 deletions(-) create mode 100644 cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParam.scala create mode 100644 cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParamSuite.scala diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParam.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParam.scala new file mode 100644 index 0000000000..bdac25dd61 --- /dev/null +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParam.scala @@ -0,0 +1,161 @@ +package com.microsoft.azure.synapse.ml.services.openai + +import com.microsoft.azure.synapse.ml.param.ServiceParam +import com.microsoft.azure.synapse.ml.services.HasServiceParams +import org.apache.spark.ml.param.Param +import org.apache.spark.sql.Row + +import scala.collection.mutable + +sealed trait GlobalParamKey[T] +case object OpenAIDeploymentNameKey extends GlobalParamKey[String] + +object GlobalParamObject { + private val ParamToKeyMap: mutable.Map[Any, GlobalParamKey[_]] = mutable.Map.empty + private val GlobalParams: mutable.Map[GlobalParamKey[Any], Any] = mutable.Map.empty + + def setGlobalParamKey[T](key: GlobalParamKey[T], value: T): Unit = { + GlobalParams(key.asInstanceOf[GlobalParamKey[Any]]) = value + } + + def getGlobalParamKey[T](key: GlobalParamKey[T]): Option[T] = { + GlobalParams.get(key.asInstanceOf[GlobalParamKey[Any]]).map(_.asInstanceOf[T]) + } + + def getParam[T](p: Param[T]): Option[T] = { + ParamToKeyMap.get(p).flatMap { key => + key match { + case k: GlobalParamKey[T] => getGlobalParamKey(k) + case _ => None + } + } + } + + def getServiceParam[T](sp: ServiceParam[T]): Option[T] = { + ParamToKeyMap.get(sp).flatMap { key => + key match { + case k: GlobalParamKey[T] => getGlobalParamKey(k) + case _ => None + } + } + } + + def registerParam[T](p: Param[T], key: GlobalParamKey[T]): Unit = { + ParamToKeyMap(p) = key + } + + def registerServiceParam[T](sp: ServiceParam[T], key: GlobalParamKey[T]): Unit = { + ParamToKeyMap(sp) = key + } +} + + + +trait HasGlobalParams extends HasServiceParams { + + def getGlobalParam[T](p: Param[T]): T = { + try { + this.getOrDefault(p) + } + catch { + case e: Exception => + GlobalParamObject.getParam(p) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalParam[T](name: String): T = { + val param = this.getParam(name).asInstanceOf[Param[T]] + try { + this.getOrDefault(param) + } + catch { + case e: Exception => + GlobalParamObject.getParam(param) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamScalar[T](p: ServiceParam[T]): T = { + try { + this.getScalarParam(p) + } + catch { + case e: Exception => + GlobalParamObject.getServiceParam(p) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamVector[T](p: ServiceParam[T]): T= { //TODO: should return a string + try { + this.getScalarParam(p) + } + catch { + case e: Exception => + GlobalParamObject.getServiceParam(p)match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamScalar[T](name: String): T = { + val serviceParam = this.getParam(name).asInstanceOf[ServiceParam[T]] + try { + this.getScalarParam(serviceParam) + } + catch { + case e: Exception => + GlobalParamObject.getServiceParam(serviceParam) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamVector[T](name: String): T = { //TODO: should return a string + val serviceParam = this.getParam(name).asInstanceOf[ServiceParam[T]] + try { + this.getScalarParam(serviceParam) + } + catch { + case e: Exception => + GlobalParamObject.getServiceParam(serviceParam) match { + case Some(v) => v + case None => throw e + } + } + } + + protected def getGlobalServiceParamValueOpt[T](row: Row, p: ServiceParam[T]): Option[T] = { + //TODO: Confirm with Mark what this colName business is about + try { + get(p).orElse(getDefault(p)).flatMap { + case Right(colName) => Option(row.getAs[T](colName)) + case Left(value) => Some(value) + } + match { + case some @ Some(_) => some + case None => GlobalParamObject.getServiceParam(p) + } + } + catch { + case e: Exception => + GlobalParamObject.getServiceParam(p) match { + case Some(v) => Some(v) + case None => throw e + } + } + } + + + protected def getGlobalServiceParamValue[T](row: Row, p: ServiceParam[T]): T = + getGlobalServiceParamValueOpt(row, p).get +} \ No newline at end of file diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index b57f4d65da..883b2b7c61 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -51,12 +51,14 @@ trait HasMessagesInput extends Params { def setMessagesCol(v: String): this.type = set(messagesCol, v) } -trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion { +trait HasOpenAISharedParams extends HasGlobalParams with HasAPIVersion { val deploymentName = new ServiceParam[String]( this, "deploymentName", "The name of the deployment", isRequired = false) - def getDeploymentName: String = getScalarParam(deploymentName) + GlobalParamObject.registerServiceParam(deploymentName, OpenAIDeploymentNameKey) + + def getDeploymentName: String = getGlobalServiceParamScalar(deploymentName) def setDeploymentName(v: String): this.type = setScalarParam(deploymentName, v) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala index 703fc7f471..e667cc0625 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala @@ -35,7 +35,7 @@ class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase( } override protected def prepareUrlRoot: Row => String = { row => - s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/chat/completions" + s"${getUrl}openai/deployments/${getGlobalServiceParamValue(row, deploymentName)}/chat/completions" } override protected[openai] def prepareEntity: Row => Option[AbstractHttpEntity] = { diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala index 953138bc36..5fbbb1243b 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala @@ -34,7 +34,7 @@ class OpenAICompletion(override val uid: String) extends OpenAIServicesBase(uid) } override protected def prepareUrlRoot: Row => String = { row => - s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/completions" + s"${getUrl}openai/deployments/${getGlobalServiceParamValue(row, deploymentName)}/completions" } override protected[openai] def prepareEntity: Row => Option[AbstractHttpEntity] = { diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala index 58f5c857d6..8c9ee199f3 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala @@ -59,7 +59,7 @@ class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid) } override protected def prepareUrlRoot: Row => String = { row => - s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/embeddings" + s"${getUrl}openai/deployments/${getGlobalServiceParamValue(row, deploymentName)}/embeddings" } private[this] def getStringEntity[A](text: A, optionalParams: Map[String, Any]): StringEntity = { diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParamSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParamSuite.scala new file mode 100644 index 0000000000..c3612ddf46 --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParamSuite.scala @@ -0,0 +1,64 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.openai + +import com.microsoft.azure.synapse.ml.Secrets.getAccessToken +import com.microsoft.azure.synapse.ml.core.test.base.Flaky +import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing} +import org.apache.spark.ml.util.MLReadable +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.{DataFrame, Row} +import org.scalactic.Equality + +class GlobalParamSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIKey with Flaky { + + import spark.implicits._ + + override def beforeAll(): Unit = { + val aadToken = getAccessToken("https://cognitiveservices.azure.com/") + println(s"Triggering token creation early ${aadToken.length}") + super.beforeAll() + } + + GlobalParamObject.setGlobalParamKey(OpenAIDeploymentNameKey, deploymentName) + + lazy val prompt: OpenAIPrompt = new OpenAIPrompt() + .setSubscriptionKey(openAIAPIKey) + .setCustomServiceName(openAIServiceName) + .setOutputCol("outParsed") + .setTemperature(0) + + lazy val df: DataFrame = Seq( + ("apple", "fruits"), + ("mercedes", "cars"), + ("cake", "dishes"), + (null, "none") //scalastyle:ignore null + ).toDF("text", "category") + + test("Basic Usage") { + val nonNullCount = prompt + .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") + .setPostProcessing("csv") + .transform(df) + .select("outParsed") + .collect() + .count(r => Option(r.getSeq[String](0)).isDefined) + + assert(nonNullCount == 3) + } + + override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = { + super.assertDFEq(df1.drop("out", "outParsed"), df2.drop("out", "outParsed"))(eq) + } + + override def testObjects(): Seq[TestObject[OpenAIPrompt]] = { + val testPrompt = prompt + .setPromptTemplate("{text} rhymes with ") + + Seq(new TestObject(testPrompt, df)) + } + + override def reader: MLReadable[_] = OpenAIPrompt + +} From 7451422098237906e7f8a47d8f932e5ff20551ea Mon Sep 17 00:00:00 2001 From: Shyam Sai Date: Thu, 7 Nov 2024 11:44:13 -0500 Subject: [PATCH 2/8] Enable ServiceParams as part of GlobalParam --- .../synapse/ml/services/GlobalParam.scala | 213 ++++++++++++++++++ .../ml/services/openai/GlobalParam.scala | 161 ------------- .../synapse/ml/services/openai/OpenAI.scala | 2 +- .../{openai => }/GlobalParamSuite.scala | 24 +- 4 files changed, 219 insertions(+), 181 deletions(-) create mode 100644 cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala delete mode 100644 cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParam.scala rename cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/{openai => }/GlobalParamSuite.scala (64%) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala new file mode 100644 index 0000000000..b1546985d0 --- /dev/null +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala @@ -0,0 +1,213 @@ +package com.microsoft.azure.synapse.ml.services + +import com.microsoft.azure.synapse.ml.param.ServiceParam +import org.apache.spark.ml.param.Param +import org.apache.spark.sql.Row + +import scala.collection.mutable + +sealed trait GlobalKey[T] +sealed trait GlobalParamKey[T] extends GlobalKey[T] +sealed trait GlobalServiceParamKey[T] extends GlobalKey[Either[T, String]] //left is Scalar, right is Vector + +case object OpenAIDeploymentNameKey extends GlobalServiceParamKey[String] + +object GlobalParams { + private val ParamToKeyMap: mutable.Map[Any, GlobalKey[_]] = mutable.Map.empty + private val GlobalParams: mutable.Map[GlobalParamKey[Any], Any] = mutable.Map.empty + private val GlobalServiceParams: mutable.Map[GlobalServiceParamKey[Any], Either[Any, String]] = mutable.Map.empty + + private def setGlobalParamKey[T](key: GlobalParamKey[T], value: T): Unit = { + GlobalParams(key.asInstanceOf[GlobalParamKey[Any]]) = value + } + + private def setGlobalServiceParamKey[T](key: GlobalServiceParamKey[T], value: T): Unit = { + GlobalServiceParams(key.asInstanceOf[GlobalServiceParamKey[Any]]) = Left(value) + } + + private def setGlobalServiceParamKeyCol[T](key: GlobalServiceParamKey[T], value: String): Unit = { + GlobalServiceParams(key.asInstanceOf[GlobalServiceParamKey[Any]]) = Right(value) + } + + private def getGlobalParam[T](key: GlobalParamKey[T]): Option[T] = { + GlobalParams.get(key.asInstanceOf[GlobalParamKey[Any]]).map(_.asInstanceOf[T]) + } + + private def getGlobalServiceParam[T](key: GlobalServiceParamKey[T]): Option[Either[T, String]] = { + GlobalServiceParams.get(key.asInstanceOf[GlobalServiceParamKey[Any]]).map(_.asInstanceOf[Either[T, String]]) + } + + def getParam[T](p: Param[T]): Option[T] = { + ParamToKeyMap.get(p).flatMap { key => + key match { + case k: GlobalParamKey[T] => getGlobalParam(k) + case _ => None + } + } + } + + def getServiceParam[T](sp: ServiceParam[T]): Option[Either[T, String]] = { + ParamToKeyMap.get(sp).flatMap { key => + key match { + case k: GlobalServiceParamKey[T] => getGlobalServiceParam(k) + case _ => None + } + } + } + + def getServiceParamScalar[T](sp: ServiceParam[T]): Option[T] = { + ParamToKeyMap.get(sp).flatMap { key => + key match { + case k: GlobalServiceParamKey[T] => + getGlobalServiceParam(k) match { + case Some(Left(value)) => Some(value) + case _ => None + } + case _ => None + } + } + } + + def getServiceParamVector[T](sp: ServiceParam[T]): Option[String] = { + ParamToKeyMap.get(sp).flatMap { key => + key match { + case k: GlobalServiceParamKey[T] => + getGlobalServiceParam(k) match { + case Some(Right(colName)) => Some(colName) + case _ => None + } + case _ => None + } + } + } + + def registerParam[T](p: Param[T], key: GlobalParamKey[T]): Unit = { + ParamToKeyMap(p) = key + } + + def registerServiceParam[T](sp: ServiceParam[T], key: GlobalServiceParamKey[T]): Unit = { + ParamToKeyMap(sp) = key + } + + def setDeploymentName(deploymentName: String): Unit = { + setGlobalServiceParamKey(OpenAIDeploymentNameKey, deploymentName) + } + + def setDeploymentNameCol(colName: String): Unit = { + setGlobalServiceParamKeyCol(OpenAIDeploymentNameKey, colName) + } +} + + + +trait HasGlobalParams extends HasServiceParams { + + def getGlobalParam[T](p: Param[T]): T = { + try { + this.getOrDefault(p) + } + catch { + case e: Exception => + GlobalParams.getParam(p) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalParam[T](name: String): T = { + val param = this.getParam(name).asInstanceOf[Param[T]] + try { + this.getOrDefault(param) + } + catch { + case e: Exception => + GlobalParams.getParam(param) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamScalar[T](p: ServiceParam[T]): T = { + try { + this.getScalarParam(p) + } + catch { + case e: Exception => + GlobalParams.getServiceParamScalar(p) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamVector[T](p: ServiceParam[T]): String= { + try { + this.getVectorParam(p) + } + catch { + case e: Exception => + GlobalParams.getServiceParamVector(p)match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamScalar[T](name: String): T = { + val serviceParam = this.getParam(name).asInstanceOf[ServiceParam[T]] + try { + this.getScalarParam(serviceParam) + } + catch { + case e: Exception => + GlobalParams.getServiceParamScalar(serviceParam) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamVector[T](name: String): String = { + val serviceParam = this.getParam(name).asInstanceOf[ServiceParam[T]] + try { + this.getVectorParam(serviceParam) + } + catch { + case e: Exception => + GlobalParams.getServiceParamVector(serviceParam) match { + case Some(v) => v + case None => throw e + } + } + } + + protected def getGlobalServiceParamValueOpt[T](row: Row, p: ServiceParam[T]): Option[T] = { + val globalParam: Option[T] = GlobalParams.getServiceParam(p).flatMap { + case Right(colName) => Option(row.getAs[T](colName)) + case Left(value) => Some(value) + } + try { + get(p).orElse(getDefault(p)).flatMap { + case Right(colName) => Option(row.getAs[T](colName)) + case Left(value) => Some(value) + } + match { + case some @ Some(_) => some + case None => globalParam + } + } + catch { + case e: Exception => + globalParam match { + case Some(v) => Some(v) + case None => throw e + } + } + } + + + protected def getGlobalServiceParamValue[T](row: Row, p: ServiceParam[T]): T = + getGlobalServiceParamValueOpt(row, p).get +} \ No newline at end of file diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParam.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParam.scala deleted file mode 100644 index bdac25dd61..0000000000 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParam.scala +++ /dev/null @@ -1,161 +0,0 @@ -package com.microsoft.azure.synapse.ml.services.openai - -import com.microsoft.azure.synapse.ml.param.ServiceParam -import com.microsoft.azure.synapse.ml.services.HasServiceParams -import org.apache.spark.ml.param.Param -import org.apache.spark.sql.Row - -import scala.collection.mutable - -sealed trait GlobalParamKey[T] -case object OpenAIDeploymentNameKey extends GlobalParamKey[String] - -object GlobalParamObject { - private val ParamToKeyMap: mutable.Map[Any, GlobalParamKey[_]] = mutable.Map.empty - private val GlobalParams: mutable.Map[GlobalParamKey[Any], Any] = mutable.Map.empty - - def setGlobalParamKey[T](key: GlobalParamKey[T], value: T): Unit = { - GlobalParams(key.asInstanceOf[GlobalParamKey[Any]]) = value - } - - def getGlobalParamKey[T](key: GlobalParamKey[T]): Option[T] = { - GlobalParams.get(key.asInstanceOf[GlobalParamKey[Any]]).map(_.asInstanceOf[T]) - } - - def getParam[T](p: Param[T]): Option[T] = { - ParamToKeyMap.get(p).flatMap { key => - key match { - case k: GlobalParamKey[T] => getGlobalParamKey(k) - case _ => None - } - } - } - - def getServiceParam[T](sp: ServiceParam[T]): Option[T] = { - ParamToKeyMap.get(sp).flatMap { key => - key match { - case k: GlobalParamKey[T] => getGlobalParamKey(k) - case _ => None - } - } - } - - def registerParam[T](p: Param[T], key: GlobalParamKey[T]): Unit = { - ParamToKeyMap(p) = key - } - - def registerServiceParam[T](sp: ServiceParam[T], key: GlobalParamKey[T]): Unit = { - ParamToKeyMap(sp) = key - } -} - - - -trait HasGlobalParams extends HasServiceParams { - - def getGlobalParam[T](p: Param[T]): T = { - try { - this.getOrDefault(p) - } - catch { - case e: Exception => - GlobalParamObject.getParam(p) match { - case Some(v) => v - case None => throw e - } - } - } - - def getGlobalParam[T](name: String): T = { - val param = this.getParam(name).asInstanceOf[Param[T]] - try { - this.getOrDefault(param) - } - catch { - case e: Exception => - GlobalParamObject.getParam(param) match { - case Some(v) => v - case None => throw e - } - } - } - - def getGlobalServiceParamScalar[T](p: ServiceParam[T]): T = { - try { - this.getScalarParam(p) - } - catch { - case e: Exception => - GlobalParamObject.getServiceParam(p) match { - case Some(v) => v - case None => throw e - } - } - } - - def getGlobalServiceParamVector[T](p: ServiceParam[T]): T= { //TODO: should return a string - try { - this.getScalarParam(p) - } - catch { - case e: Exception => - GlobalParamObject.getServiceParam(p)match { - case Some(v) => v - case None => throw e - } - } - } - - def getGlobalServiceParamScalar[T](name: String): T = { - val serviceParam = this.getParam(name).asInstanceOf[ServiceParam[T]] - try { - this.getScalarParam(serviceParam) - } - catch { - case e: Exception => - GlobalParamObject.getServiceParam(serviceParam) match { - case Some(v) => v - case None => throw e - } - } - } - - def getGlobalServiceParamVector[T](name: String): T = { //TODO: should return a string - val serviceParam = this.getParam(name).asInstanceOf[ServiceParam[T]] - try { - this.getScalarParam(serviceParam) - } - catch { - case e: Exception => - GlobalParamObject.getServiceParam(serviceParam) match { - case Some(v) => v - case None => throw e - } - } - } - - protected def getGlobalServiceParamValueOpt[T](row: Row, p: ServiceParam[T]): Option[T] = { - //TODO: Confirm with Mark what this colName business is about - try { - get(p).orElse(getDefault(p)).flatMap { - case Right(colName) => Option(row.getAs[T](colName)) - case Left(value) => Some(value) - } - match { - case some @ Some(_) => some - case None => GlobalParamObject.getServiceParam(p) - } - } - catch { - case e: Exception => - GlobalParamObject.getServiceParam(p) match { - case Some(v) => Some(v) - case None => throw e - } - } - } - - - protected def getGlobalServiceParamValue[T](row: Row, p: ServiceParam[T]): T = - getGlobalServiceParamValueOpt(row, p).get -} \ No newline at end of file diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index f91b9b418b..dd1b61b63e 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -56,7 +56,7 @@ trait HasOpenAISharedParams extends HasGlobalParams with HasAPIVersion { val deploymentName = new ServiceParam[String]( this, "deploymentName", "The name of the deployment", isRequired = false) - GlobalParamObject.registerServiceParam(deploymentName, OpenAIDeploymentNameKey) + GlobalParams.registerServiceParam(deploymentName, OpenAIDeploymentNameKey) def getDeploymentName: String = getGlobalServiceParamScalar(deploymentName) diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParamSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala similarity index 64% rename from cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParamSuite.scala rename to cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala index c3612ddf46..c6714decaa 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/GlobalParamSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala @@ -1,17 +1,17 @@ // Copyright (C) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. See LICENSE in project root for information. -package com.microsoft.azure.synapse.ml.services.openai +package com.microsoft.azure.synapse.ml.services import com.microsoft.azure.synapse.ml.Secrets.getAccessToken import com.microsoft.azure.synapse.ml.core.test.base.Flaky import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing} +import com.microsoft.azure.synapse.ml.services.openai.{OpenAIAPIKey, OpenAIPrompt} import org.apache.spark.ml.util.MLReadable -import org.apache.spark.sql.functions.col -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.DataFrame import org.scalactic.Equality -class GlobalParamSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIKey with Flaky { +class GlobalParamSuite extends Flaky with OpenAIAPIKey { import spark.implicits._ @@ -21,7 +21,7 @@ class GlobalParamSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIKe super.beforeAll() } - GlobalParamObject.setGlobalParamKey(OpenAIDeploymentNameKey, deploymentName) + GlobalParams.setDeploymentName(deploymentName) lazy val prompt: OpenAIPrompt = new OpenAIPrompt() .setSubscriptionKey(openAIAPIKey) @@ -47,18 +47,4 @@ class GlobalParamSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIKe assert(nonNullCount == 3) } - - override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = { - super.assertDFEq(df1.drop("out", "outParsed"), df2.drop("out", "outParsed"))(eq) - } - - override def testObjects(): Seq[TestObject[OpenAIPrompt]] = { - val testPrompt = prompt - .setPromptTemplate("{text} rhymes with ") - - Seq(new TestObject(testPrompt, df)) - } - - override def reader: MLReadable[_] = OpenAIPrompt - } From efeb9d0ea3595ebfb3d9a7b395779bfd29a28238 Mon Sep 17 00:00:00 2001 From: Shyam Sai Date: Mon, 11 Nov 2024 11:07:31 -0500 Subject: [PATCH 3/8] Only one GlobalKey type and only one GlobalParams dictionary. Added asserts to type check Params vs Service Params. Added more tests. Tests pass! --- .../synapse/ml/services/GlobalParam.scala | 159 ++++++++++++++---- .../synapse/ml/services/openai/OpenAI.scala | 1 + .../ml/services/GlobalParamSuite.scala | 72 +++++++- 3 files changed, 193 insertions(+), 39 deletions(-) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala index b1546985d0..53d21a8ebf 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala @@ -5,42 +5,135 @@ import org.apache.spark.ml.param.Param import org.apache.spark.sql.Row import scala.collection.mutable +import scala.reflect.ClassTag -sealed trait GlobalKey[T] -sealed trait GlobalParamKey[T] extends GlobalKey[T] -sealed trait GlobalServiceParamKey[T] extends GlobalKey[Either[T, String]] //left is Scalar, right is Vector +trait GlobalKey[T] { + val name: String + val isServiceParam: Boolean +} -case object OpenAIDeploymentNameKey extends GlobalServiceParamKey[String] +case object OpenAIDeploymentNameKey extends GlobalKey[String] + {val name: String = "OpenAIDeploymentName"; val isServiceParam = true} +case object TestParamKey extends GlobalKey[Double] {val name: String = "TestParam"; val isServiceParam = false} +case object TestServiceParamKey extends GlobalKey[Int] + {val name: String = "TestServiceParam"; val isServiceParam = true} object GlobalParams { private val ParamToKeyMap: mutable.Map[Any, GlobalKey[_]] = mutable.Map.empty - private val GlobalParams: mutable.Map[GlobalParamKey[Any], Any] = mutable.Map.empty - private val GlobalServiceParams: mutable.Map[GlobalServiceParamKey[Any], Either[Any, String]] = mutable.Map.empty + private val GlobalParams: mutable.Map[GlobalKey[_], Any] = mutable.Map.empty + + private def boxedClass(c: Class[_]): Class[_] = { + if (!c.isPrimitive) c + c match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Long] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Float.TYPE => classOf[java.lang.Float] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => c // Fallback for any other primitive types + } + } + + private val StringtoKeyMap: Map[String, GlobalKey[_]] = Map( + "OpenAIDeploymentName" -> OpenAIDeploymentNameKey, + "TestParam" -> TestParamKey, + "TestServiceParam" -> TestServiceParamKey, + ) + + private def findGlobalKeyByName(keyName: String): Option[GlobalKey[_]] = { + StringtoKeyMap.get(keyName) + } + + def setGlobalParam[T](key: GlobalKey[T], value: T)(implicit ct: ClassTag[T]): Unit = { + assert(!key.isServiceParam, s"${key.name} is a Service Param. setGlobalServiceParamKey should be used.") + val expectedClass = boxedClass(ct.runtimeClass) + val actualClass = value.getClass + assert( + expectedClass.isAssignableFrom(actualClass), + s"Value of type ${actualClass.getName} is not compatible with expected type ${expectedClass.getName}" + ) + GlobalParams(key) = value + } - private def setGlobalParamKey[T](key: GlobalParamKey[T], value: T): Unit = { - GlobalParams(key.asInstanceOf[GlobalParamKey[Any]]) = value + def setGlobalParam[T](keyName: String, value: T)(implicit ct: ClassTag[T]): Unit = { + val expectedClass = boxedClass(ct.runtimeClass) + val actualClass = value.getClass + assert( + expectedClass.isAssignableFrom(actualClass), + s"Value of type ${actualClass.getName} is not compatible with expected type ${expectedClass.getName}" + ) + val key = findGlobalKeyByName(keyName) + key match { + case Some(k) => + assert(!k.isServiceParam, s"${k.name} is a Service Param. setGlobalServiceParamKey should be used.") + GlobalParams(k) = value + case None => throw new IllegalArgumentException("${keyName} is not a valid key in GlobalParams") + } } - private def setGlobalServiceParamKey[T](key: GlobalServiceParamKey[T], value: T): Unit = { - GlobalServiceParams(key.asInstanceOf[GlobalServiceParamKey[Any]]) = Left(value) + def setGlobalServiceParam[T](key: GlobalKey[T], value: T)(implicit ct: ClassTag[T]): Unit = { + assert(key.isServiceParam, s"${key.name} is a Param. setGlobalParamKey should be used.") + val expectedClass = boxedClass(ct.runtimeClass) + val actualClass = value.getClass + assert( + expectedClass.isAssignableFrom(actualClass), + s"Value of type ${actualClass.getName} is not compatible with expected type ${expectedClass.getName}" + ) + GlobalParams(key) = Left(value) } - private def setGlobalServiceParamKeyCol[T](key: GlobalServiceParamKey[T], value: String): Unit = { - GlobalServiceParams(key.asInstanceOf[GlobalServiceParamKey[Any]]) = Right(value) + def setGlobalServiceParam[T](keyName: String, value: T)(implicit ct: ClassTag[T]): Unit = { + val expectedClass = boxedClass(ct.runtimeClass) + val actualClass = value.getClass + assert( + expectedClass.isAssignableFrom(actualClass), + s"Value of type ${actualClass.getName} is not compatible with expected type ${expectedClass.getName}" + ) + val key = findGlobalKeyByName(keyName) + key match { + case Some(k) => + assert(k.isServiceParam, s"${k.name} is a Param. setGlobalParamKey should be used.") + GlobalParams(k) = Left(value) + case None => throw new IllegalArgumentException("${keyName} is not a valid key in GlobalParams") + } } - private def getGlobalParam[T](key: GlobalParamKey[T]): Option[T] = { - GlobalParams.get(key.asInstanceOf[GlobalParamKey[Any]]).map(_.asInstanceOf[T]) + def setGlobalServiceParamCol[T](key: GlobalKey[T], value: String): Unit = { + assert(key.isServiceParam, s"${key.name} is a Param. setGlobalParamKey should be used.") + GlobalParams(key) = Right(value) } - private def getGlobalServiceParam[T](key: GlobalServiceParamKey[T]): Option[Either[T, String]] = { - GlobalServiceParams.get(key.asInstanceOf[GlobalServiceParamKey[Any]]).map(_.asInstanceOf[Either[T, String]]) + def setGlobalServiceParamCol[T](keyName: String, value: String): Unit = { + val key = findGlobalKeyByName(keyName) + key match { + case Some(k) => + assert(k.isServiceParam, s"${k.name} is a Param. setGlobalParamKey should be used.") + GlobalParams(k) = Right(value) + case None => throw new IllegalArgumentException("${keyName} is not a valid key in GlobalParams") + } + } + + private def getGlobalParam[T](key: GlobalKey[T]): Option[T] = { + GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[T]) + } + + private def getGlobalServiceParam[T](key: GlobalKey[T]): Option[Either[T, String]] = { + val value = GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[Either[T, String]]) + value match { + case some @ Some(v) => + assert(v.isInstanceOf[Either[T, String]], + "getGlobalServiceParam used to fetch a normal Param") + value + case None => None + } } def getParam[T](p: Param[T]): Option[T] = { ParamToKeyMap.get(p).flatMap { key => key match { - case k: GlobalParamKey[T] => getGlobalParam(k) + case k: GlobalKey[T] => + assert(!k.isServiceParam, s"${k.name} is a Service Param. getServiceParam should be used.") + getGlobalParam(k) case _ => None } } @@ -49,7 +142,9 @@ object GlobalParams { def getServiceParam[T](sp: ServiceParam[T]): Option[Either[T, String]] = { ParamToKeyMap.get(sp).flatMap { key => key match { - case k: GlobalServiceParamKey[T] => getGlobalServiceParam(k) + case k: GlobalKey[T] => + assert(k.isServiceParam, s"${k.name} is a Param. getParam should be used.") + getGlobalServiceParam(k) case _ => None } } @@ -58,9 +153,11 @@ object GlobalParams { def getServiceParamScalar[T](sp: ServiceParam[T]): Option[T] = { ParamToKeyMap.get(sp).flatMap { key => key match { - case k: GlobalServiceParamKey[T] => + case k: GlobalKey[T] => getGlobalServiceParam(k) match { - case Some(Left(value)) => Some(value) + case Some(Left(value)) => + assert(k.isServiceParam, s"${k.name} is a Param. getParam should be used.") + Some(value) case _ => None } case _ => None @@ -71,9 +168,11 @@ object GlobalParams { def getServiceParamVector[T](sp: ServiceParam[T]): Option[String] = { ParamToKeyMap.get(sp).flatMap { key => key match { - case k: GlobalServiceParamKey[T] => + case k: GlobalKey[T] => getGlobalServiceParam(k) match { - case Some(Right(colName)) => Some(colName) + case Some(Right(colName)) => + assert(k.isServiceParam, s"${k.name} is a Param. getParam should be used.") + Some(colName) case _ => None } case _ => None @@ -81,25 +180,17 @@ object GlobalParams { } } - def registerParam[T](p: Param[T], key: GlobalParamKey[T]): Unit = { + def registerParam[T](p: Param[T], key: GlobalKey[T]): Unit = { + assert(!key.isServiceParam, s"${key.name} is a Service Param. registerServiceParam should be used.") ParamToKeyMap(p) = key } - def registerServiceParam[T](sp: ServiceParam[T], key: GlobalServiceParamKey[T]): Unit = { + def registerServiceParam[T](sp: ServiceParam[T], key: GlobalKey[T]): Unit = { + assert(key.isServiceParam, s"${key.name} is a Param. registerParam should be used.") ParamToKeyMap(sp) = key } - - def setDeploymentName(deploymentName: String): Unit = { - setGlobalServiceParamKey(OpenAIDeploymentNameKey, deploymentName) - } - - def setDeploymentNameCol(colName: String): Unit = { - setGlobalServiceParamKeyCol(OpenAIDeploymentNameKey, colName) - } } - - trait HasGlobalParams extends HasServiceParams { def getGlobalParam[T](p: Param[T]): T = { diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index dd1b61b63e..90de15260e 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -51,6 +51,7 @@ trait HasMessagesInput extends Params { def setMessagesCol(v: String): this.type = set(messagesCol, v) } + trait HasOpenAISharedParams extends HasGlobalParams with HasAPIVersion { val deploymentName = new ServiceParam[String]( diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala index c6714decaa..792ae04957 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala @@ -5,11 +5,44 @@ package com.microsoft.azure.synapse.ml.services import com.microsoft.azure.synapse.ml.Secrets.getAccessToken import com.microsoft.azure.synapse.ml.core.test.base.Flaky -import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing} +import com.microsoft.azure.synapse.ml.param.ServiceParam import com.microsoft.azure.synapse.ml.services.openai.{OpenAIAPIKey, OpenAIPrompt} -import org.apache.spark.ml.util.MLReadable +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame -import org.scalactic.Equality +import spray.json.DefaultJsonProtocol.IntJsonFormat + +trait TestGlobalParamsTrait extends HasGlobalParams { + val testParam: Param[Double] = new Param[Double]( + this, "TestParam", "Test Param") + + GlobalParams.registerParam(testParam, TestParamKey) + + def getTestParam: Double = getGlobalParam(testParam) + + def setTestParam(v: Double): this.type = set(testParam, v) + + val testServiceParam = new ServiceParam[Int]( + this, "testServiceParam", "Test Service Param", isRequired = false) + + GlobalParams.registerServiceParam(testServiceParam, TestServiceParamKey) + + def getTestServiceParam: Int = getGlobalServiceParamScalar(testServiceParam) + + def setTestServiceParam(v: Int): this.type = setScalarParam(testServiceParam, v) + + def getTestServiceParamCol: String = getVectorParam(testServiceParam) + + def setTestServiceParamCol(v: String): this.type = setVectorParam(testServiceParam, v) +} + +class TestGlobalParams extends TestGlobalParamsTrait { + + override def copy(extra: ParamMap): Transformer = defaultCopy(extra) + + override val uid: String = Identifiable.randomUID("TestGlobalParams") +} class GlobalParamSuite extends Flaky with OpenAIAPIKey { @@ -21,7 +54,35 @@ class GlobalParamSuite extends Flaky with OpenAIAPIKey { super.beforeAll() } - GlobalParams.setDeploymentName(deploymentName) + GlobalParams.setGlobalParam(TestParamKey, 12.5) + GlobalParams.setGlobalServiceParam(TestServiceParamKey, 1) + + val testGlobalParams = new TestGlobalParams() + + test("Basic Usage") { + assert(testGlobalParams.getTestParam == 12.5) + assert(testGlobalParams.getTestServiceParam == 1) + } + + test("Test Changing Value") { + assert(testGlobalParams.getTestParam == 12.5) + GlobalParams.setGlobalParam("TestParam", 19.853) + assert(testGlobalParams.getTestParam == 19.853) + assert(testGlobalParams.getTestServiceParam == 1) + GlobalParams.setGlobalServiceParam("TestServiceParam", 4) + assert(testGlobalParams.getTestServiceParam == 4) + } + + test("Using wrong setters and getters") { + assertThrows[AssertionError] { + GlobalParams.setGlobalServiceParam("TestParam", 8.8888) + } + assertThrows[AssertionError] { + GlobalParams.getParam(testGlobalParams.testServiceParam) + } + } + + GlobalParams.setGlobalServiceParam("OpenAIDeploymentName", deploymentName) lazy val prompt: OpenAIPrompt = new OpenAIPrompt() .setSubscriptionKey(openAIAPIKey) @@ -36,7 +97,7 @@ class GlobalParamSuite extends Flaky with OpenAIAPIKey { (null, "none") //scalastyle:ignore null ).toDF("text", "category") - test("Basic Usage") { + test("OpenAIPrompt w Globals") { val nonNullCount = prompt .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") .setPostProcessing("csv") @@ -48,3 +109,4 @@ class GlobalParamSuite extends Flaky with OpenAIAPIKey { assert(nonNullCount == 3) } } + From 7fb45dbc75661eee2951ebb23de62d25e3535f97 Mon Sep 17 00:00:00 2001 From: Shyam Sai Date: Mon, 18 Nov 2024 11:53:16 -0500 Subject: [PATCH 4/8] Use initialize objects with GlobalParams and get rid of boxedClass logic --- .../synapse/ml/services/GlobalParam.scala | 60 ++++--------------- .../synapse/ml/services/openai/OpenAI.scala | 2 + .../ml/services/GlobalParamSuite.scala | 10 +++- 3 files changed, 19 insertions(+), 53 deletions(-) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala index 53d21a8ebf..9d9c220942 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala @@ -5,63 +5,35 @@ import org.apache.spark.ml.param.Param import org.apache.spark.sql.Row import scala.collection.mutable -import scala.reflect.ClassTag +import com.microsoft.azure.synapse.ml.core.utils.JarLoadingUtils trait GlobalKey[T] { val name: String val isServiceParam: Boolean } -case object OpenAIDeploymentNameKey extends GlobalKey[String] - {val name: String = "OpenAIDeploymentName"; val isServiceParam = true} -case object TestParamKey extends GlobalKey[Double] {val name: String = "TestParam"; val isServiceParam = false} -case object TestServiceParamKey extends GlobalKey[Int] - {val name: String = "TestServiceParam"; val isServiceParam = true} - object GlobalParams { private val ParamToKeyMap: mutable.Map[Any, GlobalKey[_]] = mutable.Map.empty private val GlobalParams: mutable.Map[GlobalKey[_], Any] = mutable.Map.empty - private def boxedClass(c: Class[_]): Class[_] = { - if (!c.isPrimitive) c - c match { - case java.lang.Integer.TYPE => classOf[java.lang.Integer] - case java.lang.Long.TYPE => classOf[java.lang.Long] - case java.lang.Double.TYPE => classOf[java.lang.Double] - case java.lang.Float.TYPE => classOf[java.lang.Float] - case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] - case _ => c // Fallback for any other primitive types + private val StringtoKeyMap: mutable.Map[String, GlobalKey[_]] = { + val strToKeyMap = mutable.Map[String, GlobalKey[_]]() + JarLoadingUtils.instantiateObjects[GlobalKey[_]]().foreach { key: GlobalKey[_] => + strToKeyMap += (key.name -> key) } + strToKeyMap } - private val StringtoKeyMap: Map[String, GlobalKey[_]] = Map( - "OpenAIDeploymentName" -> OpenAIDeploymentNameKey, - "TestParam" -> TestParamKey, - "TestServiceParam" -> TestServiceParamKey, - ) - private def findGlobalKeyByName(keyName: String): Option[GlobalKey[_]] = { StringtoKeyMap.get(keyName) } - def setGlobalParam[T](key: GlobalKey[T], value: T)(implicit ct: ClassTag[T]): Unit = { + def setGlobalParam[T](key: GlobalKey[T], value: T): Unit = { assert(!key.isServiceParam, s"${key.name} is a Service Param. setGlobalServiceParamKey should be used.") - val expectedClass = boxedClass(ct.runtimeClass) - val actualClass = value.getClass - assert( - expectedClass.isAssignableFrom(actualClass), - s"Value of type ${actualClass.getName} is not compatible with expected type ${expectedClass.getName}" - ) GlobalParams(key) = value } - def setGlobalParam[T](keyName: String, value: T)(implicit ct: ClassTag[T]): Unit = { - val expectedClass = boxedClass(ct.runtimeClass) - val actualClass = value.getClass - assert( - expectedClass.isAssignableFrom(actualClass), - s"Value of type ${actualClass.getName} is not compatible with expected type ${expectedClass.getName}" - ) + def setGlobalParam[T](keyName: String, value: T): Unit = { val key = findGlobalKeyByName(keyName) key match { case Some(k) => @@ -71,24 +43,12 @@ object GlobalParams { } } - def setGlobalServiceParam[T](key: GlobalKey[T], value: T)(implicit ct: ClassTag[T]): Unit = { + def setGlobalServiceParam[T](key: GlobalKey[T], value: T): Unit = { assert(key.isServiceParam, s"${key.name} is a Param. setGlobalParamKey should be used.") - val expectedClass = boxedClass(ct.runtimeClass) - val actualClass = value.getClass - assert( - expectedClass.isAssignableFrom(actualClass), - s"Value of type ${actualClass.getName} is not compatible with expected type ${expectedClass.getName}" - ) GlobalParams(key) = Left(value) } - def setGlobalServiceParam[T](keyName: String, value: T)(implicit ct: ClassTag[T]): Unit = { - val expectedClass = boxedClass(ct.runtimeClass) - val actualClass = value.getClass - assert( - expectedClass.isAssignableFrom(actualClass), - s"Value of type ${actualClass.getName} is not compatible with expected type ${expectedClass.getName}" - ) + def setGlobalServiceParam[T](keyName: String, value: T): Unit = { val key = findGlobalKeyByName(keyName) key match { case Some(k) => diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index 90de15260e..e89bb2b735 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -51,6 +51,8 @@ trait HasMessagesInput extends Params { def setMessagesCol(v: String): this.type = set(messagesCol, v) } +case object OpenAIDeploymentNameKey extends GlobalKey[String] +{val name: String = "OpenAIDeploymentName"; val isServiceParam = true} trait HasOpenAISharedParams extends HasGlobalParams with HasAPIVersion { diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala index 792ae04957..f8f331d25c 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala @@ -44,6 +44,11 @@ class TestGlobalParams extends TestGlobalParamsTrait { override val uid: String = Identifiable.randomUID("TestGlobalParams") } + +case object TestParamKey extends GlobalKey[Double] {val name: String = "TestParam"; val isServiceParam = false} +case object TestServiceParamKey extends GlobalKey[Int] +{val name: String = "TestServiceParam"; val isServiceParam = true} + class GlobalParamSuite extends Flaky with OpenAIAPIKey { import spark.implicits._ @@ -54,12 +59,11 @@ class GlobalParamSuite extends Flaky with OpenAIAPIKey { super.beforeAll() } - GlobalParams.setGlobalParam(TestParamKey, 12.5) - GlobalParams.setGlobalServiceParam(TestServiceParamKey, 1) - val testGlobalParams = new TestGlobalParams() test("Basic Usage") { + GlobalParams.setGlobalParam(TestParamKey, 12.5) + GlobalParams.setGlobalServiceParam(TestServiceParamKey, 1) assert(testGlobalParams.getTestParam == 12.5) assert(testGlobalParams.getTestServiceParam == 1) } From 8d0f87c47cfe0e1827bbfbde22dd23a5ea11638e Mon Sep 17 00:00:00 2001 From: Shyam Sai Date: Thu, 21 Nov 2024 13:56:05 -0500 Subject: [PATCH 5/8] Refactor GlobalParams. Move to core, and set up OpenAIDefaults as user interface. --- .../ml/services/CognitiveServiceBase.scala | 17 +- .../synapse/ml/services/GlobalParam.scala | 264 ------------------ .../synapse/ml/services/openai/OpenAI.scala | 10 +- .../openai/OpenAIChatCompletion.scala | 2 +- .../ml/services/openai/OpenAICompletion.scala | 2 +- .../ml/services/openai/OpenAIDefaults.scala | 9 + .../ml/services/openai/OpenAIEmbedding.scala | 2 +- .../ml/services/openai/OpenAIPrompt.scala | 16 +- .../ml/services/GlobalParamSuite.scala | 116 -------- .../services/openai/OpenAIDefaultsSuite.scala | 56 ++++ .../azure/synapse/ml/param/GlobalParams.scala | 79 ++++++ .../synapse/ml/param/GlobalParamsSuite.scala | 60 ++++ 12 files changed, 242 insertions(+), 391 deletions(-) delete mode 100644 cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala create mode 100644 cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala delete mode 100644 cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala create mode 100644 cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala create mode 100644 core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala create mode 100644 core/src/test/scala/com/microsoft/azure/synapse/ml/param/GlobalParamsSuite.scala diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala index 2d123edf56..67b8f69782 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala @@ -10,7 +10,7 @@ import com.microsoft.azure.synapse.ml.fabric.FabricClient import com.microsoft.azure.synapse.ml.io.http._ import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails -import com.microsoft.azure.synapse.ml.param.ServiceParam +import com.microsoft.azure.synapse.ml.param.{GlobalParams, ServiceParam} import com.microsoft.azure.synapse.ml.stages.{DropColumns, Lambda} import org.apache.http.NameValuePair import org.apache.http.client.methods.{HttpEntityEnclosingRequestBase, HttpPost, HttpRequestBase} @@ -547,6 +547,21 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform } override def transform(dataset: Dataset[_]): DataFrame = { + // check for empty params. Fill em with OpenAIDefaults + this.params + .filter(p => !this.isSet(p) && !this.hasDefault(p)) + .foreach { p => + GlobalParams.getParam(p) match { + case Some(v) => + p match { + case serviceParam: ServiceParam[_] => + setScalarParam(serviceParam.asInstanceOf[ServiceParam[Any]], v) + case param: Param[_] => + set(param.asInstanceOf[Param[Any]], v) + } + case None => + } + } logTransform[DataFrame](getInternalTransformer(dataset.schema).transform(dataset), dataset.columns.length ) } diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala deleted file mode 100644 index 9d9c220942..0000000000 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala +++ /dev/null @@ -1,264 +0,0 @@ -package com.microsoft.azure.synapse.ml.services - -import com.microsoft.azure.synapse.ml.param.ServiceParam -import org.apache.spark.ml.param.Param -import org.apache.spark.sql.Row - -import scala.collection.mutable -import com.microsoft.azure.synapse.ml.core.utils.JarLoadingUtils - -trait GlobalKey[T] { - val name: String - val isServiceParam: Boolean -} - -object GlobalParams { - private val ParamToKeyMap: mutable.Map[Any, GlobalKey[_]] = mutable.Map.empty - private val GlobalParams: mutable.Map[GlobalKey[_], Any] = mutable.Map.empty - - private val StringtoKeyMap: mutable.Map[String, GlobalKey[_]] = { - val strToKeyMap = mutable.Map[String, GlobalKey[_]]() - JarLoadingUtils.instantiateObjects[GlobalKey[_]]().foreach { key: GlobalKey[_] => - strToKeyMap += (key.name -> key) - } - strToKeyMap - } - - private def findGlobalKeyByName(keyName: String): Option[GlobalKey[_]] = { - StringtoKeyMap.get(keyName) - } - - def setGlobalParam[T](key: GlobalKey[T], value: T): Unit = { - assert(!key.isServiceParam, s"${key.name} is a Service Param. setGlobalServiceParamKey should be used.") - GlobalParams(key) = value - } - - def setGlobalParam[T](keyName: String, value: T): Unit = { - val key = findGlobalKeyByName(keyName) - key match { - case Some(k) => - assert(!k.isServiceParam, s"${k.name} is a Service Param. setGlobalServiceParamKey should be used.") - GlobalParams(k) = value - case None => throw new IllegalArgumentException("${keyName} is not a valid key in GlobalParams") - } - } - - def setGlobalServiceParam[T](key: GlobalKey[T], value: T): Unit = { - assert(key.isServiceParam, s"${key.name} is a Param. setGlobalParamKey should be used.") - GlobalParams(key) = Left(value) - } - - def setGlobalServiceParam[T](keyName: String, value: T): Unit = { - val key = findGlobalKeyByName(keyName) - key match { - case Some(k) => - assert(k.isServiceParam, s"${k.name} is a Param. setGlobalParamKey should be used.") - GlobalParams(k) = Left(value) - case None => throw new IllegalArgumentException("${keyName} is not a valid key in GlobalParams") - } - } - - def setGlobalServiceParamCol[T](key: GlobalKey[T], value: String): Unit = { - assert(key.isServiceParam, s"${key.name} is a Param. setGlobalParamKey should be used.") - GlobalParams(key) = Right(value) - } - - def setGlobalServiceParamCol[T](keyName: String, value: String): Unit = { - val key = findGlobalKeyByName(keyName) - key match { - case Some(k) => - assert(k.isServiceParam, s"${k.name} is a Param. setGlobalParamKey should be used.") - GlobalParams(k) = Right(value) - case None => throw new IllegalArgumentException("${keyName} is not a valid key in GlobalParams") - } - } - - private def getGlobalParam[T](key: GlobalKey[T]): Option[T] = { - GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[T]) - } - - private def getGlobalServiceParam[T](key: GlobalKey[T]): Option[Either[T, String]] = { - val value = GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[Either[T, String]]) - value match { - case some @ Some(v) => - assert(v.isInstanceOf[Either[T, String]], - "getGlobalServiceParam used to fetch a normal Param") - value - case None => None - } - } - - def getParam[T](p: Param[T]): Option[T] = { - ParamToKeyMap.get(p).flatMap { key => - key match { - case k: GlobalKey[T] => - assert(!k.isServiceParam, s"${k.name} is a Service Param. getServiceParam should be used.") - getGlobalParam(k) - case _ => None - } - } - } - - def getServiceParam[T](sp: ServiceParam[T]): Option[Either[T, String]] = { - ParamToKeyMap.get(sp).flatMap { key => - key match { - case k: GlobalKey[T] => - assert(k.isServiceParam, s"${k.name} is a Param. getParam should be used.") - getGlobalServiceParam(k) - case _ => None - } - } - } - - def getServiceParamScalar[T](sp: ServiceParam[T]): Option[T] = { - ParamToKeyMap.get(sp).flatMap { key => - key match { - case k: GlobalKey[T] => - getGlobalServiceParam(k) match { - case Some(Left(value)) => - assert(k.isServiceParam, s"${k.name} is a Param. getParam should be used.") - Some(value) - case _ => None - } - case _ => None - } - } - } - - def getServiceParamVector[T](sp: ServiceParam[T]): Option[String] = { - ParamToKeyMap.get(sp).flatMap { key => - key match { - case k: GlobalKey[T] => - getGlobalServiceParam(k) match { - case Some(Right(colName)) => - assert(k.isServiceParam, s"${k.name} is a Param. getParam should be used.") - Some(colName) - case _ => None - } - case _ => None - } - } - } - - def registerParam[T](p: Param[T], key: GlobalKey[T]): Unit = { - assert(!key.isServiceParam, s"${key.name} is a Service Param. registerServiceParam should be used.") - ParamToKeyMap(p) = key - } - - def registerServiceParam[T](sp: ServiceParam[T], key: GlobalKey[T]): Unit = { - assert(key.isServiceParam, s"${key.name} is a Param. registerParam should be used.") - ParamToKeyMap(sp) = key - } -} - -trait HasGlobalParams extends HasServiceParams { - - def getGlobalParam[T](p: Param[T]): T = { - try { - this.getOrDefault(p) - } - catch { - case e: Exception => - GlobalParams.getParam(p) match { - case Some(v) => v - case None => throw e - } - } - } - - def getGlobalParam[T](name: String): T = { - val param = this.getParam(name).asInstanceOf[Param[T]] - try { - this.getOrDefault(param) - } - catch { - case e: Exception => - GlobalParams.getParam(param) match { - case Some(v) => v - case None => throw e - } - } - } - - def getGlobalServiceParamScalar[T](p: ServiceParam[T]): T = { - try { - this.getScalarParam(p) - } - catch { - case e: Exception => - GlobalParams.getServiceParamScalar(p) match { - case Some(v) => v - case None => throw e - } - } - } - - def getGlobalServiceParamVector[T](p: ServiceParam[T]): String= { - try { - this.getVectorParam(p) - } - catch { - case e: Exception => - GlobalParams.getServiceParamVector(p)match { - case Some(v) => v - case None => throw e - } - } - } - - def getGlobalServiceParamScalar[T](name: String): T = { - val serviceParam = this.getParam(name).asInstanceOf[ServiceParam[T]] - try { - this.getScalarParam(serviceParam) - } - catch { - case e: Exception => - GlobalParams.getServiceParamScalar(serviceParam) match { - case Some(v) => v - case None => throw e - } - } - } - - def getGlobalServiceParamVector[T](name: String): String = { - val serviceParam = this.getParam(name).asInstanceOf[ServiceParam[T]] - try { - this.getVectorParam(serviceParam) - } - catch { - case e: Exception => - GlobalParams.getServiceParamVector(serviceParam) match { - case Some(v) => v - case None => throw e - } - } - } - - protected def getGlobalServiceParamValueOpt[T](row: Row, p: ServiceParam[T]): Option[T] = { - val globalParam: Option[T] = GlobalParams.getServiceParam(p).flatMap { - case Right(colName) => Option(row.getAs[T](colName)) - case Left(value) => Some(value) - } - try { - get(p).orElse(getDefault(p)).flatMap { - case Right(colName) => Option(row.getAs[T](colName)) - case Left(value) => Some(value) - } - match { - case some @ Some(_) => some - case None => globalParam - } - } - catch { - case e: Exception => - globalParam match { - case Some(v) => Some(v) - case None => throw e - } - } - } - - - protected def getGlobalServiceParamValue[T](row: Row, p: ServiceParam[T]): T = - getGlobalServiceParamValueOpt(row, p).get -} \ No newline at end of file diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index e89bb2b735..f64c6f91e1 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -6,7 +6,7 @@ package com.microsoft.azure.synapse.ml.services.openai import com.microsoft.azure.synapse.ml.codegen.GenerationUtils import com.microsoft.azure.synapse.ml.fabric.{FabricClient, OpenAIFabricSetting} import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails -import com.microsoft.azure.synapse.ml.param.ServiceParam +import com.microsoft.azure.synapse.ml.param.{GlobalKey, GlobalParams, ServiceParam} import com.microsoft.azure.synapse.ml.services._ import org.apache.spark.ml.PipelineModel import org.apache.spark.ml.param.{Param, Params} @@ -52,16 +52,16 @@ trait HasMessagesInput extends Params { } case object OpenAIDeploymentNameKey extends GlobalKey[String] -{val name: String = "OpenAIDeploymentName"; val isServiceParam = true} +{val name: String = "OpenAIDeploymentName"} -trait HasOpenAISharedParams extends HasGlobalParams with HasAPIVersion { +trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion { val deploymentName = new ServiceParam[String]( this, "deploymentName", "The name of the deployment", isRequired = false) - GlobalParams.registerServiceParam(deploymentName, OpenAIDeploymentNameKey) + GlobalParams.registerParam(deploymentName, OpenAIDeploymentNameKey) - def getDeploymentName: String = getGlobalServiceParamScalar(deploymentName) + def getDeploymentName: String = getScalarParam(deploymentName) def setDeploymentName(v: String): this.type = setScalarParam(deploymentName, v) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala index d7a81e1f8b..379d797766 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala @@ -34,7 +34,7 @@ class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase( } override protected def prepareUrlRoot: Row => String = { row => - s"${getUrl}openai/deployments/${getGlobalServiceParamValue(row, deploymentName)}/chat/completions" + s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/chat/completions" } override protected[openai] def prepareEntity: Row => Option[AbstractHttpEntity] = { diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala index 73d97cf7b0..219bc34d87 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala @@ -34,7 +34,7 @@ class OpenAICompletion(override val uid: String) extends OpenAIServicesBase(uid) } override protected def prepareUrlRoot: Row => String = { row => - s"${getUrl}openai/deployments/${getGlobalServiceParamValue(row, deploymentName)}/completions" + s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/completions" } override protected[openai] def prepareEntity: Row => Option[AbstractHttpEntity] = { diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala new file mode 100644 index 0000000000..a50e7fa187 --- /dev/null +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala @@ -0,0 +1,9 @@ +package com.microsoft.azure.synapse.ml.services.openai + +import com.microsoft.azure.synapse.ml.param.GlobalParams + +object OpenAIDefaults { + def setDeploymentName(v: String): Unit = { + GlobalParams.setGlobalParam(OpenAIDeploymentNameKey, v) + } +} diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala index cdb6e6d070..342254f8fb 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala @@ -60,7 +60,7 @@ class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid) } override protected def prepareUrlRoot: Row => String = { row => - s"${getUrl}openai/deployments/${getGlobalServiceParamValue(row, deploymentName)}/embeddings" + s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/embeddings" } private[this] def getStringEntity[A](text: A, optionalParams: Map[String, Any]): StringEntity = { diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index a43f3ffe3a..150f710966 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -7,7 +7,7 @@ import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol import com.microsoft.azure.synapse.ml.core.spark.Functions import com.microsoft.azure.synapse.ml.io.http.{ConcurrencyParams, HasErrorCol, HasURL} import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging} -import com.microsoft.azure.synapse.ml.param.StringStringMapParam +import com.microsoft.azure.synapse.ml.param.{GlobalParams, ServiceParam, StringStringMapParam} import com.microsoft.azure.synapse.ml.services._ import org.apache.http.entity.AbstractHttpEntity import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} @@ -124,7 +124,19 @@ class OpenAIPrompt(override val uid: String) extends Transformer override def transform(dataset: Dataset[_]): DataFrame = { import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions._ - + this.params + .filter(p => !this.isSet(p) && !this.hasDefault(p)) + .foreach { p => + GlobalParams.getParam(p) match { + case Some(v) => + p match { + case serviceParam: ServiceParam[_] => setScalarParam(serviceParam.asInstanceOf[ServiceParam[Any]], v) + case param: Param[_] => + set(param.asInstanceOf[Param[Any]], v) + } + case None => + } + } logTransform[DataFrame]({ val df = dataset.toDF diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala deleted file mode 100644 index f8f331d25c..0000000000 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright (C) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See LICENSE in project root for information. - -package com.microsoft.azure.synapse.ml.services - -import com.microsoft.azure.synapse.ml.Secrets.getAccessToken -import com.microsoft.azure.synapse.ml.core.test.base.Flaky -import com.microsoft.azure.synapse.ml.param.ServiceParam -import com.microsoft.azure.synapse.ml.services.openai.{OpenAIAPIKey, OpenAIPrompt} -import org.apache.spark.ml.Transformer -import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.DataFrame -import spray.json.DefaultJsonProtocol.IntJsonFormat - -trait TestGlobalParamsTrait extends HasGlobalParams { - val testParam: Param[Double] = new Param[Double]( - this, "TestParam", "Test Param") - - GlobalParams.registerParam(testParam, TestParamKey) - - def getTestParam: Double = getGlobalParam(testParam) - - def setTestParam(v: Double): this.type = set(testParam, v) - - val testServiceParam = new ServiceParam[Int]( - this, "testServiceParam", "Test Service Param", isRequired = false) - - GlobalParams.registerServiceParam(testServiceParam, TestServiceParamKey) - - def getTestServiceParam: Int = getGlobalServiceParamScalar(testServiceParam) - - def setTestServiceParam(v: Int): this.type = setScalarParam(testServiceParam, v) - - def getTestServiceParamCol: String = getVectorParam(testServiceParam) - - def setTestServiceParamCol(v: String): this.type = setVectorParam(testServiceParam, v) -} - -class TestGlobalParams extends TestGlobalParamsTrait { - - override def copy(extra: ParamMap): Transformer = defaultCopy(extra) - - override val uid: String = Identifiable.randomUID("TestGlobalParams") -} - - -case object TestParamKey extends GlobalKey[Double] {val name: String = "TestParam"; val isServiceParam = false} -case object TestServiceParamKey extends GlobalKey[Int] -{val name: String = "TestServiceParam"; val isServiceParam = true} - -class GlobalParamSuite extends Flaky with OpenAIAPIKey { - - import spark.implicits._ - - override def beforeAll(): Unit = { - val aadToken = getAccessToken("https://cognitiveservices.azure.com/") - println(s"Triggering token creation early ${aadToken.length}") - super.beforeAll() - } - - val testGlobalParams = new TestGlobalParams() - - test("Basic Usage") { - GlobalParams.setGlobalParam(TestParamKey, 12.5) - GlobalParams.setGlobalServiceParam(TestServiceParamKey, 1) - assert(testGlobalParams.getTestParam == 12.5) - assert(testGlobalParams.getTestServiceParam == 1) - } - - test("Test Changing Value") { - assert(testGlobalParams.getTestParam == 12.5) - GlobalParams.setGlobalParam("TestParam", 19.853) - assert(testGlobalParams.getTestParam == 19.853) - assert(testGlobalParams.getTestServiceParam == 1) - GlobalParams.setGlobalServiceParam("TestServiceParam", 4) - assert(testGlobalParams.getTestServiceParam == 4) - } - - test("Using wrong setters and getters") { - assertThrows[AssertionError] { - GlobalParams.setGlobalServiceParam("TestParam", 8.8888) - } - assertThrows[AssertionError] { - GlobalParams.getParam(testGlobalParams.testServiceParam) - } - } - - GlobalParams.setGlobalServiceParam("OpenAIDeploymentName", deploymentName) - - lazy val prompt: OpenAIPrompt = new OpenAIPrompt() - .setSubscriptionKey(openAIAPIKey) - .setCustomServiceName(openAIServiceName) - .setOutputCol("outParsed") - .setTemperature(0) - - lazy val df: DataFrame = Seq( - ("apple", "fruits"), - ("mercedes", "cars"), - ("cake", "dishes"), - (null, "none") //scalastyle:ignore null - ).toDF("text", "category") - - test("OpenAIPrompt w Globals") { - val nonNullCount = prompt - .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") - .setPostProcessing("csv") - .transform(df) - .select("outParsed") - .collect() - .count(r => Option(r.getSeq[String](0)).isDefined) - - assert(nonNullCount == 3) - } -} - diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala new file mode 100644 index 0000000000..7ed1c064c2 --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala @@ -0,0 +1,56 @@ +package com.microsoft.azure.synapse.ml.services.openai + +import com.microsoft.azure.synapse.ml.core.test.base.Flaky +import org.apache.spark.sql.{DataFrame, Row} + +class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { + + import spark.implicits._ + + OpenAIDefaults.setDeploymentName(deploymentName) + + def promptCompletion: OpenAICompletion = new OpenAICompletion() + .setCustomServiceName(openAIServiceName) + .setMaxTokens(200) + .setOutputCol("out") + .setSubscriptionKey(openAIAPIKey) + .setPromptCol("prompt") + + lazy val promptDF: DataFrame = Seq( + "Once upon a time", + "Best programming language award goes to", + "SynapseML is " + ).toDF("prompt") + + test("Completion w Globals") { + val fromRow = CompletionResponse.makeFromRowConverter + promptCompletion.transform(promptDF).collect().foreach(r => + fromRow(r.getAs[Row]("out")).choices.foreach(c => + assert(c.text.length > 10))) + } + + lazy val prompt: OpenAIPrompt = new OpenAIPrompt() + .setSubscriptionKey(openAIAPIKey) + .setCustomServiceName(openAIServiceName) + .setOutputCol("outParsed") + .setTemperature(0) + + lazy val df: DataFrame = Seq( + ("apple", "fruits"), + ("mercedes", "cars"), + ("cake", "dishes"), + (null, "none") //scalastyle:ignore null + ).toDF("text", "category") + + test("OpenAIPrompt w Globals") { + val nonNullCount = prompt + .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") + .setPostProcessing("csv") + .transform(df) + .select("outParsed") + .collect() + .count(r => Option(r.getSeq[String](0)).isDefined) + + assert(nonNullCount == 3) + } +} diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala new file mode 100644 index 0000000000..42b3173700 --- /dev/null +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala @@ -0,0 +1,79 @@ +package com.microsoft.azure.synapse.ml.param + +import org.apache.spark.ml.param.{Param, Params} +import org.apache.spark.sql.Row + +import scala.collection.mutable +import com.microsoft.azure.synapse.ml.core.utils.JarLoadingUtils + +trait GlobalKey[T] { + val name: String +} + +object GlobalParams { + private val ParamToKeyMap: mutable.Map[Any, GlobalKey[_]] = mutable.Map.empty + private val GlobalParams: mutable.Map[GlobalKey[_], Any] = mutable.Map.empty + + private val StringtoKeyMap: mutable.Map[String, GlobalKey[_]] = { + val strToKeyMap = mutable.Map[String, GlobalKey[_]]() + JarLoadingUtils.instantiateObjects[GlobalKey[_]]().foreach { key: GlobalKey[_] => + strToKeyMap += (key.name -> key) + } + strToKeyMap + } + + private def findGlobalKeyByName(keyName: String): Option[GlobalKey[_]] = { + StringtoKeyMap.get(keyName) + } + + def setGlobalParam[T](key: GlobalKey[T], value: T): Unit = { + GlobalParams(key) = value + } + + def setGlobalParam[T](keyName: String, value: T): Unit = { + val key = findGlobalKeyByName(keyName) + key match { + case Some(k) => + GlobalParams(k) = value + case None => throw new IllegalArgumentException("${keyName} is not a valid key in GlobalParams") + } + } + + private def getGlobalParam[T](key: GlobalKey[T]): Option[T] = { + GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[T]) + } + + def getParam[T](p: Param[T]): Option[T] = { + ParamToKeyMap.get(p).flatMap { key => + key match { + case k: GlobalKey[T] => + getGlobalParam(k) + case _ => None + } + } + } + + def registerParam[T](p: Param[T], key: GlobalKey[T]): Unit = { + ParamToKeyMap(p) = key + } + + def registerParam[T](sp: ServiceParam[T], key: GlobalKey[T]): Unit = { + ParamToKeyMap(sp) = key + } +} + +trait HasGlobalParams extends Params{ + + def getGlobalParam[T](p: Param[T]): T = { + try { + this.getOrDefault(p) + } + catch { + case e: Exception => + GlobalParams.getParam(p) match { + case Some(v) => v + case None => throw e + } + } + } +} \ No newline at end of file diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/param/GlobalParamsSuite.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/GlobalParamsSuite.scala new file mode 100644 index 0000000000..2912ce12b7 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/GlobalParamsSuite.scala @@ -0,0 +1,60 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.param + +import com.microsoft.azure.synapse.ml.Secrets.getAccessToken +import com.microsoft.azure.synapse.ml.core.test.base.Flaky +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.util.Identifiable +import spray.json.DefaultJsonProtocol.IntJsonFormat + +trait TestGlobalParamsTrait extends HasGlobalParams { + val testParam: Param[Double] = new Param[Double]( + this, "TestParam", "Test Param") + + GlobalParams.registerParam(testParam, TestParamKey) + + def getTestParam: Double = getGlobalParam(testParam) + + def setTestParam(v: Double): this.type = set(testParam, v) + + val testServiceParam = new ServiceParam[Int]( + this, "testServiceParam", "Test Service Param", isRequired = false) +} + +class TestGlobalParams extends TestGlobalParamsTrait { + + override def copy(extra: ParamMap): Transformer = defaultCopy(extra) + + override val uid: String = Identifiable.randomUID("TestGlobalParams") +} + + +case object TestParamKey extends GlobalKey[Double] {val name: String = "TestParam"; val isServiceParam = false} +case object TestServiceParamKey extends GlobalKey[Int] +{val name: String = "TestServiceParam"; val isServiceParam = true} + +class GlobalParamSuite extends Flaky { + + override def beforeAll(): Unit = { + val aadToken = getAccessToken("https://cognitiveservices.azure.com/") + println(s"Triggering token creation early ${aadToken.length}") + super.beforeAll() + } + + val testGlobalParams = new TestGlobalParams() + + test("Basic Usage") { + GlobalParams.setGlobalParam(TestParamKey, 12.5) + assert(testGlobalParams.getTestParam == 12.5) + } + + test("Test Changing Value") { + assert(testGlobalParams.getTestParam == 12.5) + GlobalParams.setGlobalParam("TestParam", 19.853) + assert(testGlobalParams.getTestParam == 19.853) + } +} + From 4bd764bed27d370236d1646645446eafab375273 Mon Sep 17 00:00:00 2001 From: Shyam Sai Date: Thu, 21 Nov 2024 17:19:55 -0500 Subject: [PATCH 6/8] fix comments - Works! --- .../ml/services/CognitiveServiceBase.scala | 20 ++------ .../synapse/ml/services/openai/OpenAI.scala | 3 +- .../ml/services/openai/OpenAIDefaults.scala | 2 +- .../ml/services/openai/OpenAIPrompt.scala | 18 ++------ .../azure/synapse/ml/param/GlobalParams.scala | 46 ++++--------------- .../synapse/ml/param/GlobalParamsSuite.scala | 43 +++++++---------- 6 files changed, 33 insertions(+), 99 deletions(-) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala index 67b8f69782..850b2915a7 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala @@ -10,7 +10,7 @@ import com.microsoft.azure.synapse.ml.fabric.FabricClient import com.microsoft.azure.synapse.ml.io.http._ import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails -import com.microsoft.azure.synapse.ml.param.{GlobalParams, ServiceParam} +import com.microsoft.azure.synapse.ml.param.{GlobalParams, HasGlobalParams, ServiceParam} import com.microsoft.azure.synapse.ml.stages.{DropColumns, Lambda} import org.apache.http.NameValuePair import org.apache.http.client.methods.{HttpEntityEnclosingRequestBase, HttpPost, HttpRequestBase} @@ -493,7 +493,7 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform with HasURL with ComplexParamsWritable with HasSubscriptionKey with HasErrorCol with HasAADToken with HasCustomCogServiceDomain - with SynapseMLLogging { + with SynapseMLLogging with HasGlobalParams { setDefault( outputCol -> (this.uid + "_output"), @@ -547,21 +547,7 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform } override def transform(dataset: Dataset[_]): DataFrame = { - // check for empty params. Fill em with OpenAIDefaults - this.params - .filter(p => !this.isSet(p) && !this.hasDefault(p)) - .foreach { p => - GlobalParams.getParam(p) match { - case Some(v) => - p match { - case serviceParam: ServiceParam[_] => - setScalarParam(serviceParam.asInstanceOf[ServiceParam[Any]], v) - case param: Param[_] => - set(param.asInstanceOf[Param[Any]], v) - } - case None => - } - } + transferGlobalParamsToParamMap() logTransform[DataFrame](getInternalTransformer(dataset.schema).transform(dataset), dataset.columns.length ) } diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index f64c6f91e1..2403af3ed9 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -51,8 +51,7 @@ trait HasMessagesInput extends Params { def setMessagesCol(v: String): this.type = set(messagesCol, v) } -case object OpenAIDeploymentNameKey extends GlobalKey[String] -{val name: String = "OpenAIDeploymentName"} +case object OpenAIDeploymentNameKey extends GlobalKey[Either[String, String]] trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion { diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala index a50e7fa187..122254f001 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala @@ -4,6 +4,6 @@ import com.microsoft.azure.synapse.ml.param.GlobalParams object OpenAIDefaults { def setDeploymentName(v: String): Unit = { - GlobalParams.setGlobalParam(OpenAIDeploymentNameKey, v) + GlobalParams.setGlobalParam(OpenAIDeploymentNameKey, Left(v)) } } diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index 150f710966..136e5b28d1 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -7,7 +7,7 @@ import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol import com.microsoft.azure.synapse.ml.core.spark.Functions import com.microsoft.azure.synapse.ml.io.http.{ConcurrencyParams, HasErrorCol, HasURL} import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging} -import com.microsoft.azure.synapse.ml.param.{GlobalParams, ServiceParam, StringStringMapParam} +import com.microsoft.azure.synapse.ml.param.{GlobalParams, HasGlobalParams, ServiceParam, StringStringMapParam} import com.microsoft.azure.synapse.ml.services._ import org.apache.http.entity.AbstractHttpEntity import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} @@ -28,7 +28,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer with HasURL with HasCustomCogServiceDomain with ConcurrencyParams with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader with HasCognitiveServiceInput - with ComplexParamsWritable with SynapseMLLogging { + with ComplexParamsWritable with SynapseMLLogging with HasGlobalParams { logClass(FeatureNames.AiServices.OpenAI) @@ -124,19 +124,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer override def transform(dataset: Dataset[_]): DataFrame = { import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions._ - this.params - .filter(p => !this.isSet(p) && !this.hasDefault(p)) - .foreach { p => - GlobalParams.getParam(p) match { - case Some(v) => - p match { - case serviceParam: ServiceParam[_] => setScalarParam(serviceParam.asInstanceOf[ServiceParam[Any]], v) - case param: Param[_] => - set(param.asInstanceOf[Param[Any]], v) - } - case None => - } - } + transferGlobalParamsToParamMap() logTransform[DataFrame]({ val df = dataset.toDF diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala index 42b3173700..a079d2422a 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala @@ -6,39 +6,17 @@ import org.apache.spark.sql.Row import scala.collection.mutable import com.microsoft.azure.synapse.ml.core.utils.JarLoadingUtils -trait GlobalKey[T] { - val name: String -} +trait GlobalKey[T] object GlobalParams { private val ParamToKeyMap: mutable.Map[Any, GlobalKey[_]] = mutable.Map.empty private val GlobalParams: mutable.Map[GlobalKey[_], Any] = mutable.Map.empty - private val StringtoKeyMap: mutable.Map[String, GlobalKey[_]] = { - val strToKeyMap = mutable.Map[String, GlobalKey[_]]() - JarLoadingUtils.instantiateObjects[GlobalKey[_]]().foreach { key: GlobalKey[_] => - strToKeyMap += (key.name -> key) - } - strToKeyMap - } - - private def findGlobalKeyByName(keyName: String): Option[GlobalKey[_]] = { - StringtoKeyMap.get(keyName) - } def setGlobalParam[T](key: GlobalKey[T], value: T): Unit = { GlobalParams(key) = value } - def setGlobalParam[T](keyName: String, value: T): Unit = { - val key = findGlobalKeyByName(keyName) - key match { - case Some(k) => - GlobalParams(k) = value - case None => throw new IllegalArgumentException("${keyName} is not a valid key in GlobalParams") - } - } - private def getGlobalParam[T](key: GlobalKey[T]): Option[T] = { GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[T]) } @@ -56,24 +34,18 @@ object GlobalParams { def registerParam[T](p: Param[T], key: GlobalKey[T]): Unit = { ParamToKeyMap(p) = key } - - def registerParam[T](sp: ServiceParam[T], key: GlobalKey[T]): Unit = { - ParamToKeyMap(sp) = key - } } trait HasGlobalParams extends Params{ - def getGlobalParam[T](p: Param[T]): T = { - try { - this.getOrDefault(p) - } - catch { - case e: Exception => - GlobalParams.getParam(p) match { - case Some(v) => v - case None => throw e + private[ml] def transferGlobalParamsToParamMap(): Unit = { + // check for empty params. Fill em with GlobalParams + this.params + .filter(p => !this.isSet(p) && !this.hasDefault(p)) + .foreach { p => + GlobalParams.getParam(p).foreach { v => + set(p.asInstanceOf[Param[Any]], v) } - } + } } } \ No newline at end of file diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/param/GlobalParamsSuite.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/GlobalParamsSuite.scala index 2912ce12b7..268f4b48f0 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/param/GlobalParamsSuite.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/GlobalParamsSuite.scala @@ -3,58 +3,47 @@ package com.microsoft.azure.synapse.ml.param -import com.microsoft.azure.synapse.ml.Secrets.getAccessToken import com.microsoft.azure.synapse.ml.core.test.base.Flaky import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.util.Identifiable -import spray.json.DefaultJsonProtocol.IntJsonFormat -trait TestGlobalParamsTrait extends HasGlobalParams { +case object TestParamKey extends GlobalKey[Double] + +class TestGlobalParams extends HasGlobalParams { + override val uid: String = Identifiable.randomUID("TestGlobalParams") + val testParam: Param[Double] = new Param[Double]( - this, "TestParam", "Test Param") + this, "TestParam", "Test Param for testing") + + println(testParam.parent) + println(hasParam(testParam.name)) GlobalParams.registerParam(testParam, TestParamKey) - def getTestParam: Double = getGlobalParam(testParam) + def getTestParam: Double = $(testParam) def setTestParam(v: Double): this.type = set(testParam, v) - val testServiceParam = new ServiceParam[Int]( - this, "testServiceParam", "Test Service Param", isRequired = false) -} - -class TestGlobalParams extends TestGlobalParamsTrait { - override def copy(extra: ParamMap): Transformer = defaultCopy(extra) - override val uid: String = Identifiable.randomUID("TestGlobalParams") } - -case object TestParamKey extends GlobalKey[Double] {val name: String = "TestParam"; val isServiceParam = false} -case object TestServiceParamKey extends GlobalKey[Int] -{val name: String = "TestServiceParam"; val isServiceParam = true} - class GlobalParamSuite extends Flaky { - override def beforeAll(): Unit = { - val aadToken = getAccessToken("https://cognitiveservices.azure.com/") - println(s"Triggering token creation early ${aadToken.length}") - super.beforeAll() - } - val testGlobalParams = new TestGlobalParams() test("Basic Usage") { GlobalParams.setGlobalParam(TestParamKey, 12.5) + testGlobalParams.transferGlobalParamsToParamMap() assert(testGlobalParams.getTestParam == 12.5) } - test("Test Changing Value") { - assert(testGlobalParams.getTestParam == 12.5) - GlobalParams.setGlobalParam("TestParam", 19.853) - assert(testGlobalParams.getTestParam == 19.853) + test("Test Setting Directly Value") { + testGlobalParams.setTestParam(18.7334) + GlobalParams.setGlobalParam(TestParamKey, 19.853) + testGlobalParams.transferGlobalParamsToParamMap() + assert(testGlobalParams.getTestParam == 18.7334) } } From 6e09fbf9997cd256a706025aa89c489d5f7dea95 Mon Sep 17 00:00:00 2001 From: Shyam Sai Date: Thu, 21 Nov 2024 21:34:01 -0500 Subject: [PATCH 7/8] Remove unused imports --- .../azure/synapse/ml/services/openai/OpenAIPrompt.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index 136e5b28d1..49db4911f2 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -7,7 +7,7 @@ import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol import com.microsoft.azure.synapse.ml.core.spark.Functions import com.microsoft.azure.synapse.ml.io.http.{ConcurrencyParams, HasErrorCol, HasURL} import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging} -import com.microsoft.azure.synapse.ml.param.{GlobalParams, HasGlobalParams, ServiceParam, StringStringMapParam} +import com.microsoft.azure.synapse.ml.param.{HasGlobalParams, StringStringMapParam} import com.microsoft.azure.synapse.ml.services._ import org.apache.http.entity.AbstractHttpEntity import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} From ca44c8ff4dc0c27542416149714ea0be7d7b384b Mon Sep 17 00:00:00 2001 From: Shyam Sai Date: Thu, 21 Nov 2024 23:20:40 -0500 Subject: [PATCH 8/8] Add headers and fix style bugs --- .../azure/synapse/ml/services/openai/OpenAIDefaults.scala | 3 +++ .../synapse/ml/services/openai/OpenAIDefaultsSuite.scala | 3 +++ .../microsoft/azure/synapse/ml/param/GlobalParams.scala | 7 ++++--- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala index 122254f001..8b91625064 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala @@ -1,3 +1,6 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + package com.microsoft.azure.synapse.ml.services.openai import com.microsoft.azure.synapse.ml.param.GlobalParams diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala index 7ed1c064c2..0f156d02f2 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala @@ -1,3 +1,6 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + package com.microsoft.azure.synapse.ml.services.openai import com.microsoft.azure.synapse.ml.core.test.base.Flaky diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala index a079d2422a..98f7eb33e6 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala @@ -1,10 +1,11 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + package com.microsoft.azure.synapse.ml.param import org.apache.spark.ml.param.{Param, Params} -import org.apache.spark.sql.Row import scala.collection.mutable -import com.microsoft.azure.synapse.ml.core.utils.JarLoadingUtils trait GlobalKey[T] @@ -48,4 +49,4 @@ trait HasGlobalParams extends Params{ } } } -} \ No newline at end of file +}