Skip to content

Commit

Permalink
fixing failing unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FMasudMsft committed Nov 22, 2024
1 parent b4e4d44 commit 7435ae8
Showing 1 changed file with 2 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer

private val localParamNames = Seq(
"promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt", "dropMessages",
"systemPrompt")
"systemPrompt", "messagesCol")

private def addRAIErrors(df: DataFrame, errorCol: String, outputCol: String): DataFrame = {
val openAIResultFromRow = ChatCompletionResponse.makeFromRowConverter
Expand Down Expand Up @@ -191,11 +191,10 @@ class OpenAIPrompt(override val uid: String) extends Transformer
new OpenAICompletion()
}
else {
new OpenAIChatCompletion()
new OpenAIChatCompletion().setMessagesCol(getMessagesCol)
}
// apply all parameters
extractParamMap().toSeq
.filter(p => completion.hasParam(p.param.name))
.filter(p => !localParamNames.contains(p.param.name))
.foreach(p => completion.set(completion.getParam(p.param.name), p.value))

Expand Down

0 comments on commit 7435ae8

Please sign in to comment.