From b4e4d449074ca6a0f95e652aa7c8b888798db052 Mon Sep 17 00:00:00 2001 From: Farrukh Masud Date: Fri, 22 Nov 2024 11:47:40 -0800 Subject: [PATCH 1/3] In this Changeset, I am: 1. Ensuring that new prompt instances are created for each test. This will eliminate chance of one test interfering with another test. 2. Creating test to validate that for older model deployments, OpenAiCompletion instance is created in OpenAIPrompt class. 3. Fixing the bug that caused creation of OpenAICompletion object to fail. 4. Cleaning some code in OpenAIPrompt class. --- .../synapse/ml/services/openai/OpenAI.scala | 5 +- .../ml/services/openai/OpenAIPrompt.scala | 24 +++----- .../services/openai/OpenAIPromptSuite.scala | 56 ++++++++++++++----- 3 files changed, 56 insertions(+), 29 deletions(-) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index 5435fcdbf3..b2c29b183b 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -8,6 +8,7 @@ import com.microsoft.azure.synapse.ml.fabric.{FabricClient, OpenAIFabricSetting} import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails import com.microsoft.azure.synapse.ml.param.ServiceParam import com.microsoft.azure.synapse.ml.services._ +import org.apache.http.entity.AbstractHttpEntity import org.apache.spark.ml.PipelineModel import org.apache.spark.ml.param.{Param, Params} import org.apache.spark.sql.Row @@ -294,7 +295,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { } abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String) - with HasOpenAISharedParams with OpenAIFabricSetting { + with HasOpenAISharedParams with OpenAIFabricSetting with HasCognitiveServiceInput { setDefault(timeout -> 360.0) private def usingDefaultOpenAIEndpoint(): Boolean = { @@ -307,4 +308,6 @@ abstract class OpenAIServicesBase(override val uid: String) extends CognitiveSer } super.getInternalTransformer(schema) } + + private[openai] def prepareEntityAIService: Row => Option[AbstractHttpEntity] = this.prepareEntity } diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index a43f3ffe3a..7db574240c 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -106,9 +106,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer df.map({ row => val originalOutput = Option(row.getAs[Row](outputCol)) .map({ row => openAIResultFromRow(row).choices.head }) - val isFiltered = originalOutput - .map(output => Option(output.message.content).isEmpty) - .getOrElse(false) + val isFiltered = originalOutput.exists(output => Option(output.message.content).isEmpty) if (isFiltered) { val updatedRowSeq = row.toSeq.updated( @@ -183,7 +181,10 @@ class OpenAIPrompt(override val uid: String) extends Transformer "text-ada-001", "text-babbage-001", "text-curie-001", "text-davinci-002", "text-davinci-003", "code-cushman-001", "code-davinci-002") - private def openAICompletion: OpenAIServicesBase = { + /** + * This method is made available in the `openai` package for testing purposes. + */ + private[openai] def openAICompletion: OpenAIServicesBase = { val completion: OpenAIServicesBase = if (legacyModels.contains(getDeploymentName)) { @@ -194,29 +195,22 @@ class OpenAIPrompt(override val uid: String) extends Transformer } // apply all parameters extractParamMap().toSeq + .filter(p => completion.hasParam(p.param.name)) .filter(p => !localParamNames.contains(p.param.name)) .foreach(p => completion.set(completion.getParam(p.param.name), p.value)) completion } - override protected def prepareEntity: Row => Option[AbstractHttpEntity] = { - r => - openAICompletion match { - case chatCompletion: OpenAIChatCompletion => - chatCompletion.prepareEntity(r) - case completion: OpenAICompletion => - completion.prepareEntity(r) - } - } + override protected def prepareEntity: Row => Option[AbstractHttpEntity] = openAICompletion.prepareEntityAIService private def getParser: OutputParser = { val opts = getPostProcessingOptions getPostProcessing.toLowerCase match { case "csv" => new DelimiterParser(opts.getOrElse("delimiter", ",")) - case "json" => new JsonParser(opts.get("jsonSchema").get, Map.empty) - case "regex" => new RegexParser(opts.get("regex").get, opts.get("regexGroup").get.toInt) + case "json" => new JsonParser(opts("jsonSchema"), Map.empty) + case "regex" => new RegexParser(opts("regex"), opts("regexGroup").toInt) case "" => new PassThroughParser() case _ => throw new IllegalArgumentException(s"Unsupported postProcessing type: '$getPostProcessing'") } diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala index 9409d2f358..1d883f5f01 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala @@ -21,12 +21,42 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK super.beforeAll() } - lazy val prompt: OpenAIPrompt = new OpenAIPrompt() - .setSubscriptionKey(openAIAPIKey) - .setDeploymentName(deploymentName) - .setCustomServiceName(openAIServiceName) - .setOutputCol("outParsed") - .setTemperature(0) + private def createPromptInstance(deploymentName: String) = { + new OpenAIPrompt() + .setSubscriptionKey(openAIAPIKey) + .setDeploymentName(deploymentName) + .setCustomServiceName(openAIServiceName) + .setOutputCol("outParsed") + .setTemperature(0) + } + + + test("Validating OpenAICompletion and OpenAIChatCompletion for correctModels") { + val generation1Models: Seq[String] = Seq("ada", + "babbage", + "curie", + "davinci", + "text-ada-001", + "text-babbage-001", + "text-curie-001", + "text-davinci-002", + "text-davinci-003", + "code-cushman-001", + "code-davinci-002") + + val generation2Models: Seq[String] = Seq("gpt-35-turbo", + "gpt-4-turbo", + "gpt-4o") + + def validateModelType(models: Seq[String], expectedType: Class[_]): Unit = { + models.foreach { model => + assert(createPromptInstance(model).openAICompletion.getClass == expectedType) + } + } + + validateModelType(generation1Models, classOf[OpenAICompletion]) + validateModelType(generation2Models, classOf[OpenAIChatCompletion]) + } lazy val df: DataFrame = Seq( ("apple", "fruits"), @@ -36,8 +66,8 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK ).toDF("text", "category") test("RAI Usage") { + val prompt: OpenAIPrompt = createPromptInstance(deploymentNameGpt4) val result = prompt - .setDeploymentName(deploymentNameGpt4) .setPromptTemplate("Tell me about a graphically disgusting movie in detail") .transform(df) .select(prompt.getErrorCol) @@ -46,6 +76,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Basic Usage") { + val prompt: OpenAIPrompt = createPromptInstance(deploymentName) val nonNullCount = prompt .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") .setPostProcessing("csv") @@ -58,6 +89,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Basic Usage JSON") { + val prompt: OpenAIPrompt = createPromptInstance(deploymentName) prompt.setPromptTemplate( """Split a word into prefix and postfix a respond in JSON |Cherry: {{"prefix": "Che", "suffix": "rry"}} @@ -72,14 +104,9 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK .foreach(r => assert(r.getStruct(0).getString(0).nonEmpty)) } - lazy val promptGpt4: OpenAIPrompt = new OpenAIPrompt() - .setSubscriptionKey(openAIAPIKey) - .setDeploymentName(deploymentNameGpt4) - .setCustomServiceName(openAIServiceName) - .setOutputCol("outParsed") - .setTemperature(0) test("Basic Usage - Gpt 4") { + val promptGpt4: OpenAIPrompt = createPromptInstance(deploymentNameGpt4) val nonNullCount = promptGpt4 .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") .setPostProcessing("csv") @@ -92,6 +119,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Basic Usage JSON - Gpt 4") { + val promptGpt4: OpenAIPrompt = createPromptInstance(deploymentNameGpt4) promptGpt4.setPromptTemplate( """Split a word into prefix and postfix a respond in JSON |Cherry: {{"prefix": "Che", "suffix": "rry"}} @@ -107,6 +135,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Setting and Keeping Messages Col - Gpt 4") { + val promptGpt4: OpenAIPrompt = createPromptInstance(deploymentNameGpt4) promptGpt4.setMessagesCol("messages") .setDropPrompt(false) .setPromptTemplate( @@ -154,6 +183,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } override def testObjects(): Seq[TestObject[OpenAIPrompt]] = { + val prompt: OpenAIPrompt = createPromptInstance(deploymentName) val testPrompt = prompt .setPromptTemplate("{text} rhymes with ") From 7435ae889382e49cedc2f89d04fa7f312bccf62c Mon Sep 17 00:00:00 2001 From: Farrukh Masud Date: Fri, 22 Nov 2024 14:39:54 -0800 Subject: [PATCH 2/3] fixing failing unit tests --- .../azure/synapse/ml/services/openai/OpenAIPrompt.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index 7db574240c..fcc856009e 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -99,7 +99,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer private val localParamNames = Seq( "promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt", "dropMessages", - "systemPrompt") + "systemPrompt", "messagesCol") private def addRAIErrors(df: DataFrame, errorCol: String, outputCol: String): DataFrame = { val openAIResultFromRow = ChatCompletionResponse.makeFromRowConverter @@ -191,11 +191,10 @@ class OpenAIPrompt(override val uid: String) extends Transformer new OpenAICompletion() } else { - new OpenAIChatCompletion() + new OpenAIChatCompletion().setMessagesCol(getMessagesCol) } // apply all parameters extractParamMap().toSeq - .filter(p => completion.hasParam(p.param.name)) .filter(p => !localParamNames.contains(p.param.name)) .foreach(p => completion.set(completion.getParam(p.param.name), p.value)) From ddc263de4088359807cc1147ab3771ef0cb72545 Mon Sep 17 00:00:00 2001 From: Farrukh Masud Date: Mon, 25 Nov 2024 14:56:04 -0800 Subject: [PATCH 3/3] fixing failing test --- .../azure/synapse/ml/services/openai/OpenAIPrompt.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index fcc856009e..bb690ec23a 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -99,7 +99,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer private val localParamNames = Seq( "promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt", "dropMessages", - "systemPrompt", "messagesCol") + "systemPrompt") private def addRAIErrors(df: DataFrame, errorCol: String, outputCol: String): DataFrame = { val openAIResultFromRow = ChatCompletionResponse.makeFromRowConverter