diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java index 69f6fddf..629f3894 100644 --- a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java @@ -1,6 +1,8 @@ package dev.langchain4j.service.spring; import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.agent.tool.ToolSpecifications; import dev.langchain4j.exception.IllegalConfigurationException; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; @@ -9,19 +11,21 @@ import dev.langchain4j.model.moderation.ModerationModel; import dev.langchain4j.rag.RetrievalAugmentor; import dev.langchain4j.rag.content.retriever.ContentRetriever; +import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.GenericBeanDefinition; import org.springframework.beans.factory.support.ManagedList; +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.context.annotation.Bean; import java.lang.reflect.Method; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashSet; -import java.util.Set; +import java.util.*; import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration; import static dev.langchain4j.internal.Exceptions.illegalArgument; @@ -31,7 +35,16 @@ import static dev.langchain4j.service.spring.AiServiceWiringMode.EXPLICIT; import static java.util.Arrays.asList; -public class AiServicesAutoConfig { +public class AiServicesAutoConfig implements ApplicationEventPublisherAware { + + private static final Logger log = LoggerFactory.getLogger(AiServicesAutoConfig.class); + + private ApplicationEventPublisher eventPublisher; + + @Override + public void setApplicationEventPublisher(ApplicationEventPublisher eventPublisher) { + this.eventPublisher = eventPublisher; + } @Bean BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { @@ -46,7 +59,8 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { String[] retrievalAugmentors = beanFactory.getBeanNamesForType(RetrievalAugmentor.class); String[] moderationModels = beanFactory.getBeanNamesForType(ModerationModel.class); - Set tools = new HashSet<>(); + Set toolBeanNames = new HashSet<>(); + List toolSpecifications = new ArrayList<>(); for (String beanName : beanFactory.getBeanDefinitionNames()) { try { String beanClassName = beanFactory.getBeanDefinition(beanName).getBeanClassName(); @@ -56,7 +70,13 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { Class beanClass = Class.forName(beanClassName); for (Method beanMethod : beanClass.getDeclaredMethods()) { if (beanMethod.isAnnotationPresent(Tool.class)) { - tools.add(beanName); + toolBeanNames.add(beanName); + try { + toolSpecifications.add(ToolSpecifications.toolSpecificationFrom(beanMethod)); + } catch (Exception e) { + log.warn("Cannot convert %s.%s method annotated with @Tool into ToolSpecification" + .formatted(beanClass.getName(), beanMethod.getName()), e); + } } } } catch (Exception e) { @@ -148,7 +168,7 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { if (aiServiceAnnotation.wiringMode() == EXPLICIT) { propertyValues.add("tools", toManagedList(asList(aiServiceAnnotation.tools()))); } else if (aiServiceAnnotation.wiringMode() == AUTOMATIC) { - propertyValues.add("tools", toManagedList(tools)); + propertyValues.add("tools", toManagedList(toolBeanNames)); } else { throw illegalArgument("Unknown wiring mode: " + aiServiceAnnotation.wiringMode()); } @@ -156,6 +176,10 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { BeanDefinitionRegistry registry = (BeanDefinitionRegistry) beanFactory; registry.removeBeanDefinition(aiService); registry.registerBeanDefinition(lowercaseFirstLetter(aiService), aiServiceBeanDefinition); + + if (eventPublisher != null) { + eventPublisher.publishEvent(new AiServiceRegisteredEvent(this, aiServiceClass, toolSpecifications)); + } } }; } diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/event/AiServiceRegisteredEvent.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/event/AiServiceRegisteredEvent.java new file mode 100644 index 00000000..8d15516e --- /dev/null +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/event/AiServiceRegisteredEvent.java @@ -0,0 +1,28 @@ +package dev.langchain4j.service.spring.event; + +import dev.langchain4j.agent.tool.ToolSpecification; +import org.springframework.context.ApplicationEvent; + +import java.util.List; + +import static dev.langchain4j.internal.Utils.copyIfNotNull; + +public class AiServiceRegisteredEvent extends ApplicationEvent { + + private final Class aiServiceClass; + private final List toolSpecifications; + + public AiServiceRegisteredEvent(Object source, Class aiServiceClass, List toolSpecifications) { + super(source); + this.aiServiceClass = aiServiceClass; + this.toolSpecifications = copyIfNotNull(toolSpecifications); + } + + public Class aiServiceClass() { + return aiServiceClass; + } + + public List toolSpecifications() { + return toolSpecifications; + } +} \ No newline at end of file diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServiceWithToolsApplication.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServiceWithToolsApplication.java index 1104d337..94e82885 100644 --- a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServiceWithToolsApplication.java +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServiceWithToolsApplication.java @@ -1,12 +1,26 @@ package dev.langchain4j.service.spring.mode.automatic.withTools; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.ApplicationListener; + +import java.util.List; @SpringBootApplication -class AiServiceWithToolsApplication { +class AiServiceWithToolsApplication implements ApplicationListener { public static void main(String[] args) { SpringApplication.run(AiServiceWithToolsApplication.class, args); } + + @Override + public void onApplicationEvent(AiServiceRegisteredEvent event) { + Class aiServiceClass = event.aiServiceClass(); + List toolSpecifications = event.toolSpecifications(); + for (int i = 0; i < toolSpecifications.size(); i++) { + System.out.printf("[%s]: [Tool-%s]: %s%n", aiServiceClass.getSimpleName(), i + 1, toolSpecifications.get(i)); + } + } } diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServicesAutoConfigIT.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServicesAutoConfigIT.java index 49bf50b5..3717cc10 100644 --- a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServicesAutoConfigIT.java +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServicesAutoConfigIT.java @@ -1,11 +1,16 @@ package dev.langchain4j.service.spring.mode.automatic.withTools; +import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.service.spring.AiServicesAutoConfig; +import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent; import dev.langchain4j.service.spring.mode.automatic.withTools.aop.ToolObserverAspect; +import dev.langchain4j.service.spring.mode.automatic.withTools.listener.AiServiceRegisteredEventListener; import org.junit.jupiter.api.Test; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import java.util.List; + import static dev.langchain4j.service.spring.mode.ApiKeys.OPENAI_API_KEY; import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY; import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY_NAME_DESCRIPTION; @@ -16,6 +21,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; class AiServicesAutoConfigIT { @@ -69,6 +75,32 @@ void should_create_AI_service_with_tool_that_is_package_private_method_in_packag }); } + @Test + void should_receive_ai_service_registered_event() { + contextRunner + .withUserConfiguration(AiServiceWithToolsApplication.class) + .run(context -> { + + // given + AiServiceRegisteredEventListener listener = context.getBean(AiServiceRegisteredEventListener.class); + + // then should receive AiServiceRegisteredEvent + assertTrue(listener.isEventReceived()); + assertEquals(1, listener.getReceivedEvents().size()); + + AiServiceRegisteredEvent event = listener.getReceivedEvents().stream().findFirst().orElse(null); + assertNotNull(event); + assertEquals(AiServiceWithTools.class, event.aiServiceClass()); + assertEquals(4, event.toolSpecifications().size()); + + List tools = event.toolSpecifications().stream().map(ToolSpecification::name).toList(); + assertTrue(tools.contains("getCurrentDate")); + assertTrue(tools.contains("getCurrentTime")); + assertTrue(tools.contains("getToolObserverPackageName")); + assertTrue(tools.contains("getToolObserverKey")); + }); + } + @Test void should_create_AI_service_with_tool_which_is_enhanced_by_spring_aop() { contextRunner diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/listener/AbstractApplicationListener.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/listener/AbstractApplicationListener.java new file mode 100644 index 00000000..cb41209a --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/listener/AbstractApplicationListener.java @@ -0,0 +1,24 @@ +package dev.langchain4j.service.spring.mode.automatic.withTools.listener; + +import org.springframework.context.ApplicationEvent; +import org.springframework.context.ApplicationListener; + +import java.util.ArrayList; +import java.util.List; + +public class AbstractApplicationListener implements ApplicationListener { + private final List receivedEvents = new ArrayList<>(); + + @Override + public void onApplicationEvent(E event) { + receivedEvents.add(event); + } + + public List getReceivedEvents() { + return receivedEvents; + } + + public boolean isEventReceived() { + return !receivedEvents.isEmpty(); + } +} \ No newline at end of file diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/listener/AiServiceRegisteredEventListener.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/listener/AiServiceRegisteredEventListener.java new file mode 100644 index 00000000..35ef0c1e --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/listener/AiServiceRegisteredEventListener.java @@ -0,0 +1,8 @@ +package dev.langchain4j.service.spring.mode.automatic.withTools.listener; + +import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent; +import org.springframework.stereotype.Component; + +@Component +public class AiServiceRegisteredEventListener extends AbstractApplicationListener { +}