Skip to content

Commit

Permalink
Add Azure OpenAI structured outputs support (#98)
Browse files Browse the repository at this point in the history
As mentioned in
langchain4j/langchain4j#1982 (comment)

I added a complete test, maybe it's a bit too much, then I like to have
this as some kind of example on how to use the code.
  • Loading branch information
jdubois authored Dec 20, 2024
1 parent a4280ba commit cb8421a
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ AzureOpenAiChatModel openAiChatModel(Properties properties) {
.presencePenalty(chatModelProperties.presencePenalty())
.frequencyPenalty(chatModelProperties.frequencyPenalty())
.seed(chatModelProperties.seed())
.strictJsonSchema(chatModelProperties.strictJsonSchema())
.timeout(Duration.ofSeconds(chatModelProperties.timeout() == null ? 0 : chatModelProperties.timeout()))
.maxRetries(chatModelProperties.maxRetries())
.proxyOptions(ProxyOptions.fromConfiguration(Configuration.getGlobalConfiguration()))
.logRequestsAndResponses(chatModelProperties.logRequestsAndResponses() != null && chatModelProperties.logRequestsAndResponses())
.userAgentSuffix(chatModelProperties.userAgentSuffix())
.customHeaders(chatModelProperties.customHeaders());
.customHeaders(chatModelProperties.customHeaders())
.supportedCapabilities(chatModelProperties.supportedCapabilities());
if (chatModelProperties.nonAzureApiKey() != null) {
builder.nonAzureApiKey(chatModelProperties.nonAzureApiKey());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package dev.langchain4j.azure.openai.spring;

import dev.langchain4j.model.chat.Capability;

import java.util.List;
import java.util.Map;
import java.util.Set;

record ChatModelProperties(

Expand All @@ -18,12 +21,13 @@ record ChatModelProperties(
Double presencePenalty,
Double frequencyPenalty,
Long seed,
String responseFormat,
Boolean strictJsonSchema,
Integer timeout, // TODO use Duration instead
Integer maxRetries,
Boolean logRequestsAndResponses,
String userAgentSuffix,
Map<String, String> customHeaders,
String nonAzureApiKey
String nonAzureApiKey,
Set<Capability> supportedCapabilities
) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
import dev.langchain4j.model.azure.AzureOpenAiStreamingChatModel;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.image.ImageModel;
import dev.langchain4j.model.output.Response;
Expand All @@ -17,8 +23,12 @@
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;

import java.util.List;
import java.util.concurrent.CompletableFuture;

import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON;
import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;

Expand Down Expand Up @@ -53,6 +63,52 @@ void should_provide_chat_model(String deploymentName) {
});
}

class Person {

String name;
List<String> favouriteColors;
}

@ParameterizedTest(name = "Deployment name: {0}")
@CsvSource({
"gpt-4o-mini"
})
void should_provide_chat_model_with_json_schema(String deploymentName) {
contextRunner
.withPropertyValues(
"langchain4j.azure-open-ai.chat-model.api-key=" + AZURE_OPENAI_KEY,
"langchain4j.azure-open-ai.chat-model.endpoint=" + AZURE_OPENAI_ENDPOINT,
"langchain4j.azure-open-ai.chat-model.deployment-name=" + deploymentName,
"langchain4j.azure-open-ai.chat-model.strict-json-schema=true"
)
.run(context -> {

ChatLanguageModel chatLanguageModel = context.getBean(ChatLanguageModel.class);

ChatRequest chatRequest = ChatRequest.builder()
.messages(singletonList(userMessage("Julien likes blue, white and red")))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.rootElement(JsonObjectSchema.builder()
.addStringProperty("name")
.addProperty("favouriteColors", JsonArraySchema.builder()
.items(new JsonStringSchema())
.build())
.required("name", "favouriteColors")
.build())
.build())
.build())
.build();

assertThat(chatLanguageModel).isInstanceOf(AzureOpenAiChatModel.class);
AiMessage aiMessage = chatLanguageModel.chat(chatRequest).aiMessage();
assertThat(aiMessage.text()).contains("{\"name\":\"Julien\",\"favouriteColors\":[\"blue\",\"white\",\"red\"]}");
assertThat(context.getBean(AzureOpenAiChatModel.class)).isSameAs(chatLanguageModel);
});
}

@ParameterizedTest(name = "Deployment name: {0}")
@CsvSource({
"gpt-3.5-turbo"
Expand Down

0 comments on commit cb8421a

Please sign in to comment.