diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index 7db574240c..fcc856009e 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -99,7 +99,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer private val localParamNames = Seq( "promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt", "dropMessages", - "systemPrompt") + "systemPrompt", "messagesCol") private def addRAIErrors(df: DataFrame, errorCol: String, outputCol: String): DataFrame = { val openAIResultFromRow = ChatCompletionResponse.makeFromRowConverter @@ -191,11 +191,10 @@ class OpenAIPrompt(override val uid: String) extends Transformer new OpenAICompletion() } else { - new OpenAIChatCompletion() + new OpenAIChatCompletion().setMessagesCol(getMessagesCol) } // apply all parameters extractParamMap().toSeq - .filter(p => completion.hasParam(p.param.name)) .filter(p => !localParamNames.contains(p.param.name)) .foreach(p => completion.set(completion.getParam(p.param.name), p.value))