diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index 5435fcdbf3..b2c29b183b 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -8,6 +8,7 @@ import com.microsoft.azure.synapse.ml.fabric.{FabricClient, OpenAIFabricSetting} import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails import com.microsoft.azure.synapse.ml.param.ServiceParam import com.microsoft.azure.synapse.ml.services._ +import org.apache.http.entity.AbstractHttpEntity import org.apache.spark.ml.PipelineModel import org.apache.spark.ml.param.{Param, Params} import org.apache.spark.sql.Row @@ -294,7 +295,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { } abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String) - with HasOpenAISharedParams with OpenAIFabricSetting { + with HasOpenAISharedParams with OpenAIFabricSetting with HasCognitiveServiceInput { setDefault(timeout -> 360.0) private def usingDefaultOpenAIEndpoint(): Boolean = { @@ -307,4 +308,6 @@ abstract class OpenAIServicesBase(override val uid: String) extends CognitiveSer } super.getInternalTransformer(schema) } + + private[openai] def prepareEntityAIService: Row => Option[AbstractHttpEntity] = this.prepareEntity } diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index a43f3ffe3a..7db574240c 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -106,9 +106,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer df.map({ row => val originalOutput = Option(row.getAs[Row](outputCol)) .map({ row => openAIResultFromRow(row).choices.head }) - val isFiltered = originalOutput - .map(output => Option(output.message.content).isEmpty) - .getOrElse(false) + val isFiltered = originalOutput.exists(output => Option(output.message.content).isEmpty) if (isFiltered) { val updatedRowSeq = row.toSeq.updated( @@ -183,7 +181,10 @@ class OpenAIPrompt(override val uid: String) extends Transformer "text-ada-001", "text-babbage-001", "text-curie-001", "text-davinci-002", "text-davinci-003", "code-cushman-001", "code-davinci-002") - private def openAICompletion: OpenAIServicesBase = { + /** + * This method is made available in the `openai` package for testing purposes. + */ + private[openai] def openAICompletion: OpenAIServicesBase = { val completion: OpenAIServicesBase = if (legacyModels.contains(getDeploymentName)) { @@ -194,29 +195,22 @@ class OpenAIPrompt(override val uid: String) extends Transformer } // apply all parameters extractParamMap().toSeq + .filter(p => completion.hasParam(p.param.name)) .filter(p => !localParamNames.contains(p.param.name)) .foreach(p => completion.set(completion.getParam(p.param.name), p.value)) completion } - override protected def prepareEntity: Row => Option[AbstractHttpEntity] = { - r => - openAICompletion match { - case chatCompletion: OpenAIChatCompletion => - chatCompletion.prepareEntity(r) - case completion: OpenAICompletion => - completion.prepareEntity(r) - } - } + override protected def prepareEntity: Row => Option[AbstractHttpEntity] = openAICompletion.prepareEntityAIService private def getParser: OutputParser = { val opts = getPostProcessingOptions getPostProcessing.toLowerCase match { case "csv" => new DelimiterParser(opts.getOrElse("delimiter", ",")) - case "json" => new JsonParser(opts.get("jsonSchema").get, Map.empty) - case "regex" => new RegexParser(opts.get("regex").get, opts.get("regexGroup").get.toInt) + case "json" => new JsonParser(opts("jsonSchema"), Map.empty) + case "regex" => new RegexParser(opts("regex"), opts("regexGroup").toInt) case "" => new PassThroughParser() case _ => throw new IllegalArgumentException(s"Unsupported postProcessing type: '$getPostProcessing'") } diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala index 9409d2f358..1d883f5f01 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala @@ -21,12 +21,42 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK super.beforeAll() } - lazy val prompt: OpenAIPrompt = new OpenAIPrompt() - .setSubscriptionKey(openAIAPIKey) - .setDeploymentName(deploymentName) - .setCustomServiceName(openAIServiceName) - .setOutputCol("outParsed") - .setTemperature(0) + private def createPromptInstance(deploymentName: String) = { + new OpenAIPrompt() + .setSubscriptionKey(openAIAPIKey) + .setDeploymentName(deploymentName) + .setCustomServiceName(openAIServiceName) + .setOutputCol("outParsed") + .setTemperature(0) + } + + + test("Validating OpenAICompletion and OpenAIChatCompletion for correctModels") { + val generation1Models: Seq[String] = Seq("ada", + "babbage", + "curie", + "davinci", + "text-ada-001", + "text-babbage-001", + "text-curie-001", + "text-davinci-002", + "text-davinci-003", + "code-cushman-001", + "code-davinci-002") + + val generation2Models: Seq[String] = Seq("gpt-35-turbo", + "gpt-4-turbo", + "gpt-4o") + + def validateModelType(models: Seq[String], expectedType: Class[_]): Unit = { + models.foreach { model => + assert(createPromptInstance(model).openAICompletion.getClass == expectedType) + } + } + + validateModelType(generation1Models, classOf[OpenAICompletion]) + validateModelType(generation2Models, classOf[OpenAIChatCompletion]) + } lazy val df: DataFrame = Seq( ("apple", "fruits"), @@ -36,8 +66,8 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK ).toDF("text", "category") test("RAI Usage") { + val prompt: OpenAIPrompt = createPromptInstance(deploymentNameGpt4) val result = prompt - .setDeploymentName(deploymentNameGpt4) .setPromptTemplate("Tell me about a graphically disgusting movie in detail") .transform(df) .select(prompt.getErrorCol) @@ -46,6 +76,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Basic Usage") { + val prompt: OpenAIPrompt = createPromptInstance(deploymentName) val nonNullCount = prompt .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") .setPostProcessing("csv") @@ -58,6 +89,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Basic Usage JSON") { + val prompt: OpenAIPrompt = createPromptInstance(deploymentName) prompt.setPromptTemplate( """Split a word into prefix and postfix a respond in JSON |Cherry: {{"prefix": "Che", "suffix": "rry"}} @@ -72,14 +104,9 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK .foreach(r => assert(r.getStruct(0).getString(0).nonEmpty)) } - lazy val promptGpt4: OpenAIPrompt = new OpenAIPrompt() - .setSubscriptionKey(openAIAPIKey) - .setDeploymentName(deploymentNameGpt4) - .setCustomServiceName(openAIServiceName) - .setOutputCol("outParsed") - .setTemperature(0) test("Basic Usage - Gpt 4") { + val promptGpt4: OpenAIPrompt = createPromptInstance(deploymentNameGpt4) val nonNullCount = promptGpt4 .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") .setPostProcessing("csv") @@ -92,6 +119,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Basic Usage JSON - Gpt 4") { + val promptGpt4: OpenAIPrompt = createPromptInstance(deploymentNameGpt4) promptGpt4.setPromptTemplate( """Split a word into prefix and postfix a respond in JSON |Cherry: {{"prefix": "Che", "suffix": "rry"}} @@ -107,6 +135,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Setting and Keeping Messages Col - Gpt 4") { + val promptGpt4: OpenAIPrompt = createPromptInstance(deploymentNameGpt4) promptGpt4.setMessagesCol("messages") .setDropPrompt(false) .setPromptTemplate( @@ -154,6 +183,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } override def testObjects(): Seq[TestObject[OpenAIPrompt]] = { + val prompt: OpenAIPrompt = createPromptInstance(deploymentName) val testPrompt = prompt .setPromptTemplate("{text} rhymes with ")