From 7435ae889382e49cedc2f89d04fa7f312bccf62c Mon Sep 17 00:00:00 2001 From: Farrukh Masud Date: Fri, 22 Nov 2024 14:39:54 -0800 Subject: [PATCH] fixing failing unit tests --- .../azure/synapse/ml/services/openai/OpenAIPrompt.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index 7db574240c..fcc856009e 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -99,7 +99,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer private val localParamNames = Seq( "promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt", "dropMessages", - "systemPrompt") + "systemPrompt", "messagesCol") private def addRAIErrors(df: DataFrame, errorCol: String, outputCol: String): DataFrame = { val openAIResultFromRow = ChatCompletionResponse.makeFromRowConverter @@ -191,11 +191,10 @@ class OpenAIPrompt(override val uid: String) extends Transformer new OpenAICompletion() } else { - new OpenAIChatCompletion() + new OpenAIChatCompletion().setMessagesCol(getMessagesCol) } // apply all parameters extractParamMap().toSeq - .filter(p => completion.hasParam(p.param.name)) .filter(p => !localParamNames.contains(p.param.name)) .foreach(p => completion.set(completion.getParam(p.param.name), p.value))