diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala new file mode 100644 index 0000000000..9d9c220942 --- /dev/null +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/GlobalParam.scala @@ -0,0 +1,264 @@ +package com.microsoft.azure.synapse.ml.services + +import com.microsoft.azure.synapse.ml.param.ServiceParam +import org.apache.spark.ml.param.Param +import org.apache.spark.sql.Row + +import scala.collection.mutable +import com.microsoft.azure.synapse.ml.core.utils.JarLoadingUtils + +trait GlobalKey[T] { + val name: String + val isServiceParam: Boolean +} + +object GlobalParams { + private val ParamToKeyMap: mutable.Map[Any, GlobalKey[_]] = mutable.Map.empty + private val GlobalParams: mutable.Map[GlobalKey[_], Any] = mutable.Map.empty + + private val StringtoKeyMap: mutable.Map[String, GlobalKey[_]] = { + val strToKeyMap = mutable.Map[String, GlobalKey[_]]() + JarLoadingUtils.instantiateObjects[GlobalKey[_]]().foreach { key: GlobalKey[_] => + strToKeyMap += (key.name -> key) + } + strToKeyMap + } + + private def findGlobalKeyByName(keyName: String): Option[GlobalKey[_]] = { + StringtoKeyMap.get(keyName) + } + + def setGlobalParam[T](key: GlobalKey[T], value: T): Unit = { + assert(!key.isServiceParam, s"${key.name} is a Service Param. setGlobalServiceParamKey should be used.") + GlobalParams(key) = value + } + + def setGlobalParam[T](keyName: String, value: T): Unit = { + val key = findGlobalKeyByName(keyName) + key match { + case Some(k) => + assert(!k.isServiceParam, s"${k.name} is a Service Param. setGlobalServiceParamKey should be used.") + GlobalParams(k) = value + case None => throw new IllegalArgumentException("${keyName} is not a valid key in GlobalParams") + } + } + + def setGlobalServiceParam[T](key: GlobalKey[T], value: T): Unit = { + assert(key.isServiceParam, s"${key.name} is a Param. setGlobalParamKey should be used.") + GlobalParams(key) = Left(value) + } + + def setGlobalServiceParam[T](keyName: String, value: T): Unit = { + val key = findGlobalKeyByName(keyName) + key match { + case Some(k) => + assert(k.isServiceParam, s"${k.name} is a Param. setGlobalParamKey should be used.") + GlobalParams(k) = Left(value) + case None => throw new IllegalArgumentException("${keyName} is not a valid key in GlobalParams") + } + } + + def setGlobalServiceParamCol[T](key: GlobalKey[T], value: String): Unit = { + assert(key.isServiceParam, s"${key.name} is a Param. setGlobalParamKey should be used.") + GlobalParams(key) = Right(value) + } + + def setGlobalServiceParamCol[T](keyName: String, value: String): Unit = { + val key = findGlobalKeyByName(keyName) + key match { + case Some(k) => + assert(k.isServiceParam, s"${k.name} is a Param. setGlobalParamKey should be used.") + GlobalParams(k) = Right(value) + case None => throw new IllegalArgumentException("${keyName} is not a valid key in GlobalParams") + } + } + + private def getGlobalParam[T](key: GlobalKey[T]): Option[T] = { + GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[T]) + } + + private def getGlobalServiceParam[T](key: GlobalKey[T]): Option[Either[T, String]] = { + val value = GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[Either[T, String]]) + value match { + case some @ Some(v) => + assert(v.isInstanceOf[Either[T, String]], + "getGlobalServiceParam used to fetch a normal Param") + value + case None => None + } + } + + def getParam[T](p: Param[T]): Option[T] = { + ParamToKeyMap.get(p).flatMap { key => + key match { + case k: GlobalKey[T] => + assert(!k.isServiceParam, s"${k.name} is a Service Param. getServiceParam should be used.") + getGlobalParam(k) + case _ => None + } + } + } + + def getServiceParam[T](sp: ServiceParam[T]): Option[Either[T, String]] = { + ParamToKeyMap.get(sp).flatMap { key => + key match { + case k: GlobalKey[T] => + assert(k.isServiceParam, s"${k.name} is a Param. getParam should be used.") + getGlobalServiceParam(k) + case _ => None + } + } + } + + def getServiceParamScalar[T](sp: ServiceParam[T]): Option[T] = { + ParamToKeyMap.get(sp).flatMap { key => + key match { + case k: GlobalKey[T] => + getGlobalServiceParam(k) match { + case Some(Left(value)) => + assert(k.isServiceParam, s"${k.name} is a Param. getParam should be used.") + Some(value) + case _ => None + } + case _ => None + } + } + } + + def getServiceParamVector[T](sp: ServiceParam[T]): Option[String] = { + ParamToKeyMap.get(sp).flatMap { key => + key match { + case k: GlobalKey[T] => + getGlobalServiceParam(k) match { + case Some(Right(colName)) => + assert(k.isServiceParam, s"${k.name} is a Param. getParam should be used.") + Some(colName) + case _ => None + } + case _ => None + } + } + } + + def registerParam[T](p: Param[T], key: GlobalKey[T]): Unit = { + assert(!key.isServiceParam, s"${key.name} is a Service Param. registerServiceParam should be used.") + ParamToKeyMap(p) = key + } + + def registerServiceParam[T](sp: ServiceParam[T], key: GlobalKey[T]): Unit = { + assert(key.isServiceParam, s"${key.name} is a Param. registerParam should be used.") + ParamToKeyMap(sp) = key + } +} + +trait HasGlobalParams extends HasServiceParams { + + def getGlobalParam[T](p: Param[T]): T = { + try { + this.getOrDefault(p) + } + catch { + case e: Exception => + GlobalParams.getParam(p) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalParam[T](name: String): T = { + val param = this.getParam(name).asInstanceOf[Param[T]] + try { + this.getOrDefault(param) + } + catch { + case e: Exception => + GlobalParams.getParam(param) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamScalar[T](p: ServiceParam[T]): T = { + try { + this.getScalarParam(p) + } + catch { + case e: Exception => + GlobalParams.getServiceParamScalar(p) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamVector[T](p: ServiceParam[T]): String= { + try { + this.getVectorParam(p) + } + catch { + case e: Exception => + GlobalParams.getServiceParamVector(p)match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamScalar[T](name: String): T = { + val serviceParam = this.getParam(name).asInstanceOf[ServiceParam[T]] + try { + this.getScalarParam(serviceParam) + } + catch { + case e: Exception => + GlobalParams.getServiceParamScalar(serviceParam) match { + case Some(v) => v + case None => throw e + } + } + } + + def getGlobalServiceParamVector[T](name: String): String = { + val serviceParam = this.getParam(name).asInstanceOf[ServiceParam[T]] + try { + this.getVectorParam(serviceParam) + } + catch { + case e: Exception => + GlobalParams.getServiceParamVector(serviceParam) match { + case Some(v) => v + case None => throw e + } + } + } + + protected def getGlobalServiceParamValueOpt[T](row: Row, p: ServiceParam[T]): Option[T] = { + val globalParam: Option[T] = GlobalParams.getServiceParam(p).flatMap { + case Right(colName) => Option(row.getAs[T](colName)) + case Left(value) => Some(value) + } + try { + get(p).orElse(getDefault(p)).flatMap { + case Right(colName) => Option(row.getAs[T](colName)) + case Left(value) => Some(value) + } + match { + case some @ Some(_) => some + case None => globalParam + } + } + catch { + case e: Exception => + globalParam match { + case Some(v) => Some(v) + case None => throw e + } + } + } + + + protected def getGlobalServiceParamValue[T](row: Row, p: ServiceParam[T]): T = + getGlobalServiceParamValueOpt(row, p).get +} \ No newline at end of file diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index 8d40f49ee6..e89bb2b735 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -51,12 +51,17 @@ trait HasMessagesInput extends Params { def setMessagesCol(v: String): this.type = set(messagesCol, v) } -trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion { +case object OpenAIDeploymentNameKey extends GlobalKey[String] +{val name: String = "OpenAIDeploymentName"; val isServiceParam = true} + +trait HasOpenAISharedParams extends HasGlobalParams with HasAPIVersion { val deploymentName = new ServiceParam[String]( this, "deploymentName", "The name of the deployment", isRequired = false) - def getDeploymentName: String = getScalarParam(deploymentName) + GlobalParams.registerServiceParam(deploymentName, OpenAIDeploymentNameKey) + + def getDeploymentName: String = getGlobalServiceParamScalar(deploymentName) def setDeploymentName(v: String): this.type = setScalarParam(deploymentName, v) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala index 379d797766..d7a81e1f8b 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala @@ -34,7 +34,7 @@ class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase( } override protected def prepareUrlRoot: Row => String = { row => - s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/chat/completions" + s"${getUrl}openai/deployments/${getGlobalServiceParamValue(row, deploymentName)}/chat/completions" } override protected[openai] def prepareEntity: Row => Option[AbstractHttpEntity] = { diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala index 219bc34d87..73d97cf7b0 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala @@ -34,7 +34,7 @@ class OpenAICompletion(override val uid: String) extends OpenAIServicesBase(uid) } override protected def prepareUrlRoot: Row => String = { row => - s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/completions" + s"${getUrl}openai/deployments/${getGlobalServiceParamValue(row, deploymentName)}/completions" } override protected[openai] def prepareEntity: Row => Option[AbstractHttpEntity] = { diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala index 342254f8fb..cdb6e6d070 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala @@ -60,7 +60,7 @@ class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid) } override protected def prepareUrlRoot: Row => String = { row => - s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/embeddings" + s"${getUrl}openai/deployments/${getGlobalServiceParamValue(row, deploymentName)}/embeddings" } private[this] def getStringEntity[A](text: A, optionalParams: Map[String, Any]): StringEntity = { diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala new file mode 100644 index 0000000000..f8f331d25c --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/GlobalParamSuite.scala @@ -0,0 +1,116 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services + +import com.microsoft.azure.synapse.ml.Secrets.getAccessToken +import com.microsoft.azure.synapse.ml.core.test.base.Flaky +import com.microsoft.azure.synapse.ml.param.ServiceParam +import com.microsoft.azure.synapse.ml.services.openai.{OpenAIAPIKey, OpenAIPrompt} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.DataFrame +import spray.json.DefaultJsonProtocol.IntJsonFormat + +trait TestGlobalParamsTrait extends HasGlobalParams { + val testParam: Param[Double] = new Param[Double]( + this, "TestParam", "Test Param") + + GlobalParams.registerParam(testParam, TestParamKey) + + def getTestParam: Double = getGlobalParam(testParam) + + def setTestParam(v: Double): this.type = set(testParam, v) + + val testServiceParam = new ServiceParam[Int]( + this, "testServiceParam", "Test Service Param", isRequired = false) + + GlobalParams.registerServiceParam(testServiceParam, TestServiceParamKey) + + def getTestServiceParam: Int = getGlobalServiceParamScalar(testServiceParam) + + def setTestServiceParam(v: Int): this.type = setScalarParam(testServiceParam, v) + + def getTestServiceParamCol: String = getVectorParam(testServiceParam) + + def setTestServiceParamCol(v: String): this.type = setVectorParam(testServiceParam, v) +} + +class TestGlobalParams extends TestGlobalParamsTrait { + + override def copy(extra: ParamMap): Transformer = defaultCopy(extra) + + override val uid: String = Identifiable.randomUID("TestGlobalParams") +} + + +case object TestParamKey extends GlobalKey[Double] {val name: String = "TestParam"; val isServiceParam = false} +case object TestServiceParamKey extends GlobalKey[Int] +{val name: String = "TestServiceParam"; val isServiceParam = true} + +class GlobalParamSuite extends Flaky with OpenAIAPIKey { + + import spark.implicits._ + + override def beforeAll(): Unit = { + val aadToken = getAccessToken("https://cognitiveservices.azure.com/") + println(s"Triggering token creation early ${aadToken.length}") + super.beforeAll() + } + + val testGlobalParams = new TestGlobalParams() + + test("Basic Usage") { + GlobalParams.setGlobalParam(TestParamKey, 12.5) + GlobalParams.setGlobalServiceParam(TestServiceParamKey, 1) + assert(testGlobalParams.getTestParam == 12.5) + assert(testGlobalParams.getTestServiceParam == 1) + } + + test("Test Changing Value") { + assert(testGlobalParams.getTestParam == 12.5) + GlobalParams.setGlobalParam("TestParam", 19.853) + assert(testGlobalParams.getTestParam == 19.853) + assert(testGlobalParams.getTestServiceParam == 1) + GlobalParams.setGlobalServiceParam("TestServiceParam", 4) + assert(testGlobalParams.getTestServiceParam == 4) + } + + test("Using wrong setters and getters") { + assertThrows[AssertionError] { + GlobalParams.setGlobalServiceParam("TestParam", 8.8888) + } + assertThrows[AssertionError] { + GlobalParams.getParam(testGlobalParams.testServiceParam) + } + } + + GlobalParams.setGlobalServiceParam("OpenAIDeploymentName", deploymentName) + + lazy val prompt: OpenAIPrompt = new OpenAIPrompt() + .setSubscriptionKey(openAIAPIKey) + .setCustomServiceName(openAIServiceName) + .setOutputCol("outParsed") + .setTemperature(0) + + lazy val df: DataFrame = Seq( + ("apple", "fruits"), + ("mercedes", "cars"), + ("cake", "dishes"), + (null, "none") //scalastyle:ignore null + ).toDF("text", "category") + + test("OpenAIPrompt w Globals") { + val nonNullCount = prompt + .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") + .setPostProcessing("csv") + .transform(df) + .select("outParsed") + .collect() + .count(r => Option(r.getSeq[String](0)).isDefined) + + assert(nonNullCount == 3) + } +} +