diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index 49db4911f2..332f83bbb5 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -60,7 +60,32 @@ class OpenAIPrompt(override val uid: String) extends Transformer def getPostProcessingOptions: Map[String, String] = $(postProcessingOptions) - def setPostProcessingOptions(value: Map[String, String]): this.type = set(postProcessingOptions, value) + def setPostProcessingOptions(value: Map[String, String]): this.type = { + // Helper method to set or validate the postProcessing parameter + def setOrValidatePostProcessing(expected: String): Unit = { + if (isSet(postProcessing)) { + require(getPostProcessing == expected, s"postProcessing must be '$expected'") + } else { + set(postProcessing, expected) + } + } + + // Match on the keys in the provided value map to set the appropriate post-processing option + value match { + case v if v.contains("delimiter") => + setOrValidatePostProcessing("csv") + case v if v.contains("jsonSchema") => + setOrValidatePostProcessing("json") + case v if v.contains("regex") => + require(v.contains("regexGroup"), "regexGroup must be specified with regex") + setOrValidatePostProcessing("regex") + case _ => + throw new IllegalArgumentException("Invalid post processing options") + } + + // Set the postProcessingOptions parameter with the provided value map + set(postProcessingOptions, value) + } def setPostProcessingOptions(v: java.util.HashMap[String, String]): this.type = set(postProcessingOptions, v.asScala.toMap) diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala index 9409d2f358..ea24669f04 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala @@ -106,6 +106,20 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK .foreach(r => assert(r.getStruct(0).getString(0).nonEmpty)) } + test("Basic Usage JSON - Gpt 4 without explicit post-processing") { + promptGpt4.setPromptTemplate( + """Split a word into prefix and postfix a respond in JSON + |Cherry: {{"prefix": "Che", "suffix": "rry"}} + |{text}: + |""".stripMargin) + .setPostProcessingOptions(Map("jsonSchema" -> "prefix STRING, suffix STRING")) + .transform(df) + .select("outParsed") + .where(col("outParsed").isNotNull) + .collect() + .foreach(r => assert(r.getStruct(0).getString(0).nonEmpty)) + } + test("Setting and Keeping Messages Col - Gpt 4") { promptGpt4.setMessagesCol("messages") .setDropPrompt(false) @@ -149,6 +163,46 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK .count(r => Option(r.getSeq[String](0)).isDefined) } + test("setPostProcessingOptions should set postProcessing to 'csv' for delimiter option") { + val prompt = new OpenAIPrompt() + prompt.setPostProcessingOptions(Map("delimiter" -> ",")) + assert(prompt.getPostProcessing == "csv") + } + + test("setPostProcessingOptions should set postProcessing to 'json' for jsonSchema option") { + val prompt = new OpenAIPrompt() + prompt.setPostProcessingOptions(Map("jsonSchema" -> "schema")) + assert(prompt.getPostProcessing == "json") + } + + test("setPostProcessingOptions should set postProcessing to 'regex' for regex option") { + val prompt = new OpenAIPrompt() + prompt.setPostProcessingOptions(Map("regex" -> ".*", "regexGroup" -> "0")) + assert(prompt.getPostProcessing == "regex") + } + + test("setPostProcessingOptions should throw IllegalArgumentException for invalid options") { + val prompt = new OpenAIPrompt() + intercept[IllegalArgumentException] { + prompt.setPostProcessingOptions(Map("invalidOption" -> "value")) + } + } + + test("setPostProcessingOptions should validate regex options contain regexGroup key") { + val prompt = new OpenAIPrompt() + intercept[IllegalArgumentException] { + prompt.setPostProcessingOptions(Map("regex" -> ".*")) + } + } + + test("setPostProcessingOptions should validate existing postProcessing value") { + val prompt = new OpenAIPrompt() + prompt.setPostProcessing("csv") + intercept[IllegalArgumentException] { + prompt.setPostProcessingOptions(Map("jsonSchema" -> "schema")) + } + } + override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = { super.assertDFEq(df1.drop("out", "outParsed"), df2.drop("out", "outParsed"))(eq) }