Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Global Params #2318

Merged
merged 13 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import com.microsoft.azure.synapse.ml.fabric.FabricClient
import com.microsoft.azure.synapse.ml.io.http._
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.param.{GlobalParams, HasGlobalParams, ServiceParam}
import com.microsoft.azure.synapse.ml.stages.{DropColumns, Lambda}
import org.apache.http.NameValuePair
import org.apache.http.client.methods.{HttpEntityEnclosingRequestBase, HttpPost, HttpRequestBase}
Expand Down Expand Up @@ -493,7 +493,7 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform
with HasURL with ComplexParamsWritable
with HasSubscriptionKey with HasErrorCol
with HasAADToken with HasCustomCogServiceDomain
with SynapseMLLogging {
with SynapseMLLogging with HasGlobalParams {

setDefault(
outputCol -> (this.uid + "_output"),
Expand Down Expand Up @@ -547,6 +547,7 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform
}

override def transform(dataset: Dataset[_]): DataFrame = {
transferGlobalParamsToParamMap()
sss04 marked this conversation as resolved.
Show resolved Hide resolved
logTransform[DataFrame](getInternalTransformer(dataset.schema).transform(dataset), dataset.columns.length
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package com.microsoft.azure.synapse.ml.services.openai
import com.microsoft.azure.synapse.ml.codegen.GenerationUtils
import com.microsoft.azure.synapse.ml.fabric.{FabricClient, OpenAIFabricSetting}
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.param.{GlobalKey, GlobalParams, ServiceParam}
import com.microsoft.azure.synapse.ml.services._
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.param.{Param, Params}
Expand Down Expand Up @@ -51,11 +51,15 @@ trait HasMessagesInput extends Params {
def setMessagesCol(v: String): this.type = set(messagesCol, v)
}

case object OpenAIDeploymentNameKey extends GlobalKey[Either[String, String]]

trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion {

val deploymentName = new ServiceParam[String](
this, "deploymentName", "The name of the deployment", isRequired = false)

GlobalParams.registerParam(deploymentName, OpenAIDeploymentNameKey)

def getDeploymentName: String = getScalarParam(deploymentName)

def setDeploymentName(v: String): this.type = setScalarParam(deploymentName, v)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.param.GlobalParams

object OpenAIDefaults {
def setDeploymentName(v: String): Unit = {
GlobalParams.setGlobalParam(OpenAIDeploymentNameKey, Left(v))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol
import com.microsoft.azure.synapse.ml.core.spark.Functions
import com.microsoft.azure.synapse.ml.io.http.{ConcurrencyParams, HasErrorCol, HasURL}
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.StringStringMapParam
import com.microsoft.azure.synapse.ml.param.{HasGlobalParams, StringStringMapParam}
import com.microsoft.azure.synapse.ml.services._
import org.apache.http.entity.AbstractHttpEntity
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
Expand All @@ -28,7 +28,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
with HasURL with HasCustomCogServiceDomain with ConcurrencyParams
with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader
with HasCognitiveServiceInput
with ComplexParamsWritable with SynapseMLLogging {
with ComplexParamsWritable with SynapseMLLogging with HasGlobalParams {

logClass(FeatureNames.AiServices.OpenAI)

Expand Down Expand Up @@ -124,7 +124,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer

override def transform(dataset: Dataset[_]): DataFrame = {
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions._

transferGlobalParamsToParamMap()
logTransform[DataFrame]({
val df = dataset.toDF

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.core.test.base.Flaky
import org.apache.spark.sql.{DataFrame, Row}

class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {

import spark.implicits._

OpenAIDefaults.setDeploymentName(deploymentName)

def promptCompletion: OpenAICompletion = new OpenAICompletion()
.setCustomServiceName(openAIServiceName)
.setMaxTokens(200)
.setOutputCol("out")
.setSubscriptionKey(openAIAPIKey)
.setPromptCol("prompt")

lazy val promptDF: DataFrame = Seq(
"Once upon a time",
"Best programming language award goes to",
"SynapseML is "
).toDF("prompt")

test("Completion w Globals") {
val fromRow = CompletionResponse.makeFromRowConverter
promptCompletion.transform(promptDF).collect().foreach(r =>
fromRow(r.getAs[Row]("out")).choices.foreach(c =>
assert(c.text.length > 10)))
}

lazy val prompt: OpenAIPrompt = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")
.setTemperature(0)

lazy val df: DataFrame = Seq(
("apple", "fruits"),
("mercedes", "cars"),
("cake", "dishes"),
(null, "none") //scalastyle:ignore null
).toDF("text", "category")

test("OpenAIPrompt w Globals") {
val nonNullCount = prompt
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
.setPostProcessing("csv")
.transform(df)
.select("outParsed")
.collect()
.count(r => Option(r.getSeq[String](0)).isDefined)

assert(nonNullCount == 3)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.microsoft.azure.synapse.ml.param

import org.apache.spark.ml.param.{Param, Params}
import org.apache.spark.sql.Row

import scala.collection.mutable
import com.microsoft.azure.synapse.ml.core.utils.JarLoadingUtils

trait GlobalKey[T]

object GlobalParams {
private val ParamToKeyMap: mutable.Map[Any, GlobalKey[_]] = mutable.Map.empty
private val GlobalParams: mutable.Map[GlobalKey[_], Any] = mutable.Map.empty


def setGlobalParam[T](key: GlobalKey[T], value: T): Unit = {
GlobalParams(key) = value
}

private def getGlobalParam[T](key: GlobalKey[T]): Option[T] = {
GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[T])
}

def getParam[T](p: Param[T]): Option[T] = {
ParamToKeyMap.get(p).flatMap { key =>
key match {
case k: GlobalKey[T] =>
getGlobalParam(k)
case _ => None
}
}
}

def registerParam[T](p: Param[T], key: GlobalKey[T]): Unit = {
ParamToKeyMap(p) = key
}
}

trait HasGlobalParams extends Params{

sss04 marked this conversation as resolved.
Show resolved Hide resolved
private[ml] def transferGlobalParamsToParamMap(): Unit = {
// check for empty params. Fill em with GlobalParams
this.params
.filter(p => !this.isSet(p) && !this.hasDefault(p))
.foreach { p =>
GlobalParams.getParam(p).foreach { v =>
set(p.asInstanceOf[Param[Any]], v)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.param

import com.microsoft.azure.synapse.ml.core.test.base.Flaky
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.Identifiable

case object TestParamKey extends GlobalKey[Double]

class TestGlobalParams extends HasGlobalParams {
override val uid: String = Identifiable.randomUID("TestGlobalParams")

val testParam: Param[Double] = new Param[Double](
this, "TestParam", "Test Param for testing")

println(testParam.parent)
println(hasParam(testParam.name))

GlobalParams.registerParam(testParam, TestParamKey)

def getTestParam: Double = $(testParam)

def setTestParam(v: Double): this.type = set(testParam, v)

override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

}

class GlobalParamSuite extends Flaky {

val testGlobalParams = new TestGlobalParams()

test("Basic Usage") {
GlobalParams.setGlobalParam(TestParamKey, 12.5)
testGlobalParams.transferGlobalParamsToParamMap()
assert(testGlobalParams.getTestParam == 12.5)
}

test("Test Setting Directly Value") {
testGlobalParams.setTestParam(18.7334)
GlobalParams.setGlobalParam(TestParamKey, 19.853)
testGlobalParams.transferGlobalParamsToParamMap()
assert(testGlobalParams.getTestParam == 18.7334)
}
}