Skip to content

Commit

Permalink
fix: fix openai prompt behavior on RAI errors (#2279)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhamilton723 authored Sep 6, 2024
1 parent 392f601 commit 478e8a2
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import org.apache.http.entity.AbstractHttpEntity
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.sql.Row.unapplySeq
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, functions => F, types => T}
Expand Down Expand Up @@ -78,7 +80,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer

def setSystemPrompt(value: String): this.type = set(systemPrompt, value)

private val defaultSystemPrompt = "You are an AI chatbot who wants to answer user's questions and complete tasks. " +
private val defaultSystemPrompt = "You are an AI chatbot who wants to answer user's questions and complete tasks. " +
"Follow their instructions carefully and be brief if they don't say otherwise."

setDefault(
Expand All @@ -100,6 +102,27 @@ class OpenAIPrompt(override val uid: String) extends Transformer
"promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt", "dropMessages",
"systemPrompt")

private def addRAIErrors(df: DataFrame, errorCol: String, outputCol: String): DataFrame = {
val openAIResultFromRow = ChatCompletionResponse.makeFromRowConverter
df.map({ row =>
val originalOutput = Option(row.getAs[Row](outputCol))
.map({ row => openAIResultFromRow(row).choices.head })
val isFiltered = originalOutput
.map(output => Option(output.message.content).isEmpty)
.getOrElse(false)

if (isFiltered) {
val updatedRowSeq = row.toSeq.updated(
row.fieldIndex(errorCol),
Row(originalOutput.get.finish_reason, null) //scalastyle:ignore null
)
Row.fromSeq(updatedRowSeq)
} else {
row
}
})(RowEncoder(df.schema))
}

override def transform(dataset: Dataset[_]): DataFrame = {
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions._

Expand All @@ -120,8 +143,10 @@ class OpenAIPrompt(override val uid: String) extends Transformer
val dfTemplated = df.withColumn(messageColName, createMessagesUDF(promptCol))
val completionNamed = chatCompletion.setMessagesCol(messageColName)

val results = completionNamed
.transform(dfTemplated)
val transformed = addRAIErrors(
completionNamed.transform(dfTemplated), chatCompletion.getErrorCol, chatCompletion.getOutputCol)

val results = transformed
.withColumn(getOutputCol,
getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1)
.getField("message").getField("content")))
Expand Down Expand Up @@ -155,19 +180,19 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}, dataset.columns.length)
}

private val legacyModels = Set("ada","babbage", "curie", "davinci",
private val legacyModels = Set("ada", "babbage", "curie", "davinci",
"text-ada-001", "text-babbage-001", "text-curie-001", "text-davinci-002", "text-davinci-003",
"code-cushman-001", "code-davinci-002")

private def openAICompletion: OpenAIServicesBase = {

val completion: OpenAIServicesBase =
if (legacyModels.contains(getDeploymentName)) {
new OpenAICompletion()
}
else {
new OpenAIChatCompletion()
}
if (legacyModels.contains(getDeploymentName)) {
new OpenAICompletion()
}
else {
new OpenAIChatCompletion()
}
// apply all parameters
extractParamMap().toSeq
.filter(p => !localParamNames.contains(p.param.name))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ case class OpenAIChatChoice(message: OpenAIMessage,
index: Long,
finish_reason: String)

case class OpenAIUsage(completion_tokens: Long, prompt_tokens: Long, total_tokens: Long)

case class ChatCompletionResponse(id: String,
`object`: String,
created: String,
model: String,
choices: Seq[OpenAIChatChoice])
choices: Seq[OpenAIChatChoice],
system_fingerprint: Option[String],
usage: Option[OpenAIUsage])

object ChatCompletionResponse extends SparkBindings[ChatCompletionResponse]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import com.microsoft.azure.synapse.ml.Secrets.getAccessToken
import com.microsoft.azure.synapse.ml.core.test.base.Flaky
import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing}
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.col
import org.scalactic.Equality

Expand Down Expand Up @@ -35,6 +35,16 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
(null, "none") //scalastyle:ignore null
).toDF("text", "category")

test("RAI Usage") {
val result = prompt
.setDeploymentName(deploymentNameGpt4)
.setPromptTemplate("Tell me about a graphically disgusting movie in detail")
.transform(df)
.select(prompt.getErrorCol)
.collect().head.getAs[Row](0)
assert(Option(result).nonEmpty)
}

test("Basic Usage") {
val nonNullCount = prompt
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ trait HasErrorCol extends Params {
}

object ErrorUtils extends Serializable {

val ErrorSchema: StructType = new StructType()
.add("response", StringType, nullable = true)
.add("status", StatusLineData.schema, nullable = true)
Expand Down

0 comments on commit 478e8a2

Please sign in to comment.