diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index 8d40f49ee6..5435fcdbf3 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -100,11 +100,12 @@ trait HasOpenAIEmbeddingParams extends HasOpenAISharedParams with HasAPIVersion } trait HasOpenAITextParams extends HasOpenAISharedParams { - val maxTokens: ServiceParam[Int] = new ServiceParam[Int]( this, "maxTokens", "The maximum number of tokens to generate. Has minimum of 0.", - isRequired = false) + isRequired = false){ + override val payloadName: String = "max_tokens" + } def getMaxTokens: Int = getScalarParam(maxTokens) @@ -149,7 +150,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { " So 0.1 means only the tokens comprising the top 10 percent probability mass are considered." + " We generally recommend using this or `temperature` but not both." + " Minimum of 0 and maximum of 1 allowed.", - isRequired = false) + isRequired = false) { + override val payloadName: String = "top_p" + } def getTopP: Double = getScalarParam(topP) @@ -178,7 +181,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { " So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens." + " If `logprobs` is 0, only the chosen tokens will have logprobs returned." + " Minimum of 0 and maximum of 100 allowed.", - isRequired = false) + isRequired = false) { + override val payloadName: String = "logprobs" + } def getLogProbs: Int = getScalarParam(logProbs) @@ -204,7 +209,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { val cacheLevel: ServiceParam[Int] = new ServiceParam[Int]( this, "cacheLevel", "can be used to disable any server-side caching, 0=no cache, 1=prompt prefix enabled, 2=full cache", - isRequired = false) + isRequired = false){ + override val payloadName: String = "cache_level" + } def getCacheLevel: Int = getScalarParam(cacheLevel) @@ -218,7 +225,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { this, "presencePenalty", "How much to penalize new tokens based on their existing frequency in the text so far." + " Decreases the likelihood of the model to repeat the same line verbatim. Has minimum of -2 and maximum of 2.", - isRequired = false) + isRequired = false){ + override val payloadName: String = "presence_penalty" + } def getPresencePenalty: Double = getScalarParam(presencePenalty) @@ -232,7 +241,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { this, "frequencyPenalty", "How much to penalize new tokens based on whether they appear in the text so far." + " Increases the likelihood of the model to talk about new topics.", - isRequired = false) + isRequired = false){ + override val payloadName: String = "frequency_penalty" + } def getFrequencyPenalty: Double = getScalarParam(frequencyPenalty) @@ -246,7 +257,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { this, "bestOf", "How many generations to create server side, and display only the best." + " Will not stream intermediate progress if best_of > 1. Has maximum value of 128.", - isRequired = false) + isRequired = false){ + override val payloadName: String = "best_of" + } def getBestOf: Int = getScalarParam(bestOf) @@ -256,24 +269,27 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { def setBestOfCol(v: String): this.type = setVectorParam(bestOf, v) + // list of shared text parameters. In method getOptionalParams, we will iterate over these parameters + // to compute the optional parameters. Since this list never changes, we can create it once and reuse it. + private val sharedTextParams = Seq( + maxTokens, + temperature, + topP, + user, + n, + echo, + stop, + cacheLevel, + presencePenalty, + frequencyPenalty, + bestOf, + logProbs + ) + private[ml] def getOptionalParams(r: Row): Map[String, Any] = { - Seq( - maxTokens, - temperature, - topP, - user, - n, - echo, - stop, - cacheLevel, - presencePenalty, - frequencyPenalty, - bestOf - ).flatMap(param => - getValueOpt(r, param).map(v => (GenerationUtils.camelToSnake(param.name), v)) - ).++(Seq( - getValueOpt(r, logProbs).map(v => ("logprobs", v)) - ).flatten).toMap + sharedTextParams.flatMap { param => + getValueOpt(r, param).map { value => param.payloadName -> value } + }.toMap } }