Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adding code for sending parameter "response_format" as request payload #2317

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {

// list of shared text parameters. In method getOptionalParams, we will iterate over these parameters
// to compute the optional parameters. Since this list never changes, we can create it once and reuse it.
private val sharedTextParams = Seq(
private[openai] val sharedTextParams: Seq[ServiceParam[_]] = Seq(
maxTokens,
temperature,
topP,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.services.{HasCognitiveServiceInput, HasInternalJsonOutputParser}
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
Expand All @@ -18,8 +19,84 @@ import scala.language.existentials

object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion]

object OpenAIResponseFormat extends Enumeration {
case class ResponseFormat(name: String, prompt: String) extends super.Val(name)
val TEXT: ResponseFormat = ResponseFormat("text", "Output must be in text format")
val JSON: ResponseFormat = ResponseFormat("json_object", "Output must be in JSON format")
}


trait HasOpenAITextParamsExtended extends HasOpenAITextParams {
val responseFormat: ServiceParam[Map[String, String]] = new ServiceParam[Map[String, String]](
this,
"responseFromat",
"Response format for the completion. Can be 'json_object' or 'text'.",
isRequired = false) {
override val payloadName: String = "response_format"
}

def getResponseFormat: Map[String, String] = getScalarParam(responseFormat)

def setResponseFormat(value: Map[String, String]): this.type = {
if (!OpenAIResponseFormat.values.map(_.asInstanceOf[OpenAIResponseFormat.ResponseFormat].name)
.contains(value("type"))) {
throw new IllegalArgumentException("Response format must be 'text' or 'json_object'")
}
setScalarParam(responseFormat, value)
}

def setResponseFormat(value: String): this.type = {
if (value.isEmpty) {
this
} else {
val normalizedValue = value.toLowerCase match {
case "json" => "json_object"
case other => other
}
// Validate the normalized value using the OpenAIResponseFormat enum
if (!OpenAIResponseFormat.values
.map(_.asInstanceOf[OpenAIResponseFormat.ResponseFormat].name)
.contains(normalizedValue)) {
throw new IllegalArgumentException("Response format must be valid for OpenAI API. " +
"Currently supported formats are " + OpenAIResponseFormat.values
.map(_.asInstanceOf[OpenAIResponseFormat.ResponseFormat].name)
.mkString(", "))
}

setScalarParam(responseFormat, Map("type" -> normalizedValue))
}
}

def setResponseFormat(value: OpenAIResponseFormat.ResponseFormat): this.type = {
// this method should throw an excption if the openAiCompletion is not a ChatCompletion
this.setResponseFormat(value.name)
}

def getResponseFormatCol: String = getVectorParam(responseFormat)

def setResponseFormatCol(value: String): this.type = setVectorParam(responseFormat, value)


// Recreating the sharedTextParams sequence to include additional parameter responseFormat
override private[openai] val sharedTextParams: Seq[ServiceParam[_]] = Seq(
maxTokens,
temperature,
topP,
user,
n,
echo,
stop,
cacheLevel,
presencePenalty,
frequencyPenalty,
bestOf,
logProbs,
responseFormat // Additional parameter
)
}

class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAITextParams with HasMessagesInput with HasCognitiveServiceInput
with HasOpenAITextParamsExtended with HasMessagesInput with HasCognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

Expand Down Expand Up @@ -55,11 +132,20 @@ class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(
override def responseDataType: DataType = ChatCompletionResponse.schema

private[this] def getStringEntity(messages: Seq[Row], optionalParams: Map[String, Any]): StringEntity = {
val mappedMessages: Seq[Map[String, String]] = messages.map { m =>
var mappedMessages: Seq[Map[String, String]] = messages.map { m =>
Copy link
Collaborator

@mhamilton723 mhamilton723 Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please no vars in code unless required. In scala any kind of mutability is highly discouraged. Here in this case you can return either the original or the original plus your addition in the if statement

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made mappedMessages val again

Seq("role", "content", "name").map(n =>
n -> Option(m.getAs[String](n))
).toMap.filter(_._2.isDefined).mapValues(_.get)
}

// if the optionalParams contains "response_format" key, and it's value contains "json_object",
// then we need to add a message to instruct openAI to return the response in JSON format
if (optionalParams.get("response_format")
.exists(_.asInstanceOf[Map[String, String]]("type")
.contentEquals("json_object"))) {
mappedMessages :+= Map("role" -> "system", "content" -> OpenAIResponseFormat.JSON.prompt)
}

val fullPayload = optionalParams.updated("messages", mappedMessages)
new StringEntity(fullPayload.toJson.compactPrint, ContentType.APPLICATION_JSON)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,37 @@ class OpenAIPrompt(override val uid: String) extends Transformer

def getPostProcessingOptions: Map[String, String] = $(postProcessingOptions)

def setPostProcessingOptions(value: Map[String, String]): this.type = set(postProcessingOptions, value)
def setPostProcessingOptions(value: Map[String, String]): this.type = {
// Helper method to validate that regex options contain the required "regexGroup" key
def validateRegexOptions(options: Map[String, String]): Unit = {
require(options.contains("regexGroup"), "regexGroup must be specified with regex")
}

// Helper method to set or validate the postProcessing parameter
def setOrValidatePostProcessing(expected: String): Unit = {
if (isSet(postProcessing)) {
require(getPostProcessing == expected, s"postProcessing must be '$expected'")
} else {
set(postProcessing, expected)
}
}

// Match on the keys in the provided value map to set the appropriate post-processing option
value match {
case v if v.contains("delimiter") =>
setOrValidatePostProcessing("csv")
case v if v.contains("jsonSchema") =>
setOrValidatePostProcessing("json")
case v if v.contains("regex") =>
validateRegexOptions(v)
setOrValidatePostProcessing("regex")
case _ =>
throw new IllegalArgumentException("Invalid post processing options")
}

// Set the postProcessingOptions parameter with the provided value map
set(postProcessingOptions, value)
}

def setPostProcessingOptions(v: java.util.HashMap[String, String]): this.type =
set(postProcessingOptions, v.asScala.toMap)
Expand All @@ -79,10 +109,43 @@ class OpenAIPrompt(override val uid: String) extends Transformer

def setSystemPrompt(value: String): this.type = set(systemPrompt, value)

val responseFormat = new Param[String](
this, "responseFormat", "The response format from the OpenAI API.")

def getResponseFormat: String = $(responseFormat)

def setResponseFormat(value: String): this.type = {
if (value.isEmpty) {
this
} else {
val normalizedValue = value.toLowerCase match {
case "json" => "json_object"
case other => other
}

// Validate the normalized value using the OpenAIResponseFormat enum
if (!OpenAIResponseFormat.values
.map(_.asInstanceOf[OpenAIResponseFormat.ResponseFormat].name)
.contains(normalizedValue)) {
throw new IllegalArgumentException("Response format must be valid for OpenAI API. " +
"Currently supported formats are " + OpenAIResponseFormat.values
.map(_.asInstanceOf[OpenAIResponseFormat.ResponseFormat].name)
.mkString(", "))
}

set(responseFormat, normalizedValue)
}
}

def setResponseFormat(value: OpenAIResponseFormat.ResponseFormat): this.type = {
this.setResponseFormat(value.name)
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code seems duplicated, can it be abstracted and shared?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am now using shared code.


private val defaultSystemPrompt = "You are an AI chatbot who wants to answer user's questions and complete tasks. " +
"Follow their instructions carefully and be brief if they don't say otherwise."

setDefault(
responseFormat -> "",
postProcessing -> "",
postProcessingOptions -> Map.empty,
outputCol -> (this.uid + "_output"),
Expand All @@ -99,16 +162,14 @@ class OpenAIPrompt(override val uid: String) extends Transformer

private val localParamNames = Seq(
"promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt", "dropMessages",
"systemPrompt")
"systemPrompt", "responseFormat")

private def addRAIErrors(df: DataFrame, errorCol: String, outputCol: String): DataFrame = {
val openAIResultFromRow = ChatCompletionResponse.makeFromRowConverter
df.map({ row =>
val originalOutput = Option(row.getAs[Row](outputCol))
.map({ row => openAIResultFromRow(row).choices.head })
val isFiltered = originalOutput
.map(output => Option(output.message.content).isEmpty)
.getOrElse(false)
val isFiltered = originalOutput.exists(output => Option(output.message.content).isEmpty)

if (isFiltered) {
val updatedRowSeq = row.toSeq.updated(
Expand Down Expand Up @@ -138,6 +199,9 @@ class OpenAIPrompt(override val uid: String) extends Transformer
})
completion match {
case chatCompletion: OpenAIChatCompletion =>
if (isSet(responseFormat)) {
chatCompletion.setResponseFormat(getResponseFormat)
}
val messageColName = getMessagesCol
val dfTemplated = df.withColumn(messageColName, createMessagesUDF(promptCol))
val completionNamed = chatCompletion.setMessagesCol(messageColName)
Expand All @@ -158,6 +222,9 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}

case completion: OpenAICompletion =>
if (isSet(responseFormat)) {
throw new IllegalArgumentException("responseFormat is not supported for OpenAICompletion")
}
val promptColName = df.withDerivativeCol("prompt")
val dfTemplated = df.withColumn(promptColName, promptCol)
val completionNamed = completion.setPromptCol(promptColName)
Expand Down Expand Up @@ -215,8 +282,8 @@ class OpenAIPrompt(override val uid: String) extends Transformer

getPostProcessing.toLowerCase match {
case "csv" => new DelimiterParser(opts.getOrElse("delimiter", ","))
case "json" => new JsonParser(opts.get("jsonSchema").get, Map.empty)
case "regex" => new RegexParser(opts.get("regex").get, opts.get("regexGroup").get.toInt)
case "json" => new JsonParser(opts("jsonSchema"), Map.empty)
case "regex" => new RegexParser(opts("regex"), opts("regexGroup").toInt)
case "" => new PassThroughParser()
case _ => throw new IllegalArgumentException(s"Unsupported postProcessing type: '$getPostProcessing'")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, Transformer
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.{DataFrame, Row}
import org.scalactic.Equality
import org.scalatest.matchers.must.Matchers.{an, be}
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion] with OpenAIAPIKey with Flaky {

Expand Down Expand Up @@ -174,6 +176,45 @@ class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion]
testCompletion(customEndpointCompletion, goodDf)
}

test("setResponseFormat should set the response format correctly") {
completion.setResponseFormat("text")
completion.getResponseFormat shouldEqual Map("type" -> "text")

completion.setResponseFormat("tExT")
completion.getResponseFormat shouldEqual Map("type" -> "text")

completion.setResponseFormat("json")
completion.getResponseFormat shouldEqual Map("type" -> "json_object")

completion.setResponseFormat("JSON")
completion.getResponseFormat shouldEqual Map("type" -> "json_object")

completion.setResponseFormat("json_object")
completion.getResponseFormat shouldEqual Map("type" -> "json_object")

completion.setResponseFormat("Json_ObjeCt")
completion.getResponseFormat shouldEqual Map("type" -> "json_object")
}

test("setResponseFormat should throw an exception for invalid response format") {
an[IllegalArgumentException] should be thrownBy {
completion.setResponseFormat("invalid_format")
}
}

test("setResponseFormat with ResponseFormat should set the response format correctly") {
completion.setResponseFormat(OpenAIResponseFormat.TEXT)
completion.getResponseFormat shouldEqual Map("type" -> "text")

completion.setResponseFormat(OpenAIResponseFormat.JSON)
completion.getResponseFormat shouldEqual Map("type" -> "json_object")
}

test("setResponseFormatCol should set the response format column correctly") {
completion.setResponseFormatCol("response_format_col")
completion.getResponseFormatCol shouldEqual "response_format_col"
}

def testCompletion(completion: OpenAIChatCompletion, df: DataFrame, requiredLength: Int = 10): Unit = {
val fromRow = ChatCompletionResponse.makeFromRowConverter
completion.transform(df).collect().foreach(r =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ trait OpenAIAPIKey {
lazy val deploymentName: String = "gpt-35-turbo"
lazy val modelName: String = "gpt-35-turbo"
lazy val deploymentNameGpt4: String = "gpt-4"
lazy val deploymentNameGpt4o: String = "gpt-4o"
lazy val modelNameGpt4: String = "gpt-4"
}

Expand Down
Loading
Loading