From 3041558f768c30939347d710a8fab4c6af35ab3d Mon Sep 17 00:00:00 2001 From: bidek <1751659+bidek@users.noreply.github.com> Date: Wed, 18 Dec 2024 09:51:44 +0100 Subject: [PATCH 1/4] Ollama Spring Boot Starter - Supported capabilities in application.properties (#96) ## Issue https://github.com/langchain4j/langchain4j/pull/2250 ## Change Add support for passing supported capabilities through application.properties ## General checklist - [x] There are no breaking changes - [x] I have added unit and/or integration tests for my change - [ ] The tests cover both positive and negative cases - [x] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) --- .../langchain4j/ollama/spring/AutoConfig.java | 2 ++ .../ollama/spring/ChatModelProperties.java | 3 +++ .../ollama/spring/AutoConfigIT.java | 20 +++++++++++++++++++ 3 files changed, 25 insertions(+) diff --git a/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/AutoConfig.java b/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/AutoConfig.java index 65245793..02e776c0 100644 --- a/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/AutoConfig.java +++ b/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/AutoConfig.java @@ -27,6 +27,7 @@ OllamaChatModel ollamaChatModel(Properties properties) { .numPredict(chatModelProperties.getNumPredict()) .stop(chatModelProperties.getStop()) .format(chatModelProperties.getFormat()) + .supportedCapabilities(chatModelProperties.getSupportedCapabilities()) .timeout(chatModelProperties.getTimeout()) .maxRetries(chatModelProperties.getMaxRetries()) .customHeaders(chatModelProperties.getCustomHeaders()) @@ -50,6 +51,7 @@ OllamaStreamingChatModel ollamaStreamingChatModel(Properties properties) { .numPredict(chatModelProperties.getNumPredict()) .stop(chatModelProperties.getStop()) .format(chatModelProperties.getFormat()) + .supportedCapabilities(chatModelProperties.getSupportedCapabilities()) .timeout(chatModelProperties.getTimeout()) .customHeaders(chatModelProperties.getCustomHeaders()) .logRequests(chatModelProperties.getLogRequests()) diff --git a/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/ChatModelProperties.java b/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/ChatModelProperties.java index 0c75cdec..8fcc970c 100644 --- a/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/ChatModelProperties.java +++ b/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/ChatModelProperties.java @@ -1,11 +1,13 @@ package dev.langchain4j.ollama.spring; +import dev.langchain4j.model.chat.Capability; import lombok.Getter; import lombok.Setter; import java.time.Duration; import java.util.List; import java.util.Map; +import java.util.Set; @Getter @Setter @@ -21,6 +23,7 @@ class ChatModelProperties { Integer numPredict; List stop; String format; + Set supportedCapabilities; Duration timeout; Integer maxRetries; Map customHeaders; diff --git a/langchain4j-ollama-spring-boot-starter/src/test/java/dev/langchain4j/ollama/spring/AutoConfigIT.java b/langchain4j-ollama-spring-boot-starter/src/test/java/dev/langchain4j/ollama/spring/AutoConfigIT.java index fced9a72..3e7451f0 100644 --- a/langchain4j-ollama-spring-boot-starter/src/test/java/dev/langchain4j/ollama/spring/AutoConfigIT.java +++ b/langchain4j-ollama-spring-boot-starter/src/test/java/dev/langchain4j/ollama/spring/AutoConfigIT.java @@ -18,6 +18,7 @@ import java.util.concurrent.CompletableFuture; +import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA; import static java.lang.String.format; import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; @@ -57,6 +58,25 @@ void should_provide_chat_model() { }); } + @Test + void should_provide_chat_model_with_supported_capabilities() { + contextRunner + .withPropertyValues( + "langchain4j.ollama.chat-model.base-url=" + baseUrl(), + "langchain4j.ollama.chat-model.model-name=" + MODEL_NAME, + "langchain4j.ollama.chat-model.supportedCapabilities=RESPONSE_FORMAT_JSON_SCHEMA" + ) + .run(context -> { + + ChatLanguageModel chatLanguageModel = context.getBean(ChatLanguageModel.class); + assertThat(chatLanguageModel).isInstanceOf(OllamaChatModel.class); + assertThat(chatLanguageModel.supportedCapabilities()).contains(RESPONSE_FORMAT_JSON_SCHEMA); + + assertThat(context.getBean(OllamaChatModel.class)).isSameAs(chatLanguageModel); + }); + } + + @Test void should_provide_streaming_chat_model() { contextRunner From 3bdf8f235b1609856c7b90efe5adb3f0a5ec9e84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ji=C5=99=C3=AD=20Krokviak?= <26182195+Empatixx@users.noreply.github.com> Date: Fri, 20 Dec 2024 12:25:36 +0100 Subject: [PATCH 2/4] #2247 - autoconfig openai maxSegmentsPerBatch (#97) ## Issue Continues https://github.com/langchain4j/langchain4j/pull/2248 from Langchain4J, adding new parameter for EmbeddingModel from OpenAi ## Change I added new autoconfig field `maxSegmentsPerBatch` ## General checklist - [X] There are no breaking changes - [ ] I have added unit and/or integration tests for my change - [ ] The tests cover both positive and negative cases - [ ] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) --- .../src/main/java/dev/langchain4j/openai/spring/AutoConfig.java | 1 + .../dev/langchain4j/openai/spring/EmbeddingModelProperties.java | 1 + 2 files changed, 2 insertions(+) diff --git a/langchain4j-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/openai/spring/AutoConfig.java b/langchain4j-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/openai/spring/AutoConfig.java index d0db55ac..3feef275 100644 --- a/langchain4j-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/openai/spring/AutoConfig.java +++ b/langchain4j-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/openai/spring/AutoConfig.java @@ -122,6 +122,7 @@ OpenAiEmbeddingModel openAiEmbeddingModel(Properties properties) { .organizationId(embeddingModelProperties.organizationId()) .modelName(embeddingModelProperties.modelName()) .dimensions(embeddingModelProperties.dimensions()) + .maxSegmentsPerBatch(embeddingModelProperties.maxSegmentsPerBatch()) .user(embeddingModelProperties.user()) .timeout(embeddingModelProperties.timeout()) .maxRetries(embeddingModelProperties.maxRetries()) diff --git a/langchain4j-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/openai/spring/EmbeddingModelProperties.java b/langchain4j-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/openai/spring/EmbeddingModelProperties.java index 381d0f34..d790effc 100644 --- a/langchain4j-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/openai/spring/EmbeddingModelProperties.java +++ b/langchain4j-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/openai/spring/EmbeddingModelProperties.java @@ -13,6 +13,7 @@ record EmbeddingModelProperties( String organizationId, String modelName, Integer dimensions, + Integer maxSegmentsPerBatch, String user, Duration timeout, Integer maxRetries, From a4280ba2bd8ecf0bd01a21a58fa0fdf86cbaa0cb Mon Sep 17 00:00:00 2001 From: Enigma Date: Fri, 20 Dec 2024 21:30:20 +0800 Subject: [PATCH 3/4] AiServiceRegisteredEvent (#89) Closes https://github.com/langchain4j/langchain4j/issues/2112 Publish a `AiServiceRegisteredEvent` Spring Event after registering the `AiService` bean in `AiServicesAutoConfig`. This event contains the `AiService` class and its corresponding tools description information. Once a user implements the even listener to listen for this event, they can receive the event during the Spring Boot startup phase and handle their business logic as needed. --- .../service/spring/AiServicesAutoConfig.java | 40 +++++++++++++++---- .../event/AiServiceRegisteredEvent.java | 28 +++++++++++++ .../AiServiceWithToolsApplication.java | 16 +++++++- .../withTools/AiServicesAutoConfigIT.java | 32 +++++++++++++++ .../listener/AbstractApplicationListener.java | 24 +++++++++++ .../AiServiceRegisteredEventListener.java | 8 ++++ 6 files changed, 139 insertions(+), 9 deletions(-) create mode 100644 langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/event/AiServiceRegisteredEvent.java create mode 100644 langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/listener/AbstractApplicationListener.java create mode 100644 langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/listener/AiServiceRegisteredEventListener.java diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java index 69f6fddf..629f3894 100644 --- a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java @@ -1,6 +1,8 @@ package dev.langchain4j.service.spring; import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.agent.tool.ToolSpecifications; import dev.langchain4j.exception.IllegalConfigurationException; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; @@ -9,19 +11,21 @@ import dev.langchain4j.model.moderation.ModerationModel; import dev.langchain4j.rag.RetrievalAugmentor; import dev.langchain4j.rag.content.retriever.ContentRetriever; +import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.GenericBeanDefinition; import org.springframework.beans.factory.support.ManagedList; +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.context.annotation.Bean; import java.lang.reflect.Method; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashSet; -import java.util.Set; +import java.util.*; import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration; import static dev.langchain4j.internal.Exceptions.illegalArgument; @@ -31,7 +35,16 @@ import static dev.langchain4j.service.spring.AiServiceWiringMode.EXPLICIT; import static java.util.Arrays.asList; -public class AiServicesAutoConfig { +public class AiServicesAutoConfig implements ApplicationEventPublisherAware { + + private static final Logger log = LoggerFactory.getLogger(AiServicesAutoConfig.class); + + private ApplicationEventPublisher eventPublisher; + + @Override + public void setApplicationEventPublisher(ApplicationEventPublisher eventPublisher) { + this.eventPublisher = eventPublisher; + } @Bean BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { @@ -46,7 +59,8 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { String[] retrievalAugmentors = beanFactory.getBeanNamesForType(RetrievalAugmentor.class); String[] moderationModels = beanFactory.getBeanNamesForType(ModerationModel.class); - Set tools = new HashSet<>(); + Set toolBeanNames = new HashSet<>(); + List toolSpecifications = new ArrayList<>(); for (String beanName : beanFactory.getBeanDefinitionNames()) { try { String beanClassName = beanFactory.getBeanDefinition(beanName).getBeanClassName(); @@ -56,7 +70,13 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { Class beanClass = Class.forName(beanClassName); for (Method beanMethod : beanClass.getDeclaredMethods()) { if (beanMethod.isAnnotationPresent(Tool.class)) { - tools.add(beanName); + toolBeanNames.add(beanName); + try { + toolSpecifications.add(ToolSpecifications.toolSpecificationFrom(beanMethod)); + } catch (Exception e) { + log.warn("Cannot convert %s.%s method annotated with @Tool into ToolSpecification" + .formatted(beanClass.getName(), beanMethod.getName()), e); + } } } } catch (Exception e) { @@ -148,7 +168,7 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { if (aiServiceAnnotation.wiringMode() == EXPLICIT) { propertyValues.add("tools", toManagedList(asList(aiServiceAnnotation.tools()))); } else if (aiServiceAnnotation.wiringMode() == AUTOMATIC) { - propertyValues.add("tools", toManagedList(tools)); + propertyValues.add("tools", toManagedList(toolBeanNames)); } else { throw illegalArgument("Unknown wiring mode: " + aiServiceAnnotation.wiringMode()); } @@ -156,6 +176,10 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { BeanDefinitionRegistry registry = (BeanDefinitionRegistry) beanFactory; registry.removeBeanDefinition(aiService); registry.registerBeanDefinition(lowercaseFirstLetter(aiService), aiServiceBeanDefinition); + + if (eventPublisher != null) { + eventPublisher.publishEvent(new AiServiceRegisteredEvent(this, aiServiceClass, toolSpecifications)); + } } }; } diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/event/AiServiceRegisteredEvent.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/event/AiServiceRegisteredEvent.java new file mode 100644 index 00000000..8d15516e --- /dev/null +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/event/AiServiceRegisteredEvent.java @@ -0,0 +1,28 @@ +package dev.langchain4j.service.spring.event; + +import dev.langchain4j.agent.tool.ToolSpecification; +import org.springframework.context.ApplicationEvent; + +import java.util.List; + +import static dev.langchain4j.internal.Utils.copyIfNotNull; + +public class AiServiceRegisteredEvent extends ApplicationEvent { + + private final Class aiServiceClass; + private final List toolSpecifications; + + public AiServiceRegisteredEvent(Object source, Class aiServiceClass, List toolSpecifications) { + super(source); + this.aiServiceClass = aiServiceClass; + this.toolSpecifications = copyIfNotNull(toolSpecifications); + } + + public Class aiServiceClass() { + return aiServiceClass; + } + + public List toolSpecifications() { + return toolSpecifications; + } +} \ No newline at end of file diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServiceWithToolsApplication.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServiceWithToolsApplication.java index 1104d337..94e82885 100644 --- a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServiceWithToolsApplication.java +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServiceWithToolsApplication.java @@ -1,12 +1,26 @@ package dev.langchain4j.service.spring.mode.automatic.withTools; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.ApplicationListener; + +import java.util.List; @SpringBootApplication -class AiServiceWithToolsApplication { +class AiServiceWithToolsApplication implements ApplicationListener { public static void main(String[] args) { SpringApplication.run(AiServiceWithToolsApplication.class, args); } + + @Override + public void onApplicationEvent(AiServiceRegisteredEvent event) { + Class aiServiceClass = event.aiServiceClass(); + List toolSpecifications = event.toolSpecifications(); + for (int i = 0; i < toolSpecifications.size(); i++) { + System.out.printf("[%s]: [Tool-%s]: %s%n", aiServiceClass.getSimpleName(), i + 1, toolSpecifications.get(i)); + } + } } diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServicesAutoConfigIT.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServicesAutoConfigIT.java index 49bf50b5..3717cc10 100644 --- a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServicesAutoConfigIT.java +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServicesAutoConfigIT.java @@ -1,11 +1,16 @@ package dev.langchain4j.service.spring.mode.automatic.withTools; +import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.service.spring.AiServicesAutoConfig; +import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent; import dev.langchain4j.service.spring.mode.automatic.withTools.aop.ToolObserverAspect; +import dev.langchain4j.service.spring.mode.automatic.withTools.listener.AiServiceRegisteredEventListener; import org.junit.jupiter.api.Test; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import java.util.List; + import static dev.langchain4j.service.spring.mode.ApiKeys.OPENAI_API_KEY; import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY; import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY_NAME_DESCRIPTION; @@ -16,6 +21,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; class AiServicesAutoConfigIT { @@ -69,6 +75,32 @@ void should_create_AI_service_with_tool_that_is_package_private_method_in_packag }); } + @Test + void should_receive_ai_service_registered_event() { + contextRunner + .withUserConfiguration(AiServiceWithToolsApplication.class) + .run(context -> { + + // given + AiServiceRegisteredEventListener listener = context.getBean(AiServiceRegisteredEventListener.class); + + // then should receive AiServiceRegisteredEvent + assertTrue(listener.isEventReceived()); + assertEquals(1, listener.getReceivedEvents().size()); + + AiServiceRegisteredEvent event = listener.getReceivedEvents().stream().findFirst().orElse(null); + assertNotNull(event); + assertEquals(AiServiceWithTools.class, event.aiServiceClass()); + assertEquals(4, event.toolSpecifications().size()); + + List tools = event.toolSpecifications().stream().map(ToolSpecification::name).toList(); + assertTrue(tools.contains("getCurrentDate")); + assertTrue(tools.contains("getCurrentTime")); + assertTrue(tools.contains("getToolObserverPackageName")); + assertTrue(tools.contains("getToolObserverKey")); + }); + } + @Test void should_create_AI_service_with_tool_which_is_enhanced_by_spring_aop() { contextRunner diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/listener/AbstractApplicationListener.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/listener/AbstractApplicationListener.java new file mode 100644 index 00000000..cb41209a --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/listener/AbstractApplicationListener.java @@ -0,0 +1,24 @@ +package dev.langchain4j.service.spring.mode.automatic.withTools.listener; + +import org.springframework.context.ApplicationEvent; +import org.springframework.context.ApplicationListener; + +import java.util.ArrayList; +import java.util.List; + +public class AbstractApplicationListener implements ApplicationListener { + private final List receivedEvents = new ArrayList<>(); + + @Override + public void onApplicationEvent(E event) { + receivedEvents.add(event); + } + + public List getReceivedEvents() { + return receivedEvents; + } + + public boolean isEventReceived() { + return !receivedEvents.isEmpty(); + } +} \ No newline at end of file diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/listener/AiServiceRegisteredEventListener.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/listener/AiServiceRegisteredEventListener.java new file mode 100644 index 00000000..35ef0c1e --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/listener/AiServiceRegisteredEventListener.java @@ -0,0 +1,8 @@ +package dev.langchain4j.service.spring.mode.automatic.withTools.listener; + +import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent; +import org.springframework.stereotype.Component; + +@Component +public class AiServiceRegisteredEventListener extends AbstractApplicationListener { +} From cb8421aa9e7f0d552cfbefa5b4f6ad42c370b9a4 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Fri, 20 Dec 2024 14:38:59 +0100 Subject: [PATCH 4/4] Add Azure OpenAI structured outputs support (#98) As mentioned in https://github.com/langchain4j/langchain4j/pull/1982#issuecomment-2553095472 I added a complete test, maybe it's a bit too much, then I like to have this as some kind of example on how to use the code. --- .../azure/openai/spring/AutoConfig.java | 4 +- .../openai/spring/ChatModelProperties.java | 8 ++- .../azure/openai/spring/AutoConfigIT.java | 56 +++++++++++++++++++ 3 files changed, 65 insertions(+), 3 deletions(-) diff --git a/langchain4j-azure-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/azure/openai/spring/AutoConfig.java b/langchain4j-azure-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/azure/openai/spring/AutoConfig.java index 6d96d798..244f90b4 100644 --- a/langchain4j-azure-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/azure/openai/spring/AutoConfig.java +++ b/langchain4j-azure-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/azure/openai/spring/AutoConfig.java @@ -46,12 +46,14 @@ AzureOpenAiChatModel openAiChatModel(Properties properties) { .presencePenalty(chatModelProperties.presencePenalty()) .frequencyPenalty(chatModelProperties.frequencyPenalty()) .seed(chatModelProperties.seed()) + .strictJsonSchema(chatModelProperties.strictJsonSchema()) .timeout(Duration.ofSeconds(chatModelProperties.timeout() == null ? 0 : chatModelProperties.timeout())) .maxRetries(chatModelProperties.maxRetries()) .proxyOptions(ProxyOptions.fromConfiguration(Configuration.getGlobalConfiguration())) .logRequestsAndResponses(chatModelProperties.logRequestsAndResponses() != null && chatModelProperties.logRequestsAndResponses()) .userAgentSuffix(chatModelProperties.userAgentSuffix()) - .customHeaders(chatModelProperties.customHeaders()); + .customHeaders(chatModelProperties.customHeaders()) + .supportedCapabilities(chatModelProperties.supportedCapabilities()); if (chatModelProperties.nonAzureApiKey() != null) { builder.nonAzureApiKey(chatModelProperties.nonAzureApiKey()); } diff --git a/langchain4j-azure-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/azure/openai/spring/ChatModelProperties.java b/langchain4j-azure-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/azure/openai/spring/ChatModelProperties.java index 5835406e..2c957428 100644 --- a/langchain4j-azure-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/azure/openai/spring/ChatModelProperties.java +++ b/langchain4j-azure-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/azure/openai/spring/ChatModelProperties.java @@ -1,7 +1,10 @@ package dev.langchain4j.azure.openai.spring; +import dev.langchain4j.model.chat.Capability; + import java.util.List; import java.util.Map; +import java.util.Set; record ChatModelProperties( @@ -18,12 +21,13 @@ record ChatModelProperties( Double presencePenalty, Double frequencyPenalty, Long seed, - String responseFormat, + Boolean strictJsonSchema, Integer timeout, // TODO use Duration instead Integer maxRetries, Boolean logRequestsAndResponses, String userAgentSuffix, Map customHeaders, - String nonAzureApiKey + String nonAzureApiKey, + Set supportedCapabilities ) { } \ No newline at end of file diff --git a/langchain4j-azure-open-ai-spring-boot-starter/src/test/java/dev/langchain4j/azure/openai/spring/AutoConfigIT.java b/langchain4j-azure-open-ai-spring-boot-starter/src/test/java/dev/langchain4j/azure/openai/spring/AutoConfigIT.java index bd06a23c..5547a8fd 100644 --- a/langchain4j-azure-open-ai-spring-boot-starter/src/test/java/dev/langchain4j/azure/openai/spring/AutoConfigIT.java +++ b/langchain4j-azure-open-ai-spring-boot-starter/src/test/java/dev/langchain4j/azure/openai/spring/AutoConfigIT.java @@ -8,6 +8,12 @@ import dev.langchain4j.model.azure.AzureOpenAiStreamingChatModel; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.request.ResponseFormat; +import dev.langchain4j.model.chat.request.json.JsonArraySchema; +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; +import dev.langchain4j.model.chat.request.json.JsonSchema; +import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.image.ImageModel; import dev.langchain4j.model.output.Response; @@ -17,8 +23,12 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import java.util.List; import java.util.concurrent.CompletableFuture; +import static dev.langchain4j.data.message.UserMessage.userMessage; +import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON; +import static java.util.Collections.singletonList; import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; @@ -53,6 +63,52 @@ void should_provide_chat_model(String deploymentName) { }); } + class Person { + + String name; + List favouriteColors; + } + + @ParameterizedTest(name = "Deployment name: {0}") + @CsvSource({ + "gpt-4o-mini" + }) + void should_provide_chat_model_with_json_schema(String deploymentName) { + contextRunner + .withPropertyValues( + "langchain4j.azure-open-ai.chat-model.api-key=" + AZURE_OPENAI_KEY, + "langchain4j.azure-open-ai.chat-model.endpoint=" + AZURE_OPENAI_ENDPOINT, + "langchain4j.azure-open-ai.chat-model.deployment-name=" + deploymentName, + "langchain4j.azure-open-ai.chat-model.strict-json-schema=true" + ) + .run(context -> { + + ChatLanguageModel chatLanguageModel = context.getBean(ChatLanguageModel.class); + + ChatRequest chatRequest = ChatRequest.builder() + .messages(singletonList(userMessage("Julien likes blue, white and red"))) + .responseFormat(ResponseFormat.builder() + .type(JSON) + .jsonSchema(JsonSchema.builder() + .name("Person") + .rootElement(JsonObjectSchema.builder() + .addStringProperty("name") + .addProperty("favouriteColors", JsonArraySchema.builder() + .items(new JsonStringSchema()) + .build()) + .required("name", "favouriteColors") + .build()) + .build()) + .build()) + .build(); + + assertThat(chatLanguageModel).isInstanceOf(AzureOpenAiChatModel.class); + AiMessage aiMessage = chatLanguageModel.chat(chatRequest).aiMessage(); + assertThat(aiMessage.text()).contains("{\"name\":\"Julien\",\"favouriteColors\":[\"blue\",\"white\",\"red\"]}"); + assertThat(context.getBean(AzureOpenAiChatModel.class)).isSameAs(chatLanguageModel); + }); + } + @ParameterizedTest(name = "Deployment name: {0}") @CsvSource({ "gpt-3.5-turbo"