Skip to content

Commit

Permalink
Adding SubscriptionKey and Temperature
Browse files Browse the repository at this point in the history
  • Loading branch information
sss04 committed Dec 2, 2024
1 parent 79d5b58 commit d88fab3
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import com.microsoft.azure.synapse.ml.fabric.FabricClient
import com.microsoft.azure.synapse.ml.io.http._
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.param.{GlobalParams, HasGlobalParams, ServiceParam}
import com.microsoft.azure.synapse.ml.param.{GlobalKey, GlobalParams, HasGlobalParams, ServiceParam}
import com.microsoft.azure.synapse.ml.services.openai.OpenAIDeploymentNameKey
import com.microsoft.azure.synapse.ml.stages.{DropColumns, Lambda}
import org.apache.http.NameValuePair
import org.apache.http.client.methods.{HttpEntityEnclosingRequestBase, HttpPost, HttpRequestBase}
Expand Down Expand Up @@ -128,10 +129,14 @@ trait HasServiceParams extends Params {
}
}

case object OpenAISubscriptionKey extends GlobalKey[Either[String, String]]

trait HasSubscriptionKey extends HasServiceParams {
val subscriptionKey = new ServiceParam[String](
this, "subscriptionKey", "the API key to use")

GlobalParams.registerParam(subscriptionKey, OpenAISubscriptionKey)

def getSubscriptionKey: String = getScalarParam(subscriptionKey)

def setSubscriptionKey(v: String): this.type = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ trait HasOpenAIEmbeddingParams extends HasOpenAISharedParams with HasAPIVersion
}
}

case object OpenAITemperatureKey extends GlobalKey[Either[Double, String]]

trait HasOpenAITextParams extends HasOpenAISharedParams {
val maxTokens: ServiceParam[Int] = new ServiceParam[Int](
this, "maxTokens",
Expand All @@ -126,6 +128,8 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
" We generally recommend using this or `top_p` but not both. Minimum of 0 and maximum of 2 allowed.",
isRequired = false)

GlobalParams.registerParam(temperature, OpenAITemperatureKey)

def getTemperature: Double = getScalarParam(temperature)

def setTemperature(v: Double): this.type = setScalarParam(temperature, v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,18 @@
package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.param.GlobalParams
import com.microsoft.azure.synapse.ml.services.OpenAISubscriptionKey

object OpenAIDefaults {
def setDeploymentName(v: String): Unit = {
GlobalParams.setGlobalParam(OpenAIDeploymentNameKey, Left(v))
}

def setSubscriptionKey(v: String): Unit = {
GlobalParams.setGlobalParam(OpenAISubscriptionKey, Left(v))
}

def setTemperature(v: Double): Unit = {
GlobalParams.setGlobalParam(OpenAITemperatureKey, Left(v))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
import spark.implicits._

OpenAIDefaults.setDeploymentName(deploymentName)
OpenAIDefaults.setSubscriptionKey(openAIAPIKey)
OpenAIDefaults.setTemperature(0.05)


def promptCompletion: OpenAICompletion = new OpenAICompletion()
.setCustomServiceName(openAIServiceName)
.setMaxTokens(200)
.setOutputCol("out")
.setSubscriptionKey(openAIAPIKey)
.setPromptCol("prompt")

lazy val promptDF: DataFrame = Seq(
Expand All @@ -33,10 +35,8 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
}

lazy val prompt: OpenAIPrompt = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")
.setTemperature(0)

lazy val df: DataFrame = Seq(
("apple", "fruits"),
Expand All @@ -56,4 +56,10 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {

assert(nonNullCount == 3)
}

test("OpenAIPrompt Check Params") {
assert(prompt.getDeploymentName == deploymentName)
assert(prompt.getSubscriptionKey == openAIAPIKey)
assert(prompt.getTemperature == 0.05)
}
}

0 comments on commit d88fab3

Please sign in to comment.