Skip to content

Commit

Permalink
Add URL to OpenAIDefaults and add new tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sss04 committed Dec 20, 2024
1 parent c239e92 commit e703bd2
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,12 @@ def get_temperature(self):

def reset_temperature(self):
self.defaults.resetTemperature()

def set_URL(self, URL):
self.defaults.setURL(URL)

def get_URL(self):
return getOption(self.defaults.getURL())

def reset_URL(self):
self.defaults.resetURL()
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.param.GlobalParams
import com.microsoft.azure.synapse.ml.services.OpenAISubscriptionKey
import com.microsoft.azure.synapse.ml.io.http.URLKey

object OpenAIDefaults {
def setDeploymentName(v: String): Unit = {
Expand Down Expand Up @@ -43,6 +44,18 @@ object OpenAIDefaults {
GlobalParams.resetGlobalParam(OpenAITemperatureKey)
}

def setURL(v: String): Unit = {
GlobalParams.setGlobalParam(URLKey, v)
}

def getURL: Option[String] = {
GlobalParams.getGlobalParam(URLKey)
}

def resetURL(): Unit = {
GlobalParams.resetGlobalParam(URLKey)
}

private def extractLeft[T](optEither: Option[Either[T, String]]): Option[T] = {
optEither match {
case Some(Left(v)) => Some(v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,35 @@ def test_setters_and_getters(self):
defaults.set_deployment_name("Bing Bong")
defaults.set_subscription_key("SubKey")
defaults.set_temperature(0.05)
defaults.set_URL("Test URL")

self.assertEqual(defaults.get_deployment_name(), "Bing Bong")
self.assertEqual(defaults.get_subscription_key(), "SubKey")
self.assertEqual(defaults.get_temperature(), 0.05)
self.assertEqual(defaults.get_URL(), "Test URL")

def test_resetters(self):
defaults = OpenAIDefaults()

defaults.set_deployment_name("Bing Bong")
defaults.set_subscription_key("SubKey")
defaults.set_temperature(0.05)
defaults.set_URL("Test URL")

self.assertEqual(defaults.get_deployment_name(), "Bing Bong")
self.assertEqual(defaults.get_subscription_key(), "SubKey")
self.assertEqual(defaults.get_temperature(), 0.05)
self.assertEqual(defaults.get_URL(), "Test URL")

defaults.reset_deployment_name()
defaults.reset_subscription_key()
defaults.reset_temperature()
defaults.reset_URL()

self.assertEqual(defaults.get_deployment_name(), None)
self.assertEqual(defaults.get_subscription_key(), None)
self.assertEqual(defaults.get_temperature(), None)
self.assertEqual(defaults.get_URL(), None)

def test_two_defaults(self):
defaults = OpenAIDefaults()
Expand Down Expand Up @@ -78,10 +84,10 @@ def test_prompt_w_defaults(self):
defaults.set_deployment_name("gpt-35-turbo-0125")
defaults.set_subscription_key(openai_api_key)
defaults.set_temperature(0.05)
defaults.set_URL("https://synapseml-openai-2.openai.azure.com/")

prompt = OpenAIPrompt()
prompt = prompt.setOutputCol("outParsed")
prompt = prompt.setCustomServiceName("synapseml-openai-2")
prompt = prompt.setPromptTemplate("Complete this comma separated list of 5 {category}: {text}, ")
results = prompt.transform(df)
results.select("outParsed").show(truncate = False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
import spark.implicits._

def promptCompletion: OpenAICompletion = new OpenAICompletion()
.setCustomServiceName(openAIServiceName)
.setMaxTokens(200)
.setOutputCol("out")
.setPromptCol("prompt")
Expand All @@ -26,6 +25,7 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
OpenAIDefaults.setDeploymentName(deploymentName)
OpenAIDefaults.setSubscriptionKey(openAIAPIKey)
OpenAIDefaults.setTemperature(0.05)
OpenAIDefaults.setURL(s"https://$openAIServiceName.openai.azure.com/")

val fromRow = CompletionResponse.makeFromRowConverter
promptCompletion.transform(promptDF).collect().foreach(r =>
Expand All @@ -34,7 +34,6 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
}

lazy val prompt: OpenAIPrompt = new OpenAIPrompt()
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")

lazy val df: DataFrame = Seq(
Expand All @@ -48,6 +47,7 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
OpenAIDefaults.setDeploymentName(deploymentName)
OpenAIDefaults.setSubscriptionKey(openAIAPIKey)
OpenAIDefaults.setTemperature(0.05)
OpenAIDefaults.setURL(s"https://$openAIServiceName.openai.azure.com/")

val nonNullCount = prompt
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
Expand All @@ -65,22 +65,31 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
}

test("Test Getters") {
OpenAIDefaults.setDeploymentName(deploymentName)
OpenAIDefaults.setSubscriptionKey(openAIAPIKey)
OpenAIDefaults.setTemperature(0.05)
OpenAIDefaults.setURL(s"https://$openAIServiceName.openai.azure.com/")

assert(OpenAIDefaults.getDeploymentName.contains(deploymentName))
assert(OpenAIDefaults.getSubscriptionKey.contains(openAIAPIKey))
assert(OpenAIDefaults.getTemperature.contains(0.05))
assert(OpenAIDefaults.getURL.contains(s"https://$openAIServiceName.openai.azure.com/"))
}

test("Test Resetters") {
OpenAIDefaults.setDeploymentName(deploymentName)
OpenAIDefaults.setSubscriptionKey(openAIAPIKey)
OpenAIDefaults.setTemperature(0.05)
OpenAIDefaults.setURL(s"https://$openAIServiceName.openai.azure.com/")

OpenAIDefaults.resetDeploymentName()
OpenAIDefaults.resetSubscriptionKey()
OpenAIDefaults.resetTemperature()
OpenAIDefaults.resetURL()

assert(OpenAIDefaults.getDeploymentName.isEmpty)
assert(OpenAIDefaults.getSubscriptionKey.isEmpty)
assert(OpenAIDefaults.getTemperature.isEmpty)
assert(OpenAIDefaults.getURL.isEmpty)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.contracts.{HasInputCol, HasOutputCol}
import com.microsoft.azure.synapse.ml.io.http.HandlingUtils.HandlerFunc
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.UDFParam
import com.microsoft.azure.synapse.ml.param.{GlobalKey, GlobalParams, UDFParam}
import org.apache.http.impl.client.CloseableHttpClient
import org.apache.spark.injections.UDFUtils
import org.apache.spark.ml.param._
Expand Down Expand Up @@ -76,10 +76,14 @@ trait ConcurrencyParams extends Wrappable {
setDefault(concurrency -> 1, timeout -> 60.0)
}

case object URLKey extends GlobalKey[String]

trait HasURL extends Params {

val url: Param[String] = new Param[String](this, "url", "Url of the service")

GlobalParams.registerParam(url, URLKey)

/** @group getParam */
def getUrl: String = $(url)

Expand Down

0 comments on commit e703bd2

Please sign in to comment.