diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index 5435fcdbf3..438a27791c 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -271,7 +271,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { // list of shared text parameters. In method getOptionalParams, we will iterate over these parameters // to compute the optional parameters. Since this list never changes, we can create it once and reuse it. - private val sharedTextParams = Seq( + private[openai] val sharedTextParams: Seq[ServiceParam[_]] = Seq( maxTokens, temperature, topP, diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala index 379d797766..8ae2819cf6 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala @@ -5,6 +5,7 @@ package com.microsoft.azure.synapse.ml.services.openai import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging} import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat +import com.microsoft.azure.synapse.ml.param.ServiceParam import com.microsoft.azure.synapse.ml.services.{HasCognitiveServiceInput, HasInternalJsonOutputParser} import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity} import org.apache.spark.ml.ComplexParamsReadable @@ -18,8 +19,84 @@ import scala.language.existentials object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion] +object OpenAIResponseFormat extends Enumeration { + case class ResponseFormat(name: String, prompt: String) extends super.Val(name) + val TEXT: ResponseFormat = ResponseFormat("text", "Output must be in text format") + val JSON: ResponseFormat = ResponseFormat("json_object", "Output must be in JSON format") +} + + +trait HasOpenAITextParamsExtended extends HasOpenAITextParams { + val responseFormat: ServiceParam[Map[String, String]] = new ServiceParam[Map[String, String]]( + this, + "responseFormat", + "Response format for the completion. Can be 'json_object' or 'text'.", + isRequired = false) { + override val payloadName: String = "response_format" + } + + def getResponseFormat: Map[String, String] = getScalarParam(responseFormat) + + def setResponseFormat(value: Map[String, String]): this.type = { + if (!OpenAIResponseFormat.values.map(_.asInstanceOf[OpenAIResponseFormat.ResponseFormat].name) + .contains(value("type"))) { + throw new IllegalArgumentException("Response format must be 'text' or 'json_object'") + } + setScalarParam(responseFormat, value) + } + + def setResponseFormat(value: String): this.type = { + if (value.isEmpty) { + this + } else { + val normalizedValue = value.toLowerCase match { + case "json" => "json_object" + case other => other + } + // Validate the normalized value using the OpenAIResponseFormat enum + if (!OpenAIResponseFormat.values + .map(_.asInstanceOf[OpenAIResponseFormat.ResponseFormat].name) + .contains(normalizedValue)) { + throw new IllegalArgumentException("Response format must be valid for OpenAI API. " + + "Currently supported formats are " + OpenAIResponseFormat.values + .map(_.asInstanceOf[OpenAIResponseFormat.ResponseFormat].name) + .mkString(", ")) + } + + setScalarParam(responseFormat, Map("type" -> normalizedValue)) + } + } + + def setResponseFormat(value: OpenAIResponseFormat.ResponseFormat): this.type = { + // this method should throw an excption if the openAiCompletion is not a ChatCompletion + this.setResponseFormat(value.name) + } + + def getResponseFormatCol: String = getVectorParam(responseFormat) + + def setResponseFormatCol(value: String): this.type = setVectorParam(responseFormat, value) + + + // Recreating the sharedTextParams sequence to include additional parameter responseFormat + override private[openai] val sharedTextParams: Seq[ServiceParam[_]] = Seq( + maxTokens, + temperature, + topP, + user, + n, + echo, + stop, + cacheLevel, + presencePenalty, + frequencyPenalty, + bestOf, + logProbs, + responseFormat // Additional parameter + ) +} + class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(uid) - with HasOpenAITextParams with HasMessagesInput with HasCognitiveServiceInput + with HasOpenAITextParamsExtended with HasMessagesInput with HasCognitiveServiceInput with HasInternalJsonOutputParser with SynapseMLLogging { logClass(FeatureNames.AiServices.OpenAI) @@ -55,12 +132,23 @@ class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase( override def responseDataType: DataType = ChatCompletionResponse.schema private[this] def getStringEntity(messages: Seq[Row], optionalParams: Map[String, Any]): StringEntity = { - val mappedMessages: Seq[Map[String, String]] = messages.map { m => - Seq("role", "content", "name").map(n => - n -> Option(m.getAs[String](n)) - ).toMap.filter(_._2.isDefined).mapValues(_.get) + // Convert each message row to a map of non-null values + val mappedMessages: Seq[Map[String, String]] = messages.map { messageRow => + Seq("role", "content", "name").map { fieldName => + fieldName -> Option(messageRow.getAs[String](fieldName)) + }.toMap.filter(_._2.isDefined).mapValues(_.get) + } + + // Check if the response format is JSON and add a system message if needed + val updatedMessages = if (optionalParams.get("response_format") + .exists(_.asInstanceOf[Map[String, String]]("type") == "json_object")) { + mappedMessages :+ Map("role" -> "system", "content" -> OpenAIResponseFormat.JSON.prompt) + } else { + mappedMessages } - val fullPayload = optionalParams.updated("messages", mappedMessages) + + // Update the optional parameters with the messages + val fullPayload = optionalParams.updated("messages", updatedMessages) new StringEntity(fullPayload.toJson.compactPrint, ContentType.APPLICATION_JSON) } diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index a43f3ffe3a..d7bbba630d 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ object OpenAIPrompt extends ComplexParamsReadable[OpenAIPrompt] class OpenAIPrompt(override val uid: String) extends Transformer - with HasOpenAITextParams with HasMessagesInput + with HasOpenAITextParamsExtended with HasMessagesInput with HasErrorCol with HasOutputCol with HasURL with HasCustomCogServiceDomain with ConcurrencyParams with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader @@ -60,7 +60,37 @@ class OpenAIPrompt(override val uid: String) extends Transformer def getPostProcessingOptions: Map[String, String] = $(postProcessingOptions) - def setPostProcessingOptions(value: Map[String, String]): this.type = set(postProcessingOptions, value) + def setPostProcessingOptions(value: Map[String, String]): this.type = { + // Helper method to validate that regex options contain the required "regexGroup" key + def validateRegexOptions(options: Map[String, String]): Unit = { + require(options.contains("regexGroup"), "regexGroup must be specified with regex") + } + + // Helper method to set or validate the postProcessing parameter + def setOrValidatePostProcessing(expected: String): Unit = { + if (isSet(postProcessing)) { + require(getPostProcessing == expected, s"postProcessing must be '$expected'") + } else { + set(postProcessing, expected) + } + } + + // Match on the keys in the provided value map to set the appropriate post-processing option + value match { + case v if v.contains("delimiter") => + setOrValidatePostProcessing("csv") + case v if v.contains("jsonSchema") => + setOrValidatePostProcessing("json") + case v if v.contains("regex") => + validateRegexOptions(v) + setOrValidatePostProcessing("regex") + case _ => + throw new IllegalArgumentException("Invalid post processing options") + } + + // Set the postProcessingOptions parameter with the provided value map + set(postProcessingOptions, value) + } def setPostProcessingOptions(v: java.util.HashMap[String, String]): this.type = set(postProcessingOptions, v.asScala.toMap) @@ -99,16 +129,14 @@ class OpenAIPrompt(override val uid: String) extends Transformer private val localParamNames = Seq( "promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt", "dropMessages", - "systemPrompt") + "systemPrompt", "responseFormat") private def addRAIErrors(df: DataFrame, errorCol: String, outputCol: String): DataFrame = { val openAIResultFromRow = ChatCompletionResponse.makeFromRowConverter df.map({ row => val originalOutput = Option(row.getAs[Row](outputCol)) .map({ row => openAIResultFromRow(row).choices.head }) - val isFiltered = originalOutput - .map(output => Option(output.message.content).isEmpty) - .getOrElse(false) + val isFiltered = originalOutput.exists(output => Option(output.message.content).isEmpty) if (isFiltered) { val updatedRowSeq = row.toSeq.updated( @@ -127,7 +155,6 @@ class OpenAIPrompt(override val uid: String) extends Transformer logTransform[DataFrame]({ val df = dataset.toDF - val completion = openAICompletion val promptCol = Functions.template(getPromptTemplate) val createMessagesUDF = udf((userMessage: String) => { @@ -138,18 +165,19 @@ class OpenAIPrompt(override val uid: String) extends Transformer }) completion match { case chatCompletion: OpenAIChatCompletion => + if (isSet(responseFormat)) { + chatCompletion.setResponseFormat(getResponseFormat) + } val messageColName = getMessagesCol val dfTemplated = df.withColumn(messageColName, createMessagesUDF(promptCol)) val completionNamed = chatCompletion.setMessagesCol(messageColName) - val transformed = addRAIErrors( completionNamed.transform(dfTemplated), chatCompletion.getErrorCol, chatCompletion.getOutputCol) val results = transformed .withColumn(getOutputCol, getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1) - .getField("message").getField("content"))) - .drop(completionNamed.getOutputCol) + .getField("message").getField("content"))).drop(completionNamed.getOutputCol) if (getDropPrompt) { results.drop(messageColName) @@ -158,6 +186,9 @@ class OpenAIPrompt(override val uid: String) extends Transformer } case completion: OpenAICompletion => + if (isSet(responseFormat)) { + throw new IllegalArgumentException("responseFormat is not supported for OpenAICompletion") + } val promptColName = df.withDerivativeCol("prompt") val dfTemplated = df.withColumn(promptColName, promptCol) val completionNamed = completion.setPromptCol(promptColName) @@ -215,8 +246,8 @@ class OpenAIPrompt(override val uid: String) extends Transformer getPostProcessing.toLowerCase match { case "csv" => new DelimiterParser(opts.getOrElse("delimiter", ",")) - case "json" => new JsonParser(opts.get("jsonSchema").get, Map.empty) - case "regex" => new RegexParser(opts.get("regex").get, opts.get("regexGroup").get.toInt) + case "json" => new JsonParser(opts("jsonSchema"), Map.empty) + case "regex" => new RegexParser(opts("regex"), opts("regexGroup").toInt) case "" => new PassThroughParser() case _ => throw new IllegalArgumentException(s"Unsupported postProcessing type: '$getPostProcessing'") } diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala index 89458decf8..416a9faa8b 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala @@ -8,6 +8,8 @@ import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, Transformer import org.apache.spark.ml.util.MLReadable import org.apache.spark.sql.{DataFrame, Row} import org.scalactic.Equality +import org.scalatest.matchers.must.Matchers.{an, be} +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion] with OpenAIAPIKey with Flaky { @@ -174,6 +176,45 @@ class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion] testCompletion(customEndpointCompletion, goodDf) } + test("setResponseFormat should set the response format correctly") { + completion.setResponseFormat("text") + completion.getResponseFormat shouldEqual Map("type" -> "text") + + completion.setResponseFormat("tExT") + completion.getResponseFormat shouldEqual Map("type" -> "text") + + completion.setResponseFormat("json") + completion.getResponseFormat shouldEqual Map("type" -> "json_object") + + completion.setResponseFormat("JSON") + completion.getResponseFormat shouldEqual Map("type" -> "json_object") + + completion.setResponseFormat("json_object") + completion.getResponseFormat shouldEqual Map("type" -> "json_object") + + completion.setResponseFormat("Json_ObjeCt") + completion.getResponseFormat shouldEqual Map("type" -> "json_object") + } + + test("setResponseFormat should throw an exception for invalid response format") { + an[IllegalArgumentException] should be thrownBy { + completion.setResponseFormat("invalid_format") + } + } + + test("setResponseFormat with ResponseFormat should set the response format correctly") { + completion.setResponseFormat(OpenAIResponseFormat.TEXT) + completion.getResponseFormat shouldEqual Map("type" -> "text") + + completion.setResponseFormat(OpenAIResponseFormat.JSON) + completion.getResponseFormat shouldEqual Map("type" -> "json_object") + } + + test("setResponseFormatCol should set the response format column correctly") { + completion.setResponseFormatCol("response_format_col") + completion.getResponseFormatCol shouldEqual "response_format_col" + } + def testCompletion(completion: OpenAIChatCompletion, df: DataFrame, requiredLength: Int = 10): Unit = { val fromRow = ChatCompletionResponse.makeFromRowConverter completion.transform(df).collect().foreach(r => diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletionSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletionSuite.scala index 807426c468..209778d805 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletionSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletionSuite.scala @@ -17,6 +17,8 @@ trait OpenAIAPIKey { lazy val deploymentName: String = "gpt-35-turbo" lazy val modelName: String = "gpt-35-turbo" lazy val deploymentNameGpt4: String = "gpt-4" + lazy val deploymentNameDavinci3: String = "text-davinci-003" + lazy val deploymentNameGpt4o: String = "gpt-4o" lazy val modelNameGpt4: String = "gpt-4" } diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala index 9409d2f358..39bcf12d15 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala @@ -10,6 +10,8 @@ import org.apache.spark.ml.util.MLReadable import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.col import org.scalactic.Equality +import org.scalatest.matchers.must.Matchers.be +import org.scalatest.matchers.should.Matchers.{an, convertToAnyShouldWrapper} class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIKey with Flaky { @@ -21,12 +23,6 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK super.beforeAll() } - lazy val prompt: OpenAIPrompt = new OpenAIPrompt() - .setSubscriptionKey(openAIAPIKey) - .setDeploymentName(deploymentName) - .setCustomServiceName(openAIServiceName) - .setOutputCol("outParsed") - .setTemperature(0) lazy val df: DataFrame = Seq( ("apple", "fruits"), @@ -36,8 +32,8 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK ).toDF("text", "category") test("RAI Usage") { + val prompt = createPromptInstance(deploymentNameGpt4) val result = prompt - .setDeploymentName(deploymentNameGpt4) .setPromptTemplate("Tell me about a graphically disgusting movie in detail") .transform(df) .select(prompt.getErrorCol) @@ -46,6 +42,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Basic Usage") { + val prompt: OpenAIPrompt = createPromptInstance(deploymentName) val nonNullCount = prompt .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") .setPostProcessing("csv") @@ -57,7 +54,20 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK assert(nonNullCount == 3) } + test("Basic Usage with only post processing options") { + val nonNullCount = createPromptInstance(deploymentName) + .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") + .setPostProcessingOptions(Map("delimiter" -> ",")) + .transform(df) + .select("outParsed") + .collect() + .count(r => Option(r.getSeq[String](0)).isDefined) + + assert(nonNullCount == 3) + } + test("Basic Usage JSON") { + val prompt = createPromptInstance(deploymentName) prompt.setPromptTemplate( """Split a word into prefix and postfix a respond in JSON |Cherry: {{"prefix": "Che", "suffix": "rry"}} @@ -72,14 +82,40 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK .foreach(r => assert(r.getStruct(0).getString(0).nonEmpty)) } - lazy val promptGpt4: OpenAIPrompt = new OpenAIPrompt() - .setSubscriptionKey(openAIAPIKey) - .setDeploymentName(deploymentNameGpt4) - .setCustomServiceName(openAIServiceName) - .setOutputCol("outParsed") - .setTemperature(0) + test("Basic Usage JSON using text response format") { + val prompt = createPromptInstance(deploymentName) + prompt.setPromptTemplate( + """Split a word into prefix and postfix a respond in JSON + |Cherry: {{"prefix": "Che", "suffix": "rry"}} + |{text}: + |""".stripMargin) + .setResponseFormat("text") + .setPostProcessing("json") + .setPostProcessingOptions(Map("jsonSchema" -> "prefix STRING, suffix STRING")) + .transform(df) + .select("outParsed") + .where(col("outParsed").isNotNull) + .collect() + .foreach(r => assert(r.getStruct(0).getString(0).nonEmpty)) + } + + test("Basic Usage JSON using only post processing oiptions") { + val prompt = createPromptInstance(deploymentName) + prompt.setPromptTemplate( + """Split a word into prefix and postfix a respond in JSON + |Cherry: {{"prefix": "Che", "suffix": "rry"}} + |{text}: + |""".stripMargin) + .setPostProcessingOptions(Map("jsonSchema" -> "prefix STRING, suffix STRING")) + .transform(df) + .select("outParsed") + .where(col("outParsed").isNotNull) + .collect() + .foreach(r => assert(r.getStruct(0).getString(0).nonEmpty)) + } test("Basic Usage - Gpt 4") { + val promptGpt4 = createPromptInstance(deploymentNameGpt4) val nonNullCount = promptGpt4 .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") .setPostProcessing("csv") @@ -92,6 +128,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Basic Usage JSON - Gpt 4") { + val promptGpt4 = createPromptInstance(deploymentNameGpt4) promptGpt4.setPromptTemplate( """Split a word into prefix and postfix a respond in JSON |Cherry: {{"prefix": "Che", "suffix": "rry"}} @@ -106,7 +143,67 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK .foreach(r => assert(r.getStruct(0).getString(0).nonEmpty)) } + test("Basic Usage JSON - Gpt 4 using responseFormat TEXT") { + val promptGpt4 = createPromptInstance(deploymentNameGpt4) + promptGpt4.setPromptTemplate( + """Split a word into prefix and postfix a respond in JSON + |Cherry: {{"prefix": "Che", "suffix": "rry"}} + |{text}: + |""".stripMargin) + .setPostProcessing("json") + .setResponseFormat("text") + .setPostProcessingOptions(Map("jsonSchema" -> "prefix STRING, suffix STRING")) + .transform(df) + .select("outParsed") + .where(col("outParsed").isNotNull) + .collect() + .foreach(r => assert(r.getStruct(0).getString(0).nonEmpty)) + } + + test("Basic Usage JSON - Gpt 4o using responseFormat JSON") { + val promptGpt4o = createPromptInstance(deploymentNameGpt4o) + promptGpt4o.setPromptTemplate( + """Split a word into prefix and postfix + |Cherry: {{"prefix": "Che", "suffix": "rry"}} + |{text}: + |""".stripMargin) + .setResponseFormat("json") + .setPostProcessingOptions(Map("jsonSchema" -> "prefix STRING, suffix STRING")) + .transform(df) + .select("outParsed") + .where(col("outParsed").isNotNull) + .collect() + .foreach(r => assert(r.getStruct(0).getString(0).nonEmpty)) + } + + test("Basic Usage - Gpt 4o with response format json") { + val promptGpt4o = createPromptInstance(deploymentNameGpt4o) + val nonNullCount = promptGpt4o + .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") + .setResponseFormat(OpenAIResponseFormat.JSON) + .transform(df) + .select("outParsed") + .collect() + .length + + assert(nonNullCount == 4) + } + + test("Basic Usage - Gpt 4o with response format text") { + val promptGpt4o = createPromptInstance(deploymentNameGpt4o) + val nonNullCount = promptGpt4o + .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") + .setResponseFormat(OpenAIResponseFormat.TEXT) + .transform(df) + .select("outParsed") + .collect() + .length + + assert(nonNullCount == 4) + } + test("Setting and Keeping Messages Col - Gpt 4") { + val promptGpt4 = createPromptInstance(deploymentNameGpt4) promptGpt4.setMessagesCol("messages") .setDropPrompt(false) .setPromptTemplate( @@ -122,6 +219,31 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK .foreach(r => assert(r.get(0) != null)) } + ignore("Basic Usage - Davinci 3 with no response format") { + val promptDavinci3 = createPromptInstance(deploymentNameDavinci3) + val rowCount = promptDavinci3 + .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") + .transform(df) + .select("outParsed") + .collect() + .length + assert(rowCount == 4) + } + + ignore("Basic Usage - Davinci 3 with response format json") { + val promptDavinci3 = createPromptInstance(deploymentNameDavinci3) + intercept[IllegalArgumentException] { + promptDavinci3 + .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") + .setResponseFormat(OpenAIResponseFormat.JSON) + .transform(df) + .select("outParsed") + .collect() + .length + } + } + + ignore("Custom EndPoint") { lazy val accessToken: String = sys.env.getOrElse("CUSTOM_ACCESS_TOKEN", "") lazy val customRootUrlValue: String = sys.env.getOrElse("CUSTOM_ROOT_URL", "") @@ -149,11 +271,71 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK .count(r => Option(r.getSeq[String](0)).isDefined) } + test("setResponseFormat should set the response format correctly with String") { + val prompt = new OpenAIPrompt() + prompt.setResponseFormat("json") + prompt.getResponseFormat shouldEqual Map("type" -> "json_object") + + prompt.setResponseFormat("json_object") + prompt.getResponseFormat shouldEqual Map("type" -> "json_object") + + prompt.setResponseFormat("text") + prompt.getResponseFormat shouldEqual Map("type" -> "text") + } + + test("setResponseFormat should throw an exception for invalid response format") { + val prompt = new OpenAIPrompt() + an[IllegalArgumentException] should be thrownBy { + prompt.setResponseFormat("invalid_format") + } + } + + test("setPostProcessingOptions should set postProcessing to 'csv' for delimiter option") { + val prompt = new OpenAIPrompt() + prompt.setPostProcessingOptions(Map("delimiter" -> ",")) + prompt.getPostProcessing should be ("csv") + } + + test("setPostProcessingOptions should set postProcessing to 'json' for jsonSchema option") { + val prompt = new OpenAIPrompt() + prompt.setPostProcessingOptions(Map("jsonSchema" -> "schema")) + prompt.getPostProcessing should be ("json") + } + + test("setPostProcessingOptions should set postProcessing to 'regex' for regex option") { + val prompt = new OpenAIPrompt() + prompt.setPostProcessingOptions(Map("regex" -> ".*", "regexGroup" -> "0")) + prompt.getPostProcessing should be ("regex") + } + + test("setPostProcessingOptions should throw IllegalArgumentException for invalid options") { + val prompt = new OpenAIPrompt() + intercept[IllegalArgumentException] { + prompt.setPostProcessingOptions(Map("invalidOption" -> "value")) + } + } + + test("setPostProcessingOptions should validate regex options contain regexGroup key") { + val prompt = new OpenAIPrompt() + intercept[IllegalArgumentException] { + prompt.setPostProcessingOptions(Map("regex" -> ".*")) + } + } + + test("setPostProcessingOptions should validate existing postProcessing value") { + val prompt = new OpenAIPrompt() + prompt.setPostProcessing("csv") + intercept[IllegalArgumentException] { + prompt.setPostProcessingOptions(Map("jsonSchema" -> "schema")) + } + } + override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = { super.assertDFEq(df1.drop("out", "outParsed"), df2.drop("out", "outParsed"))(eq) } override def testObjects(): Seq[TestObject[OpenAIPrompt]] = { + val prompt = createPromptInstance(deploymentName) val testPrompt = prompt .setPromptTemplate("{text} rhymes with ") @@ -162,4 +344,12 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK override def reader: MLReadable[_] = OpenAIPrompt + private def createPromptInstance(deploymentName: String): OpenAIPrompt = { + new OpenAIPrompt() + .setSubscriptionKey(openAIAPIKey) + .setDeploymentName(deploymentName) + .setCustomServiceName(openAIServiceName) + .setOutputCol("outParsed") + .setTemperature(0) + } }