From c7a57e73a9421cecdcab718b6f8bf0712cd6c7e2 Mon Sep 17 00:00:00 2001 From: catofdestruction Date: Fri, 15 Nov 2024 19:56:12 +0800 Subject: [PATCH] [FEATURE] Get all metadata of AiService & Tool when the Spring Boot starts https://github.com/langchain4j/langchain4j/issues/2112 --- .../service/spring/AiServicesAutoConfig.java | 42 +++++++++++++++++-- .../event/AiServiceRegisteredEvent.java | 26 ++++++++++++ .../AiServiceRegisteredEventListener.java | 6 +++ .../AiServiceWithToolsApplication.java | 16 ++++++- 4 files changed, 85 insertions(+), 5 deletions(-) create mode 100644 langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/event/AiServiceRegisteredEvent.java create mode 100644 langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/event/AiServiceRegisteredEventListener.java diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java index 90567db3..ea246b2a 100644 --- a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java @@ -1,6 +1,8 @@ package dev.langchain4j.service.spring; import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.agent.tool.ToolSpecifications; import dev.langchain4j.exception.IllegalConfigurationException; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; @@ -9,19 +11,27 @@ import dev.langchain4j.model.moderation.ModerationModel; import dev.langchain4j.rag.RetrievalAugmentor; import dev.langchain4j.rag.content.retriever.ContentRetriever; +import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent; import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.GenericBeanDefinition; import org.springframework.beans.factory.support.ManagedList; +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.context.annotation.Bean; import java.lang.reflect.Method; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.HashMap; import java.util.HashSet; +import java.util.List; +import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration; import static dev.langchain4j.internal.Exceptions.illegalArgument; @@ -31,7 +41,14 @@ import static dev.langchain4j.service.spring.AiServiceWiringMode.EXPLICIT; import static java.util.Arrays.asList; -public class AiServicesAutoConfig { +public class AiServicesAutoConfig implements ApplicationEventPublisherAware { + + private ApplicationEventPublisher eventPublisher; + + @Override + public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) { + this.eventPublisher = applicationEventPublisher; + } @Bean BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { @@ -47,12 +64,15 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { String[] moderationModels = beanFactory.getBeanNamesForType(ModerationModel.class); Set tools = new HashSet<>(); + Map toolSpecifications = new HashMap<>(); for (String beanName : beanFactory.getBeanDefinitionNames()) { try { Class beanClass = Class.forName(beanFactory.getBeanDefinition(beanName).getBeanClassName()); for (Method beanMethod : beanClass.getDeclaredMethods()) { if (beanMethod.isAnnotationPresent(Tool.class)) { - tools.add(beanName); + if (tools.add(beanName)) { + toolSpecifications.put(beanName, ToolSpecifications.toolSpecificationFrom(beanMethod)); + } } } } catch (Exception e) { @@ -70,7 +90,6 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { MutablePropertyValues propertyValues = aiServiceBeanDefinition.getPropertyValues(); AiService aiServiceAnnotation = aiServiceClass.getAnnotation(AiService.class); - addBeanReference( ChatLanguageModel.class, aiServiceAnnotation, @@ -140,11 +159,13 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { "moderationModel", propertyValues ); - + AiServiceRegisteredEvent registeredEvent; if (aiServiceAnnotation.wiringMode() == EXPLICIT) { propertyValues.add("tools", toManagedList(asList(aiServiceAnnotation.tools()))); + registeredEvent = buildEvent(aiServiceClass, toolSpecifications, asList(aiServiceAnnotation.tools())); } else if (aiServiceAnnotation.wiringMode() == AUTOMATIC) { propertyValues.add("tools", toManagedList(tools)); + registeredEvent = buildEvent(aiServiceClass, toolSpecifications, tools); } else { throw illegalArgument("Unknown wiring mode: " + aiServiceAnnotation.wiringMode()); } @@ -152,6 +173,10 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { BeanDefinitionRegistry registry = (BeanDefinitionRegistry) beanFactory; registry.removeBeanDefinition(aiService); registry.registerBeanDefinition(lowercaseFirstLetter(aiService), aiServiceBeanDefinition); + + if (eventPublisher != null) { + eventPublisher.publishEvent(registeredEvent); + } } }; } @@ -199,4 +224,13 @@ private static ManagedList toManagedList(Collection aiServiceClass, + Map toolSpecifications, + Collection tools) { + return new AiServiceRegisteredEvent(aiServiceClass, aiServiceClass, tools.stream() + .filter(toolSpecifications::containsKey) + .map(toolSpecifications::get) + .collect(Collectors.toList())); + } } diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/event/AiServiceRegisteredEvent.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/event/AiServiceRegisteredEvent.java new file mode 100644 index 00000000..28a15ed6 --- /dev/null +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/event/AiServiceRegisteredEvent.java @@ -0,0 +1,26 @@ +package dev.langchain4j.service.spring.event; + +import dev.langchain4j.agent.tool.ToolSpecification; +import org.springframework.context.ApplicationEvent; + +import java.util.List; + +public class AiServiceRegisteredEvent extends ApplicationEvent { + + private final Class aiServiceClass; + private final List toolSpecifications; + + public AiServiceRegisteredEvent(Object source, Class aiServiceClass, List toolSpecifications) { + super(source); + this.aiServiceClass = aiServiceClass; + this.toolSpecifications = toolSpecifications; + } + + public Class getAiServiceClass() { + return aiServiceClass; + } + + public List getToolSpecifications() { + return toolSpecifications; + } +} \ No newline at end of file diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/event/AiServiceRegisteredEventListener.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/event/AiServiceRegisteredEventListener.java new file mode 100644 index 00000000..4248f0d9 --- /dev/null +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/event/AiServiceRegisteredEventListener.java @@ -0,0 +1,6 @@ +package dev.langchain4j.service.spring.event; + +import org.springframework.context.ApplicationListener; + +public interface AiServiceRegisteredEventListener extends ApplicationListener { +} diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServiceWithToolsApplication.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServiceWithToolsApplication.java index 1104d337..c4f89e04 100644 --- a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServiceWithToolsApplication.java +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServiceWithToolsApplication.java @@ -1,12 +1,26 @@ package dev.langchain4j.service.spring.mode.automatic.withTools; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent; +import dev.langchain4j.service.spring.event.AiServiceRegisteredEventListener; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; +import java.util.List; + @SpringBootApplication -class AiServiceWithToolsApplication { +class AiServiceWithToolsApplication implements AiServiceRegisteredEventListener { public static void main(String[] args) { SpringApplication.run(AiServiceWithToolsApplication.class, args); } + + @Override + public void onApplicationEvent(AiServiceRegisteredEvent event) { + Class aiServiceClass = event.getAiServiceClass(); + List toolSpecifications = event.getToolSpecifications(); + for (int i = 0; i < toolSpecifications.size(); i++) { + System.out.printf("[%s]: [Tool-%s]: %s%n", aiServiceClass.getSimpleName(), i + 1, toolSpecifications.get(i)); + } + } }