Skip to content

Commit

Permalink
In this Changeset, I am:
Browse files Browse the repository at this point in the history
1. Ensuring that new prompt instances are created for each test. This will eliminate chance of one test interfering with another test.
2. Creating test to validate that for older model deployments, OpenAiCompletion instance is created in OpenAIPrompt class.
3. Fixing the bug that caused creation of OpenAICompletion object to fail.
4. Cleaning some code in OpenAIPrompt class.
  • Loading branch information
FMasudMsft committed Nov 22, 2024
1 parent 08aab6a commit b4e4d44
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import com.microsoft.azure.synapse.ml.fabric.{FabricClient, OpenAIFabricSetting}
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.services._
import org.apache.http.entity.AbstractHttpEntity
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.param.{Param, Params}
import org.apache.spark.sql.Row
Expand Down Expand Up @@ -294,7 +295,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
}

abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String)
with HasOpenAISharedParams with OpenAIFabricSetting {
with HasOpenAISharedParams with OpenAIFabricSetting with HasCognitiveServiceInput {
setDefault(timeout -> 360.0)

private def usingDefaultOpenAIEndpoint(): Boolean = {
Expand All @@ -307,4 +308,6 @@ abstract class OpenAIServicesBase(override val uid: String) extends CognitiveSer
}
super.getInternalTransformer(schema)
}

private[openai] def prepareEntityAIService: Row => Option[AbstractHttpEntity] = this.prepareEntity
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
df.map({ row =>
val originalOutput = Option(row.getAs[Row](outputCol))
.map({ row => openAIResultFromRow(row).choices.head })
val isFiltered = originalOutput
.map(output => Option(output.message.content).isEmpty)
.getOrElse(false)
val isFiltered = originalOutput.exists(output => Option(output.message.content).isEmpty)

if (isFiltered) {
val updatedRowSeq = row.toSeq.updated(
Expand Down Expand Up @@ -183,7 +181,10 @@ class OpenAIPrompt(override val uid: String) extends Transformer
"text-ada-001", "text-babbage-001", "text-curie-001", "text-davinci-002", "text-davinci-003",
"code-cushman-001", "code-davinci-002")

private def openAICompletion: OpenAIServicesBase = {
/**
* This method is made available in the `openai` package for testing purposes.
*/
private[openai] def openAICompletion: OpenAIServicesBase = {

val completion: OpenAIServicesBase =
if (legacyModels.contains(getDeploymentName)) {
Expand All @@ -194,29 +195,22 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}
// apply all parameters
extractParamMap().toSeq
.filter(p => completion.hasParam(p.param.name))
.filter(p => !localParamNames.contains(p.param.name))
.foreach(p => completion.set(completion.getParam(p.param.name), p.value))

completion
}

override protected def prepareEntity: Row => Option[AbstractHttpEntity] = {
r =>
openAICompletion match {
case chatCompletion: OpenAIChatCompletion =>
chatCompletion.prepareEntity(r)
case completion: OpenAICompletion =>
completion.prepareEntity(r)
}
}
override protected def prepareEntity: Row => Option[AbstractHttpEntity] = openAICompletion.prepareEntityAIService

private def getParser: OutputParser = {
val opts = getPostProcessingOptions

getPostProcessing.toLowerCase match {
case "csv" => new DelimiterParser(opts.getOrElse("delimiter", ","))
case "json" => new JsonParser(opts.get("jsonSchema").get, Map.empty)
case "regex" => new RegexParser(opts.get("regex").get, opts.get("regexGroup").get.toInt)
case "json" => new JsonParser(opts("jsonSchema"), Map.empty)
case "regex" => new RegexParser(opts("regex"), opts("regexGroup").toInt)
case "" => new PassThroughParser()
case _ => throw new IllegalArgumentException(s"Unsupported postProcessing type: '$getPostProcessing'")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,42 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
super.beforeAll()
}

lazy val prompt: OpenAIPrompt = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentName)
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")
.setTemperature(0)
private def createPromptInstance(deploymentName: String) = {
new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentName)
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")
.setTemperature(0)
}


test("Validating OpenAICompletion and OpenAIChatCompletion for correctModels") {
val generation1Models: Seq[String] = Seq("ada",
"babbage",
"curie",
"davinci",
"text-ada-001",
"text-babbage-001",
"text-curie-001",
"text-davinci-002",
"text-davinci-003",
"code-cushman-001",
"code-davinci-002")

val generation2Models: Seq[String] = Seq("gpt-35-turbo",
"gpt-4-turbo",
"gpt-4o")

def validateModelType(models: Seq[String], expectedType: Class[_]): Unit = {
models.foreach { model =>
assert(createPromptInstance(model).openAICompletion.getClass == expectedType)
}
}

validateModelType(generation1Models, classOf[OpenAICompletion])
validateModelType(generation2Models, classOf[OpenAIChatCompletion])
}

lazy val df: DataFrame = Seq(
("apple", "fruits"),
Expand All @@ -36,8 +66,8 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
).toDF("text", "category")

test("RAI Usage") {
val prompt: OpenAIPrompt = createPromptInstance(deploymentNameGpt4)
val result = prompt
.setDeploymentName(deploymentNameGpt4)
.setPromptTemplate("Tell me about a graphically disgusting movie in detail")
.transform(df)
.select(prompt.getErrorCol)
Expand All @@ -46,6 +76,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

test("Basic Usage") {
val prompt: OpenAIPrompt = createPromptInstance(deploymentName)
val nonNullCount = prompt
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
.setPostProcessing("csv")
Expand All @@ -58,6 +89,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

test("Basic Usage JSON") {
val prompt: OpenAIPrompt = createPromptInstance(deploymentName)
prompt.setPromptTemplate(
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
Expand All @@ -72,14 +104,9 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
.foreach(r => assert(r.getStruct(0).getString(0).nonEmpty))
}

lazy val promptGpt4: OpenAIPrompt = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentNameGpt4)
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")
.setTemperature(0)

test("Basic Usage - Gpt 4") {
val promptGpt4: OpenAIPrompt = createPromptInstance(deploymentNameGpt4)
val nonNullCount = promptGpt4
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
.setPostProcessing("csv")
Expand All @@ -92,6 +119,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

test("Basic Usage JSON - Gpt 4") {
val promptGpt4: OpenAIPrompt = createPromptInstance(deploymentNameGpt4)
promptGpt4.setPromptTemplate(
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
Expand All @@ -107,6 +135,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

test("Setting and Keeping Messages Col - Gpt 4") {
val promptGpt4: OpenAIPrompt = createPromptInstance(deploymentNameGpt4)
promptGpt4.setMessagesCol("messages")
.setDropPrompt(false)
.setPromptTemplate(
Expand Down Expand Up @@ -154,6 +183,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

override def testObjects(): Seq[TestObject[OpenAIPrompt]] = {
val prompt: OpenAIPrompt = createPromptInstance(deploymentName)
val testPrompt = prompt
.setPromptTemplate("{text} rhymes with ")

Expand Down

0 comments on commit b4e4d44

Please sign in to comment.