Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fixing OpenAIPrompt class to be able to work with older models like text-davinci-003 #2319

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import com.microsoft.azure.synapse.ml.fabric.{FabricClient, OpenAIFabricSetting}
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.param.{GlobalKey, GlobalParams, ServiceParam}
import com.microsoft.azure.synapse.ml.services._
import org.apache.http.entity.AbstractHttpEntity
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.param.{Param, Params}
import org.apache.spark.sql.Row
Expand Down Expand Up @@ -302,7 +303,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
}

abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String)
with HasOpenAISharedParams with OpenAIFabricSetting {
with HasOpenAISharedParams with OpenAIFabricSetting with HasCognitiveServiceInput {
setDefault(timeout -> 360.0)

private def usingDefaultOpenAIEndpoint(): Boolean = {
Expand All @@ -315,4 +316,6 @@ abstract class OpenAIServicesBase(override val uid: String) extends CognitiveSer
}
super.getInternalTransformer(schema)
}

private[openai] def prepareEntityAIService: Row => Option[AbstractHttpEntity] = this.prepareEntity
FarrukhMasud marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
df.map({ row =>
val originalOutput = Option(row.getAs[Row](outputCol))
.map({ row => openAIResultFromRow(row).choices.head })
val isFiltered = originalOutput
FarrukhMasud marked this conversation as resolved.
Show resolved Hide resolved
.map(output => Option(output.message.content).isEmpty)
.getOrElse(false)
val isFiltered = originalOutput.exists(output => Option(output.message.content).isEmpty)

if (isFiltered) {
val updatedRowSeq = row.toSeq.updated(
Expand Down Expand Up @@ -208,14 +206,17 @@ class OpenAIPrompt(override val uid: String) extends Transformer
"text-ada-001", "text-babbage-001", "text-curie-001", "text-davinci-002", "text-davinci-003",
"code-cushman-001", "code-davinci-002")

private def openAICompletion: OpenAIServicesBase = {
/**
* This method is made available in the `openai` package for testing purposes.
*/
private[openai] def openAICompletion: OpenAIServicesBase = {

val completion: OpenAIServicesBase =
if (legacyModels.contains(getDeploymentName)) {
new OpenAICompletion()
}
else {
new OpenAIChatCompletion()
new OpenAIChatCompletion().setMessagesCol(getMessagesCol)
}
// apply all parameters
extractParamMap().toSeq
Expand All @@ -225,23 +226,15 @@ class OpenAIPrompt(override val uid: String) extends Transformer
completion
}

override protected def prepareEntity: Row => Option[AbstractHttpEntity] = {
r =>
openAICompletion match {
case chatCompletion: OpenAIChatCompletion =>
chatCompletion.prepareEntity(r)
case completion: OpenAICompletion =>
completion.prepareEntity(r)
}
}
override protected def prepareEntity: Row => Option[AbstractHttpEntity] = openAICompletion.prepareEntityAIService

private def getParser: OutputParser = {
val opts = getPostProcessingOptions

getPostProcessing.toLowerCase match {
case "csv" => new DelimiterParser(opts.getOrElse("delimiter", ","))
case "json" => new JsonParser(opts.get("jsonSchema").get, Map.empty)
case "regex" => new RegexParser(opts.get("regex").get, opts.get("regexGroup").get.toInt)
case "json" => new JsonParser(opts("jsonSchema"), Map.empty)
case "regex" => new RegexParser(opts("regex"), opts("regexGroup").toInt)
case "" => new PassThroughParser()
case _ => throw new IllegalArgumentException(s"Unsupported postProcessing type: '$getPostProcessing'")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,42 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
super.beforeAll()
}

lazy val prompt: OpenAIPrompt = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentName)
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")
.setTemperature(0)
private def createPromptInstance(deploymentName: String) = {
new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentName)
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")
.setTemperature(0)
}


test("Validating OpenAICompletion and OpenAIChatCompletion for correctModels") {
FarrukhMasud marked this conversation as resolved.
Show resolved Hide resolved
val generation1Models: Seq[String] = Seq("ada",
"babbage",
"curie",
"davinci",
"text-ada-001",
"text-babbage-001",
"text-curie-001",
"text-davinci-002",
"text-davinci-003",
"code-cushman-001",
"code-davinci-002")

val generation2Models: Seq[String] = Seq("gpt-35-turbo",
"gpt-4-turbo",
"gpt-4o")

def validateModelType(models: Seq[String], expectedType: Class[_]): Unit = {
models.foreach { model =>
assert(createPromptInstance(model).openAICompletion.getClass == expectedType)
}
}

validateModelType(generation1Models, classOf[OpenAICompletion])
validateModelType(generation2Models, classOf[OpenAIChatCompletion])
}

lazy val df: DataFrame = Seq(
("apple", "fruits"),
Expand All @@ -36,8 +66,8 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
).toDF("text", "category")

test("RAI Usage") {
val prompt: OpenAIPrompt = createPromptInstance(deploymentNameGpt4)
val result = prompt
.setDeploymentName(deploymentNameGpt4)
.setPromptTemplate("Tell me about a graphically disgusting movie in detail")
.transform(df)
.select(prompt.getErrorCol)
Expand All @@ -46,6 +76,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

test("Basic Usage") {
val prompt: OpenAIPrompt = createPromptInstance(deploymentName)
val nonNullCount = prompt
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
.setPostProcessing("csv")
Expand All @@ -58,6 +89,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

test("Basic Usage JSON") {
val prompt: OpenAIPrompt = createPromptInstance(deploymentName)
prompt.setPromptTemplate(
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
Expand All @@ -72,14 +104,9 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
.foreach(r => assert(r.getStruct(0).getString(0).nonEmpty))
}

lazy val promptGpt4: OpenAIPrompt = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentNameGpt4)
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")
.setTemperature(0)

test("Basic Usage - Gpt 4") {
val promptGpt4: OpenAIPrompt = createPromptInstance(deploymentNameGpt4)
val nonNullCount = promptGpt4
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
.setPostProcessing("csv")
Expand All @@ -92,6 +119,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

test("Basic Usage JSON - Gpt 4") {
val promptGpt4: OpenAIPrompt = createPromptInstance(deploymentNameGpt4)
promptGpt4.setPromptTemplate(
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
Expand Down Expand Up @@ -121,6 +149,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

test("Setting and Keeping Messages Col - Gpt 4") {
val promptGpt4: OpenAIPrompt = createPromptInstance(deploymentNameGpt4)
promptGpt4.setMessagesCol("messages")
.setDropPrompt(false)
.setPromptTemplate(
Expand Down Expand Up @@ -208,6 +237,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

override def testObjects(): Seq[TestObject[OpenAIPrompt]] = {
val prompt: OpenAIPrompt = createPromptInstance(deploymentName)
val testPrompt = prompt
.setPromptTemplate("{text} rhymes with ")

Expand Down
Loading