Skip to content

Commit

Permalink
Azure OpenAI: add missing properties/parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
LangChain4j committed Nov 7, 2024
1 parent 36ac085 commit 59df6a9
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
import com.azure.core.http.ProxyOptions;
import com.azure.core.util.Configuration;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.azure.AzureOpenAiChatModel;
import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel;
import dev.langchain4j.model.azure.AzureOpenAiImageModel;
import dev.langchain4j.model.azure.AzureOpenAiStreamingChatModel;
import dev.langchain4j.model.azure.AzureOpenAiTokenizer;
import dev.langchain4j.model.azure.*;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
Expand Down Expand Up @@ -37,18 +33,25 @@ AzureOpenAiChatModel openAiChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getChatModel();
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
.endpoint(chatModelProperties.getEndpoint())
.serviceVersion(chatModelProperties.getServiceVersion())
.apiKey(chatModelProperties.getApiKey())
.deploymentName(chatModelProperties.getDeploymentName())
// TODO inject tokenizer?
.maxTokens(chatModelProperties.getMaxTokens())
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP())
.maxTokens(chatModelProperties.getMaxTokens())
.logitBias(chatModelProperties.getLogitBias())
.user(chatModelProperties.getUser())
.stop(chatModelProperties.getStop())
.presencePenalty(chatModelProperties.getPresencePenalty())
.frequencyPenalty(chatModelProperties.getFrequencyPenalty())
.seed(chatModelProperties.getSeed())
.timeout(Duration.ofSeconds(chatModelProperties.getTimeout() == null ? 0 : chatModelProperties.getTimeout()))
.maxRetries(chatModelProperties.getMaxRetries())
.proxyOptions(ProxyOptions.fromConfiguration(Configuration.getGlobalConfiguration()))
.customHeaders(chatModelProperties.getCustomHeaders())
.logRequestsAndResponses(chatModelProperties.getLogRequestsAndResponses() != null && chatModelProperties.getLogRequestsAndResponses());
.logRequestsAndResponses(chatModelProperties.getLogRequestsAndResponses() != null && chatModelProperties.getLogRequestsAndResponses())
.userAgentSuffix(chatModelProperties.getUserAgentSuffix())
.customHeaders(chatModelProperties.getCustomHeaders());
if (chatModelProperties.getNonAzureApiKey() != null) {
builder.nonAzureApiKey(chatModelProperties.getNonAzureApiKey());
}
Expand All @@ -67,23 +70,29 @@ AzureOpenAiStreamingChatModel openAiStreamingChatModelByNonAzureApiKey(Propertie
return openAiStreamingChatModel(properties);
}


AzureOpenAiStreamingChatModel openAiStreamingChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
AzureOpenAiStreamingChatModel.Builder builder = AzureOpenAiStreamingChatModel.builder()
.endpoint(chatModelProperties.getEndpoint())
.serviceVersion(chatModelProperties.getServiceVersion())
.apiKey(chatModelProperties.getApiKey())
.deploymentName(chatModelProperties.getDeploymentName())
// TODO inject tokenizer?
.maxTokens(chatModelProperties.getMaxTokens())
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP())
.logitBias(chatModelProperties.getLogitBias())
.user(chatModelProperties.getUser())
.stop(chatModelProperties.getStop())
.maxTokens(chatModelProperties.getMaxTokens())
.presencePenalty(chatModelProperties.getPresencePenalty())
.frequencyPenalty(chatModelProperties.getFrequencyPenalty())
.seed(chatModelProperties.getSeed())
.timeout(Duration.ofSeconds(chatModelProperties.getTimeout() == null ? 0 : chatModelProperties.getTimeout()))
.maxRetries(chatModelProperties.getMaxRetries())
.proxyOptions(ProxyOptions.fromConfiguration(Configuration.getGlobalConfiguration()))
.customHeaders(chatModelProperties.getCustomHeaders())
.logRequestsAndResponses(chatModelProperties.getLogRequestsAndResponses() != null && chatModelProperties.getLogRequestsAndResponses());
.logRequestsAndResponses(chatModelProperties.getLogRequestsAndResponses() != null && chatModelProperties.getLogRequestsAndResponses())
.userAgentSuffix(chatModelProperties.getUserAgentSuffix())
.customHeaders(chatModelProperties.getCustomHeaders());
if (chatModelProperties.getNonAzureApiKey() != null) {
builder.nonAzureApiKey(chatModelProperties.getNonAzureApiKey());
}
Expand All @@ -106,15 +115,17 @@ AzureOpenAiEmbeddingModel openAiEmbeddingModel(Properties properties, Tokenizer
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
AzureOpenAiEmbeddingModel.Builder builder = AzureOpenAiEmbeddingModel.builder()
.endpoint(embeddingModelProperties.getEndpoint())
.serviceVersion(embeddingModelProperties.getServiceVersion())
.apiKey(embeddingModelProperties.getApiKey())
.deploymentName(embeddingModelProperties.getDeploymentName())
.maxRetries(embeddingModelProperties.getMaxRetries())
.tokenizer(tokenizer)
.timeout(Duration.ofSeconds(embeddingModelProperties.getTimeout() == null ? 0 : embeddingModelProperties.getTimeout()))
.maxRetries(embeddingModelProperties.getMaxRetries())
.proxyOptions(ProxyOptions.fromConfiguration(Configuration.getGlobalConfiguration()))
.customHeaders(embeddingModelProperties.getCustomHeaders())
.logRequestsAndResponses(embeddingModelProperties.getLogRequestsAndResponses() != null && embeddingModelProperties.getLogRequestsAndResponses());

.logRequestsAndResponses(embeddingModelProperties.getLogRequestsAndResponses() != null && embeddingModelProperties.getLogRequestsAndResponses())
.userAgentSuffix(embeddingModelProperties.getUserAgentSuffix())
.dimensions(embeddingModelProperties.getDimensions())
.customHeaders(embeddingModelProperties.getCustomHeaders());
if (embeddingModelProperties.getNonAzureApiKey() != null) {
builder.nonAzureApiKey(embeddingModelProperties.getNonAzureApiKey());
}
Expand All @@ -137,25 +148,26 @@ AzureOpenAiImageModel openAiImageModel(Properties properties) {
ImageModelProperties imageModelProperties = properties.getImageModel();
AzureOpenAiImageModel.Builder builder = AzureOpenAiImageModel.builder()
.endpoint(imageModelProperties.getEndpoint())
.serviceVersion(imageModelProperties.getServiceVersion())
.apiKey(imageModelProperties.getApiKey())
.deploymentName(imageModelProperties.getDeploymentName())
.size(imageModelProperties.getSize())
.quality(imageModelProperties.getQuality())
.style(imageModelProperties.getStyle())
.size(imageModelProperties.getSize())
.user(imageModelProperties.getUser())
.style(imageModelProperties.getStyle())
.responseFormat(imageModelProperties.getResponseFormat())
.timeout(imageModelProperties.getTimeout() == null ? null : Duration.ofSeconds(imageModelProperties.getTimeout()))
.maxRetries(imageModelProperties.getMaxRetries())
.proxyOptions(ProxyOptions.fromConfiguration(Configuration.getGlobalConfiguration()))
.customHeaders(imageModelProperties.getCustomHeaders())
.logRequestsAndResponses(imageModelProperties.getLogRequestsAndResponses() != null && imageModelProperties.getLogRequestsAndResponses());
.logRequestsAndResponses(imageModelProperties.getLogRequestsAndResponses() != null && imageModelProperties.getLogRequestsAndResponses())
.userAgentSuffix(imageModelProperties.getUserAgentSuffix())
.customHeaders(imageModelProperties.getCustomHeaders());
if (imageModelProperties.getNonAzureApiKey() != null) {
builder.nonAzureApiKey(imageModelProperties.getNonAzureApiKey());
}
return builder.build();
}


@Bean
@ConditionalOnMissingBean
AzureOpenAiTokenizer azureOpenAiTokenizer() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,23 @@
class ChatModelProperties {

String endpoint;
String serviceVersion;
String apiKey;
String nonAzureApiKey;
String organizationId;
String deploymentName;
Integer maxTokens;
Double temperature;
Double topP;
Integer maxTokens;
Map<String, Integer> logitBias;
String user;
List<String> stop;
Double presencePenalty;
Double frequencyPenalty;
Long seed;
String responseFormat;
Integer seed;
List<String> stop;
Integer timeout;
Integer timeout; // TODO use Duration instead
Integer maxRetries;
Boolean logRequestsAndResponses;
String userAgentSuffix;
Map<String, String> customHeaders;
String nonAzureApiKey;
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
class EmbeddingModelProperties {

String endpoint;
String serviceVersion;
String apiKey;
String nonAzureApiKey;
String deploymentName;
Integer dimensions;
Integer timeout;
Integer timeout; // TODO use duration instead
Integer maxRetries;
Boolean logRequestsAndResponses;
String userAgentSuffix;
Integer dimensions;
Map<String, String> customHeaders;
String nonAzureApiKey;
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,18 @@
class ImageModelProperties {

String endpoint;
String serviceVersion;
String apiKey;
String nonAzureApiKey;
String deploymentName;
String size;
String quality;
String size;
String user;
String style;
String responseFormat;
String user;
Integer timeout;
Integer maxRetries;
Boolean logRequestsAndResponses;
String userAgentSuffix;
Map<String, String> customHeaders;
String nonAzureApiKey;
}

0 comments on commit 59df6a9

Please sign in to comment.