Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Global Param Object #2306

Closed
wants to merge 7 commits into from
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
package com.microsoft.azure.synapse.ml.services

import com.microsoft.azure.synapse.ml.param.ServiceParam
import org.apache.spark.ml.param.Param
import org.apache.spark.sql.Row

import scala.collection.mutable
import com.microsoft.azure.synapse.ml.core.utils.JarLoadingUtils

trait GlobalKey[T] {
val name: String
val isServiceParam: Boolean
}

object GlobalParams {
private val ParamToKeyMap: mutable.Map[Any, GlobalKey[_]] = mutable.Map.empty
private val GlobalParams: mutable.Map[GlobalKey[_], Any] = mutable.Map.empty

private val StringtoKeyMap: mutable.Map[String, GlobalKey[_]] = {
val strToKeyMap = mutable.Map[String, GlobalKey[_]]()
JarLoadingUtils.instantiateObjects[GlobalKey[_]]().foreach { key: GlobalKey[_] =>
strToKeyMap += (key.name -> key)
}
strToKeyMap
}

private def findGlobalKeyByName(keyName: String): Option[GlobalKey[_]] = {
StringtoKeyMap.get(keyName)
}

def setGlobalParam[T](key: GlobalKey[T], value: T): Unit = {
assert(!key.isServiceParam, s"${key.name} is a Service Param. setGlobalServiceParamKey should be used.")
GlobalParams(key) = value
}

def setGlobalParam[T](keyName: String, value: T): Unit = {
val key = findGlobalKeyByName(keyName)
key match {
case Some(k) =>
assert(!k.isServiceParam, s"${k.name} is a Service Param. setGlobalServiceParamKey should be used.")
GlobalParams(k) = value
case None => throw new IllegalArgumentException("${keyName} is not a valid key in GlobalParams")
}
}

def setGlobalServiceParam[T](key: GlobalKey[T], value: T): Unit = {
assert(key.isServiceParam, s"${key.name} is a Param. setGlobalParamKey should be used.")
GlobalParams(key) = Left(value)
}

def setGlobalServiceParam[T](keyName: String, value: T): Unit = {
val key = findGlobalKeyByName(keyName)
key match {
case Some(k) =>
assert(k.isServiceParam, s"${k.name} is a Param. setGlobalParamKey should be used.")
GlobalParams(k) = Left(value)
case None => throw new IllegalArgumentException("${keyName} is not a valid key in GlobalParams")
}
}

def setGlobalServiceParamCol[T](key: GlobalKey[T], value: String): Unit = {
assert(key.isServiceParam, s"${key.name} is a Param. setGlobalParamKey should be used.")
GlobalParams(key) = Right(value)
}

def setGlobalServiceParamCol[T](keyName: String, value: String): Unit = {
val key = findGlobalKeyByName(keyName)
key match {
case Some(k) =>
assert(k.isServiceParam, s"${k.name} is a Param. setGlobalParamKey should be used.")
GlobalParams(k) = Right(value)
case None => throw new IllegalArgumentException("${keyName} is not a valid key in GlobalParams")
}
}

private def getGlobalParam[T](key: GlobalKey[T]): Option[T] = {
GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[T])
}

private def getGlobalServiceParam[T](key: GlobalKey[T]): Option[Either[T, String]] = {
val value = GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[Either[T, String]])
value match {
case some @ Some(v) =>
assert(v.isInstanceOf[Either[T, String]],
"getGlobalServiceParam used to fetch a normal Param")
value
case None => None
}
}

def getParam[T](p: Param[T]): Option[T] = {
ParamToKeyMap.get(p).flatMap { key =>
key match {
case k: GlobalKey[T] =>
assert(!k.isServiceParam, s"${k.name} is a Service Param. getServiceParam should be used.")
getGlobalParam(k)
case _ => None
}
}
}

def getServiceParam[T](sp: ServiceParam[T]): Option[Either[T, String]] = {
ParamToKeyMap.get(sp).flatMap { key =>
key match {
case k: GlobalKey[T] =>
assert(k.isServiceParam, s"${k.name} is a Param. getParam should be used.")
getGlobalServiceParam(k)
case _ => None
}
}
}

def getServiceParamScalar[T](sp: ServiceParam[T]): Option[T] = {
ParamToKeyMap.get(sp).flatMap { key =>
key match {
case k: GlobalKey[T] =>
getGlobalServiceParam(k) match {
case Some(Left(value)) =>
assert(k.isServiceParam, s"${k.name} is a Param. getParam should be used.")
Some(value)
case _ => None
}
case _ => None
}
}
}

def getServiceParamVector[T](sp: ServiceParam[T]): Option[String] = {
ParamToKeyMap.get(sp).flatMap { key =>
key match {
case k: GlobalKey[T] =>
getGlobalServiceParam(k) match {
case Some(Right(colName)) =>
assert(k.isServiceParam, s"${k.name} is a Param. getParam should be used.")
Some(colName)
case _ => None
}
case _ => None
}
}
}

def registerParam[T](p: Param[T], key: GlobalKey[T]): Unit = {
assert(!key.isServiceParam, s"${key.name} is a Service Param. registerServiceParam should be used.")
ParamToKeyMap(p) = key
}

def registerServiceParam[T](sp: ServiceParam[T], key: GlobalKey[T]): Unit = {
assert(key.isServiceParam, s"${key.name} is a Param. registerParam should be used.")
ParamToKeyMap(sp) = key
}
}

trait HasGlobalParams extends HasServiceParams {

def getGlobalParam[T](p: Param[T]): T = {
try {
this.getOrDefault(p)
}
catch {
case e: Exception =>
GlobalParams.getParam(p) match {
case Some(v) => v
case None => throw e
}
}
}

def getGlobalParam[T](name: String): T = {
val param = this.getParam(name).asInstanceOf[Param[T]]
try {
this.getOrDefault(param)
}
catch {
case e: Exception =>
GlobalParams.getParam(param) match {
case Some(v) => v
case None => throw e
}
}
}

def getGlobalServiceParamScalar[T](p: ServiceParam[T]): T = {
try {
this.getScalarParam(p)
}
catch {
case e: Exception =>
GlobalParams.getServiceParamScalar(p) match {
case Some(v) => v
case None => throw e
}
}
}

def getGlobalServiceParamVector[T](p: ServiceParam[T]): String= {
try {
this.getVectorParam(p)
}
catch {
case e: Exception =>
GlobalParams.getServiceParamVector(p)match {
case Some(v) => v
case None => throw e
}
}
}

def getGlobalServiceParamScalar[T](name: String): T = {
val serviceParam = this.getParam(name).asInstanceOf[ServiceParam[T]]
try {
this.getScalarParam(serviceParam)
}
catch {
case e: Exception =>
GlobalParams.getServiceParamScalar(serviceParam) match {
case Some(v) => v
case None => throw e
}
}
}

def getGlobalServiceParamVector[T](name: String): String = {
val serviceParam = this.getParam(name).asInstanceOf[ServiceParam[T]]
try {
this.getVectorParam(serviceParam)
}
catch {
case e: Exception =>
GlobalParams.getServiceParamVector(serviceParam) match {
case Some(v) => v
case None => throw e
}
}
}

protected def getGlobalServiceParamValueOpt[T](row: Row, p: ServiceParam[T]): Option[T] = {
val globalParam: Option[T] = GlobalParams.getServiceParam(p).flatMap {
case Right(colName) => Option(row.getAs[T](colName))
case Left(value) => Some(value)
}
try {
get(p).orElse(getDefault(p)).flatMap {
case Right(colName) => Option(row.getAs[T](colName))
case Left(value) => Some(value)
}
match {
case some @ Some(_) => some
case None => globalParam
}
}
catch {
case e: Exception =>
globalParam match {
case Some(v) => Some(v)
case None => throw e
}
}
}


protected def getGlobalServiceParamValue[T](row: Row, p: ServiceParam[T]): T =
getGlobalServiceParamValueOpt(row, p).get
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,17 @@ trait HasMessagesInput extends Params {
def setMessagesCol(v: String): this.type = set(messagesCol, v)
}

trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion {
case object OpenAIDeploymentNameKey extends GlobalKey[String]
{val name: String = "OpenAIDeploymentName"; val isServiceParam = true}

trait HasOpenAISharedParams extends HasGlobalParams with HasAPIVersion {

val deploymentName = new ServiceParam[String](
this, "deploymentName", "The name of the deployment", isRequired = false)

def getDeploymentName: String = getScalarParam(deploymentName)
GlobalParams.registerServiceParam(deploymentName, OpenAIDeploymentNameKey)

def getDeploymentName: String = getGlobalServiceParamScalar(deploymentName)

def setDeploymentName(v: String): this.type = setScalarParam(deploymentName, v)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(
}

override protected def prepareUrlRoot: Row => String = { row =>
s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/chat/completions"
s"${getUrl}openai/deployments/${getGlobalServiceParamValue(row, deploymentName)}/chat/completions"
}

override protected[openai] def prepareEntity: Row => Option[AbstractHttpEntity] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class OpenAICompletion(override val uid: String) extends OpenAIServicesBase(uid)
}

override protected def prepareUrlRoot: Row => String = { row =>
s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/completions"
s"${getUrl}openai/deployments/${getGlobalServiceParamValue(row, deploymentName)}/completions"
}

override protected[openai] def prepareEntity: Row => Option[AbstractHttpEntity] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid)
}

override protected def prepareUrlRoot: Row => String = { row =>
s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/embeddings"
s"${getUrl}openai/deployments/${getGlobalServiceParamValue(row, deploymentName)}/embeddings"
}

private[this] def getStringEntity[A](text: A, optionalParams: Map[String, Any]): StringEntity = {
Expand Down
Loading