From 478e8a24812195d1cef32035ee865dc72e7d7bf9 Mon Sep 17 00:00:00 2001 From: Mark Hamilton Date: Fri, 6 Sep 2024 11:54:24 -0400 Subject: [PATCH] fix: fix openai prompt behavior on RAI errors (#2279) --- .../ml/services/openai/OpenAIPrompt.scala | 45 ++++++++++++++----- .../ml/services/openai/OpenAISchemas.scala | 6 ++- .../services/openai/OpenAIPromptSuite.scala | 12 ++++- .../ml/io/http/SimpleHTTPTransformer.scala | 1 + 4 files changed, 52 insertions(+), 12 deletions(-) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index 47f2de6204..66b42833e1 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -13,6 +13,8 @@ import org.apache.http.entity.AbstractHttpEntity import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer} +import org.apache.spark.sql.Row.unapplySeq +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, functions => F, types => T} @@ -78,7 +80,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer def setSystemPrompt(value: String): this.type = set(systemPrompt, value) - private val defaultSystemPrompt = "You are an AI chatbot who wants to answer user's questions and complete tasks. " + + private val defaultSystemPrompt = "You are an AI chatbot who wants to answer user's questions and complete tasks. " + "Follow their instructions carefully and be brief if they don't say otherwise." setDefault( @@ -100,6 +102,27 @@ class OpenAIPrompt(override val uid: String) extends Transformer "promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt", "dropMessages", "systemPrompt") + private def addRAIErrors(df: DataFrame, errorCol: String, outputCol: String): DataFrame = { + val openAIResultFromRow = ChatCompletionResponse.makeFromRowConverter + df.map({ row => + val originalOutput = Option(row.getAs[Row](outputCol)) + .map({ row => openAIResultFromRow(row).choices.head }) + val isFiltered = originalOutput + .map(output => Option(output.message.content).isEmpty) + .getOrElse(false) + + if (isFiltered) { + val updatedRowSeq = row.toSeq.updated( + row.fieldIndex(errorCol), + Row(originalOutput.get.finish_reason, null) //scalastyle:ignore null + ) + Row.fromSeq(updatedRowSeq) + } else { + row + } + })(RowEncoder(df.schema)) + } + override def transform(dataset: Dataset[_]): DataFrame = { import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions._ @@ -120,8 +143,10 @@ class OpenAIPrompt(override val uid: String) extends Transformer val dfTemplated = df.withColumn(messageColName, createMessagesUDF(promptCol)) val completionNamed = chatCompletion.setMessagesCol(messageColName) - val results = completionNamed - .transform(dfTemplated) + val transformed = addRAIErrors( + completionNamed.transform(dfTemplated), chatCompletion.getErrorCol, chatCompletion.getOutputCol) + + val results = transformed .withColumn(getOutputCol, getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1) .getField("message").getField("content"))) @@ -155,19 +180,19 @@ class OpenAIPrompt(override val uid: String) extends Transformer }, dataset.columns.length) } - private val legacyModels = Set("ada","babbage", "curie", "davinci", + private val legacyModels = Set("ada", "babbage", "curie", "davinci", "text-ada-001", "text-babbage-001", "text-curie-001", "text-davinci-002", "text-davinci-003", "code-cushman-001", "code-davinci-002") private def openAICompletion: OpenAIServicesBase = { val completion: OpenAIServicesBase = - if (legacyModels.contains(getDeploymentName)) { - new OpenAICompletion() - } - else { - new OpenAIChatCompletion() - } + if (legacyModels.contains(getDeploymentName)) { + new OpenAICompletion() + } + else { + new OpenAIChatCompletion() + } // apply all parameters extractParamMap().toSeq .filter(p => !localParamNames.contains(p.param.name)) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAISchemas.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAISchemas.scala index 6148b67ca9..7d2e9238ae 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAISchemas.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAISchemas.scala @@ -41,11 +41,15 @@ case class OpenAIChatChoice(message: OpenAIMessage, index: Long, finish_reason: String) +case class OpenAIUsage(completion_tokens: Long, prompt_tokens: Long, total_tokens: Long) + case class ChatCompletionResponse(id: String, `object`: String, created: String, model: String, - choices: Seq[OpenAIChatChoice]) + choices: Seq[OpenAIChatChoice], + system_fingerprint: Option[String], + usage: Option[OpenAIUsage]) object ChatCompletionResponse extends SparkBindings[ChatCompletionResponse] diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala index 773cbe7547..9409d2f358 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala @@ -7,7 +7,7 @@ import com.microsoft.azure.synapse.ml.Secrets.getAccessToken import com.microsoft.azure.synapse.ml.core.test.base.Flaky import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing} import org.apache.spark.ml.util.MLReadable -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.col import org.scalactic.Equality @@ -35,6 +35,16 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK (null, "none") //scalastyle:ignore null ).toDF("text", "category") + test("RAI Usage") { + val result = prompt + .setDeploymentName(deploymentNameGpt4) + .setPromptTemplate("Tell me about a graphically disgusting movie in detail") + .transform(df) + .select(prompt.getErrorCol) + .collect().head.getAs[Row](0) + assert(Option(result).nonEmpty) + } + test("Basic Usage") { val nonNullCount = prompt .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala index c7248a73fd..e30c3a1970 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala @@ -30,6 +30,7 @@ trait HasErrorCol extends Params { } object ErrorUtils extends Serializable { + val ErrorSchema: StructType = new StructType() .add("response", StringType, nullable = true) .add("status", StatusLineData.schema, nullable = true)