Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Global Params #2318

Merged
merged 13 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import com.microsoft.azure.synapse.ml.fabric.FabricClient
import com.microsoft.azure.synapse.ml.io.http._
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.param.{GlobalParams, ServiceParam}
import com.microsoft.azure.synapse.ml.stages.{DropColumns, Lambda}
import org.apache.http.NameValuePair
import org.apache.http.client.methods.{HttpEntityEnclosingRequestBase, HttpPost, HttpRequestBase}
Expand Down Expand Up @@ -547,6 +547,21 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform
}

override def transform(dataset: Dataset[_]): DataFrame = {
// check for empty params. Fill em with OpenAIDefaults
this.params
.filter(p => !this.isSet(p) && !this.hasDefault(p))
.foreach { p =>
sss04 marked this conversation as resolved.
Show resolved Hide resolved
GlobalParams.getParam(p) match {
case Some(v) =>
p match {
case serviceParam: ServiceParam[_] =>
setScalarParam(serviceParam.asInstanceOf[ServiceParam[Any]], v)
sss04 marked this conversation as resolved.
Show resolved Hide resolved
case param: Param[_] =>
set(param.asInstanceOf[Param[Any]], v)
}
case None =>
}
}
logTransform[DataFrame](getInternalTransformer(dataset.schema).transform(dataset), dataset.columns.length
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package com.microsoft.azure.synapse.ml.services.openai
import com.microsoft.azure.synapse.ml.codegen.GenerationUtils
import com.microsoft.azure.synapse.ml.fabric.{FabricClient, OpenAIFabricSetting}
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.param.{GlobalKey, GlobalParams, ServiceParam}
import com.microsoft.azure.synapse.ml.services._
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.param.{Param, Params}
Expand Down Expand Up @@ -51,11 +51,16 @@ trait HasMessagesInput extends Params {
def setMessagesCol(v: String): this.type = set(messagesCol, v)
}

case object OpenAIDeploymentNameKey extends GlobalKey[String]
{val name: String = "OpenAIDeploymentName"}
sss04 marked this conversation as resolved.
Show resolved Hide resolved

trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion {

val deploymentName = new ServiceParam[String](
this, "deploymentName", "The name of the deployment", isRequired = false)

GlobalParams.registerParam(deploymentName, OpenAIDeploymentNameKey)

def getDeploymentName: String = getScalarParam(deploymentName)

def setDeploymentName(v: String): this.type = setScalarParam(deploymentName, v)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.param.GlobalParams

object OpenAIDefaults {
def setDeploymentName(v: String): Unit = {
GlobalParams.setGlobalParam(OpenAIDeploymentNameKey, v)
sss04 marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol
import com.microsoft.azure.synapse.ml.core.spark.Functions
import com.microsoft.azure.synapse.ml.io.http.{ConcurrencyParams, HasErrorCol, HasURL}
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.StringStringMapParam
import com.microsoft.azure.synapse.ml.param.{GlobalParams, ServiceParam, StringStringMapParam}
import com.microsoft.azure.synapse.ml.services._
import org.apache.http.entity.AbstractHttpEntity
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
Expand Down Expand Up @@ -124,7 +124,19 @@ class OpenAIPrompt(override val uid: String) extends Transformer

override def transform(dataset: Dataset[_]): DataFrame = {
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions._

this.params
.filter(p => !this.isSet(p) && !this.hasDefault(p))
.foreach { p =>
GlobalParams.getParam(p) match {
case Some(v) =>
p match {
case serviceParam: ServiceParam[_] => setScalarParam(serviceParam.asInstanceOf[ServiceParam[Any]], v)
case param: Param[_] =>
set(param.asInstanceOf[Param[Any]], v)
}
case None =>
}
}
sss04 marked this conversation as resolved.
Show resolved Hide resolved
logTransform[DataFrame]({
val df = dataset.toDF

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.core.test.base.Flaky
import org.apache.spark.sql.{DataFrame, Row}

class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {

import spark.implicits._

OpenAIDefaults.setDeploymentName(deploymentName)

def promptCompletion: OpenAICompletion = new OpenAICompletion()
.setCustomServiceName(openAIServiceName)
.setMaxTokens(200)
.setOutputCol("out")
.setSubscriptionKey(openAIAPIKey)
.setPromptCol("prompt")

lazy val promptDF: DataFrame = Seq(
"Once upon a time",
"Best programming language award goes to",
"SynapseML is "
).toDF("prompt")

test("Completion w Globals") {
val fromRow = CompletionResponse.makeFromRowConverter
promptCompletion.transform(promptDF).collect().foreach(r =>
fromRow(r.getAs[Row]("out")).choices.foreach(c =>
assert(c.text.length > 10)))
}

lazy val prompt: OpenAIPrompt = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")
.setTemperature(0)

lazy val df: DataFrame = Seq(
("apple", "fruits"),
("mercedes", "cars"),
("cake", "dishes"),
(null, "none") //scalastyle:ignore null
).toDF("text", "category")

test("OpenAIPrompt w Globals") {
val nonNullCount = prompt
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
.setPostProcessing("csv")
.transform(df)
.select("outParsed")
.collect()
.count(r => Option(r.getSeq[String](0)).isDefined)

assert(nonNullCount == 3)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package com.microsoft.azure.synapse.ml.param

import org.apache.spark.ml.param.{Param, Params}
import org.apache.spark.sql.Row

import scala.collection.mutable
import com.microsoft.azure.synapse.ml.core.utils.JarLoadingUtils

trait GlobalKey[T] {
val name: String
}

object GlobalParams {
private val ParamToKeyMap: mutable.Map[Any, GlobalKey[_]] = mutable.Map.empty
private val GlobalParams: mutable.Map[GlobalKey[_], Any] = mutable.Map.empty

private val StringtoKeyMap: mutable.Map[String, GlobalKey[_]] = {
val strToKeyMap = mutable.Map[String, GlobalKey[_]]()
JarLoadingUtils.instantiateObjects[GlobalKey[_]]().foreach { key: GlobalKey[_] =>
strToKeyMap += (key.name -> key)
}
strToKeyMap
}

private def findGlobalKeyByName(keyName: String): Option[GlobalKey[_]] = {
StringtoKeyMap.get(keyName)
}
sss04 marked this conversation as resolved.
Show resolved Hide resolved

def setGlobalParam[T](key: GlobalKey[T], value: T): Unit = {
GlobalParams(key) = value
}

def setGlobalParam[T](keyName: String, value: T): Unit = {
val key = findGlobalKeyByName(keyName)
key match {
case Some(k) =>
GlobalParams(k) = value
case None => throw new IllegalArgumentException("${keyName} is not a valid key in GlobalParams")
}
}
sss04 marked this conversation as resolved.
Show resolved Hide resolved

private def getGlobalParam[T](key: GlobalKey[T]): Option[T] = {
GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[T])
}

def getParam[T](p: Param[T]): Option[T] = {
ParamToKeyMap.get(p).flatMap { key =>
key match {
case k: GlobalKey[T] =>
getGlobalParam(k)
case _ => None
}
}
}

def registerParam[T](p: Param[T], key: GlobalKey[T]): Unit = {
ParamToKeyMap(p) = key
}

def registerParam[T](sp: ServiceParam[T], key: GlobalKey[T]): Unit = {
ParamToKeyMap(sp) = key
}
sss04 marked this conversation as resolved.
Show resolved Hide resolved
}

trait HasGlobalParams extends Params{

sss04 marked this conversation as resolved.
Show resolved Hide resolved
def getGlobalParam[T](p: Param[T]): T = {
try {
this.getOrDefault(p)
}
catch {
case e: Exception =>
GlobalParams.getParam(p) match {
case Some(v) => v
case None => throw e
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.param

import com.microsoft.azure.synapse.ml.Secrets.getAccessToken
import com.microsoft.azure.synapse.ml.core.test.base.Flaky
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.Identifiable
import spray.json.DefaultJsonProtocol.IntJsonFormat

trait TestGlobalParamsTrait extends HasGlobalParams {
val testParam: Param[Double] = new Param[Double](
this, "TestParam", "Test Param")

GlobalParams.registerParam(testParam, TestParamKey)

def getTestParam: Double = getGlobalParam(testParam)

def setTestParam(v: Double): this.type = set(testParam, v)

val testServiceParam = new ServiceParam[Int](
this, "testServiceParam", "Test Service Param", isRequired = false)
}

class TestGlobalParams extends TestGlobalParamsTrait {

override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

override val uid: String = Identifiable.randomUID("TestGlobalParams")
}


case object TestParamKey extends GlobalKey[Double] {val name: String = "TestParam"; val isServiceParam = false}
case object TestServiceParamKey extends GlobalKey[Int]
{val name: String = "TestServiceParam"; val isServiceParam = true}
sss04 marked this conversation as resolved.
Show resolved Hide resolved

class GlobalParamSuite extends Flaky {

override def beforeAll(): Unit = {
val aadToken = getAccessToken("https://cognitiveservices.azure.com/")
println(s"Triggering token creation early ${aadToken.length}")
super.beforeAll()
}

val testGlobalParams = new TestGlobalParams()

test("Basic Usage") {
GlobalParams.setGlobalParam(TestParamKey, 12.5)
assert(testGlobalParams.getTestParam == 12.5)
}

test("Test Changing Value") {
assert(testGlobalParams.getTestParam == 12.5)
GlobalParams.setGlobalParam("TestParam", 19.853)
assert(testGlobalParams.getTestParam == 19.853)
}
}