diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java index 4f347c59..65444346 100644 --- a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java @@ -19,9 +19,7 @@ import org.springframework.beans.factory.support.GenericBeanDefinition; import org.springframework.beans.factory.support.ManagedList; import org.springframework.context.ApplicationEventPublisher; -import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.context.annotation.Bean; -import org.springframework.lang.NonNull; import java.lang.reflect.Method; import java.util.ArrayList; @@ -42,17 +40,10 @@ import static dev.langchain4j.service.spring.AiServiceWiringMode.EXPLICIT; import static java.util.Arrays.asList; -public class AiServicesAutoConfig implements ApplicationEventPublisherAware { - - private ApplicationEventPublisher eventPublisher; - - @Override - public void setApplicationEventPublisher(@NonNull ApplicationEventPublisher applicationEventPublisher) { - this.eventPublisher = applicationEventPublisher; - } +public class AiServicesAutoConfig { @Bean - BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { + BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor(ApplicationEventPublisher eventPublisher) { return beanFactory -> { // all components available in the application context @@ -64,7 +55,7 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { String[] retrievalAugmentors = beanFactory.getBeanNamesForType(RetrievalAugmentor.class); String[] moderationModels = beanFactory.getBeanNamesForType(ModerationModel.class); - Set tools = new HashSet<>(); + Set toolBeanNames = new HashSet<>(); Map> beanToolSpecifications = new HashMap<>(); for (String beanName : beanFactory.getBeanDefinitionNames()) { try { @@ -75,7 +66,7 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { Class beanClass = Class.forName(beanClassName); for (Method beanMethod : beanClass.getDeclaredMethods()) { if (beanMethod.isAnnotationPresent(Tool.class)) { - tools.add(beanName); + toolBeanNames.add(beanName); List toolSpecifications = beanToolSpecifications.getOrDefault(beanName, new ArrayList<>()); toolSpecifications.add(ToolSpecifications.toolSpecificationFrom(beanMethod)); @@ -173,8 +164,8 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { propertyValues.add("tools", toManagedList(asList(aiServiceAnnotation.tools()))); registeredEvent = buildEvent(aiServiceClass, beanToolSpecifications, asList(aiServiceAnnotation.tools())); } else if (aiServiceAnnotation.wiringMode() == AUTOMATIC) { - propertyValues.add("tools", toManagedList(tools)); - registeredEvent = buildEvent(aiServiceClass, beanToolSpecifications, tools); + propertyValues.add("tools", toManagedList(toolBeanNames)); + registeredEvent = buildEvent(aiServiceClass, beanToolSpecifications, toolBeanNames); } else { throw illegalArgument("Unknown wiring mode: " + aiServiceAnnotation.wiringMode()); } @@ -234,13 +225,13 @@ private static ManagedList toManagedList(Collection aiServiceClass, - Map> toolSpecifications, - Collection tools) { - return new AiServiceRegisteredEvent(AiServicesAutoConfig.class, aiServiceClass, - tools.stream() - .filter(toolSpecifications::containsKey) - .flatMap(tool -> toolSpecifications.get(tool).stream()) - .collect(Collectors.toList())); + private AiServiceRegisteredEvent buildEvent(Class aiServiceClass, + Map> beanToolSpecifications, + Collection toolBeanNames) { + return new AiServiceRegisteredEvent(this, aiServiceClass, + toolBeanNames.stream() + .filter(beanToolSpecifications::containsKey) + .flatMap(toolBeanName -> beanToolSpecifications.get(toolBeanName).stream()) + .collect(Collectors.toList())); } }