diff --git a/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py b/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py index 3a2566fd5f..7aad53f842 100644 --- a/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py +++ b/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py @@ -49,3 +49,12 @@ def get_temperature(self): def reset_temperature(self): self.defaults.resetTemperature() + + def set_URL(self, URL): + self.defaults.setURL(URL) + + def get_URL(self): + return getOption(self.defaults.getURL()) + + def reset_URL(self): + self.defaults.resetURL() diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala index f0c102c062..f8405fbe1b 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala @@ -5,6 +5,7 @@ package com.microsoft.azure.synapse.ml.services.openai import com.microsoft.azure.synapse.ml.param.GlobalParams import com.microsoft.azure.synapse.ml.services.OpenAISubscriptionKey +import com.microsoft.azure.synapse.ml.io.http.URLKey object OpenAIDefaults { def setDeploymentName(v: String): Unit = { @@ -43,6 +44,18 @@ object OpenAIDefaults { GlobalParams.resetGlobalParam(OpenAITemperatureKey) } + def setURL(v: String): Unit = { + GlobalParams.setGlobalParam(URLKey, v) + } + + def getURL: Option[String] = { + GlobalParams.getGlobalParam(URLKey) + } + + def resetURL(): Unit = { + GlobalParams.resetGlobalParam(URLKey) + } + private def extractLeft[T](optEither: Option[Either[T, String]]): Option[T] = { optEither match { case Some(Left(v)) => Some(v) diff --git a/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py b/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py index e276378dbb..a24190c3dc 100644 --- a/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py +++ b/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py @@ -21,10 +21,12 @@ def test_setters_and_getters(self): defaults.set_deployment_name("Bing Bong") defaults.set_subscription_key("SubKey") defaults.set_temperature(0.05) + defaults.set_URL("Test URL") self.assertEqual(defaults.get_deployment_name(), "Bing Bong") self.assertEqual(defaults.get_subscription_key(), "SubKey") self.assertEqual(defaults.get_temperature(), 0.05) + self.assertEqual(defaults.get_URL(), "Test URL") def test_resetters(self): defaults = OpenAIDefaults() @@ -32,18 +34,22 @@ def test_resetters(self): defaults.set_deployment_name("Bing Bong") defaults.set_subscription_key("SubKey") defaults.set_temperature(0.05) + defaults.set_URL("Test URL") self.assertEqual(defaults.get_deployment_name(), "Bing Bong") self.assertEqual(defaults.get_subscription_key(), "SubKey") self.assertEqual(defaults.get_temperature(), 0.05) + self.assertEqual(defaults.get_URL(), "Test URL") defaults.reset_deployment_name() defaults.reset_subscription_key() defaults.reset_temperature() + defaults.reset_URL() self.assertEqual(defaults.get_deployment_name(), None) self.assertEqual(defaults.get_subscription_key(), None) self.assertEqual(defaults.get_temperature(), None) + self.assertEqual(defaults.get_URL(), None) def test_two_defaults(self): defaults = OpenAIDefaults() @@ -78,10 +84,10 @@ def test_prompt_w_defaults(self): defaults.set_deployment_name("gpt-35-turbo-0125") defaults.set_subscription_key(openai_api_key) defaults.set_temperature(0.05) + defaults.set_URL("https://synapseml-openai-2.openai.azure.com/") prompt = OpenAIPrompt() prompt = prompt.setOutputCol("outParsed") - prompt = prompt.setCustomServiceName("synapseml-openai-2") prompt = prompt.setPromptTemplate("Complete this comma separated list of 5 {category}: {text}, ") results = prompt.transform(df) results.select("outParsed").show(truncate = False) diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala index bb108bcb6d..487e0345bc 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala @@ -11,7 +11,6 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { import spark.implicits._ def promptCompletion: OpenAICompletion = new OpenAICompletion() - .setCustomServiceName(openAIServiceName) .setMaxTokens(200) .setOutputCol("out") .setPromptCol("prompt") @@ -26,6 +25,7 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { OpenAIDefaults.setDeploymentName(deploymentName) OpenAIDefaults.setSubscriptionKey(openAIAPIKey) OpenAIDefaults.setTemperature(0.05) + OpenAIDefaults.setURL(s"https://$openAIServiceName.openai.azure.com/") val fromRow = CompletionResponse.makeFromRowConverter promptCompletion.transform(promptDF).collect().foreach(r => @@ -34,7 +34,6 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { } lazy val prompt: OpenAIPrompt = new OpenAIPrompt() - .setCustomServiceName(openAIServiceName) .setOutputCol("outParsed") lazy val df: DataFrame = Seq( @@ -48,6 +47,7 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { OpenAIDefaults.setDeploymentName(deploymentName) OpenAIDefaults.setSubscriptionKey(openAIAPIKey) OpenAIDefaults.setTemperature(0.05) + OpenAIDefaults.setURL(s"https://$openAIServiceName.openai.azure.com/") val nonNullCount = prompt .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") @@ -65,22 +65,31 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { } test("Test Getters") { + OpenAIDefaults.setDeploymentName(deploymentName) + OpenAIDefaults.setSubscriptionKey(openAIAPIKey) + OpenAIDefaults.setTemperature(0.05) + OpenAIDefaults.setURL(s"https://$openAIServiceName.openai.azure.com/") + assert(OpenAIDefaults.getDeploymentName.contains(deploymentName)) assert(OpenAIDefaults.getSubscriptionKey.contains(openAIAPIKey)) assert(OpenAIDefaults.getTemperature.contains(0.05)) + assert(OpenAIDefaults.getURL.contains(s"https://$openAIServiceName.openai.azure.com/")) } test("Test Resetters") { OpenAIDefaults.setDeploymentName(deploymentName) OpenAIDefaults.setSubscriptionKey(openAIAPIKey) OpenAIDefaults.setTemperature(0.05) + OpenAIDefaults.setURL(s"https://$openAIServiceName.openai.azure.com/") OpenAIDefaults.resetDeploymentName() OpenAIDefaults.resetSubscriptionKey() OpenAIDefaults.resetTemperature() + OpenAIDefaults.resetURL() assert(OpenAIDefaults.getDeploymentName.isEmpty) assert(OpenAIDefaults.getSubscriptionKey.isEmpty) assert(OpenAIDefaults.getTemperature.isEmpty) + assert(OpenAIDefaults.getURL.isEmpty) } } diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala index 8d942e34b1..43ea24d112 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala @@ -7,7 +7,7 @@ import com.microsoft.azure.synapse.ml.codegen.Wrappable import com.microsoft.azure.synapse.ml.core.contracts.{HasInputCol, HasOutputCol} import com.microsoft.azure.synapse.ml.io.http.HandlingUtils.HandlerFunc import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging} -import com.microsoft.azure.synapse.ml.param.UDFParam +import com.microsoft.azure.synapse.ml.param.{GlobalKey, GlobalParams, UDFParam} import org.apache.http.impl.client.CloseableHttpClient import org.apache.spark.injections.UDFUtils import org.apache.spark.ml.param._ @@ -76,10 +76,14 @@ trait ConcurrencyParams extends Wrappable { setDefault(concurrency -> 1, timeout -> 60.0) } +case object URLKey extends GlobalKey[String] + trait HasURL extends Params { val url: Param[String] = new Param[String](this, "url", "Url of the service") + GlobalParams.registerParam(url, URLKey) + /** @group getParam */ def getUrl: String = $(url)