diff --git a/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py b/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py index 3292fe18ca..3a2566fd5f 100644 --- a/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py +++ b/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py @@ -42,7 +42,7 @@ def reset_subscription_key(self): self.defaults.resetSubscriptionKey() def set_temperature(self, temp): - self.defaults.setTemperature(temp) + self.defaults.setTemperature(float(temp)) def get_temperature(self): return getOption(self.defaults.getTemperature()) diff --git a/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py b/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py index d239964767..e276378dbb 100644 --- a/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py +++ b/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py @@ -3,8 +3,10 @@ from synapse.ml.services.openai.OpenAIDefaults import OpenAIDefaults from synapse.ml.services.openai.OpenAIPrompt import OpenAIPrompt -import unittest +import unittest,os, json, subprocess from pyspark.sql import SQLContext +from pyspark.sql.functions import col + from synapse.ml.core.init_spark import * @@ -58,6 +60,34 @@ def test_two_defaults(self): defaults.set_deployment_name("Test 1") self.assertEqual(defaults.get_deployment_name(), "Test 1") + def test_prompt_w_defaults(self): + + secretJson = subprocess.check_output( + "az keyvault secret show --vault-name mmlspark-build-keys --name openai-api-key-2", + shell=True, + ) + openai_api_key = json.loads(secretJson)["value"] + + df = spark.createDataFrame([ + ("apple", "fruits"), + ("mercedes", "cars"), + ("cake", "dishes"), + ], ["text", "category"]) + + defaults = OpenAIDefaults() + defaults.set_deployment_name("gpt-35-turbo-0125") + defaults.set_subscription_key(openai_api_key) + defaults.set_temperature(0.05) + + prompt = OpenAIPrompt() + prompt = prompt.setOutputCol("outParsed") + prompt = prompt.setCustomServiceName("synapseml-openai-2") + prompt = prompt.setPromptTemplate("Complete this comma separated list of 5 {category}: {text}, ") + results = prompt.transform(df) + results.select("outParsed").show(truncate = False) + nonNullCount = results.filter(col("outParsed").isNotNull()).count() + assert (nonNullCount == 3) + if __name__ == "__main__": result = unittest.main()