diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java index 65444346..d4db5c25 100644 --- a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java @@ -25,12 +25,9 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.HashMap; import java.util.HashSet; import java.util.List; -import java.util.Map; import java.util.Set; -import java.util.stream.Collectors; import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration; import static dev.langchain4j.internal.Exceptions.illegalArgument; @@ -56,7 +53,7 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor(Applicati String[] moderationModels = beanFactory.getBeanNamesForType(ModerationModel.class); Set toolBeanNames = new HashSet<>(); - Map> beanToolSpecifications = new HashMap<>(); + List toolSpecifications = new ArrayList<>(); for (String beanName : beanFactory.getBeanDefinitionNames()) { try { String beanClassName = beanFactory.getBeanDefinition(beanName).getBeanClassName(); @@ -67,10 +64,7 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor(Applicati for (Method beanMethod : beanClass.getDeclaredMethods()) { if (beanMethod.isAnnotationPresent(Tool.class)) { toolBeanNames.add(beanName); - List toolSpecifications = - beanToolSpecifications.getOrDefault(beanName, new ArrayList<>()); toolSpecifications.add(ToolSpecifications.toolSpecificationFrom(beanMethod)); - beanToolSpecifications.put(beanName, toolSpecifications); } } } catch (Exception e) { @@ -162,10 +156,10 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor(Applicati AiServiceRegisteredEvent registeredEvent; if (aiServiceAnnotation.wiringMode() == EXPLICIT) { propertyValues.add("tools", toManagedList(asList(aiServiceAnnotation.tools()))); - registeredEvent = buildEvent(aiServiceClass, beanToolSpecifications, asList(aiServiceAnnotation.tools())); + registeredEvent = new AiServiceRegisteredEvent(this, aiServiceClass, toolSpecifications); } else if (aiServiceAnnotation.wiringMode() == AUTOMATIC) { propertyValues.add("tools", toManagedList(toolBeanNames)); - registeredEvent = buildEvent(aiServiceClass, beanToolSpecifications, toolBeanNames); + registeredEvent = new AiServiceRegisteredEvent(this, aiServiceClass, toolSpecifications); } else { throw illegalArgument("Unknown wiring mode: " + aiServiceAnnotation.wiringMode()); } @@ -224,14 +218,4 @@ private static ManagedList toManagedList(Collection aiServiceClass, - Map> beanToolSpecifications, - Collection toolBeanNames) { - return new AiServiceRegisteredEvent(this, aiServiceClass, - toolBeanNames.stream() - .filter(beanToolSpecifications::containsKey) - .flatMap(toolBeanName -> beanToolSpecifications.get(toolBeanName).stream()) - .collect(Collectors.toList())); - } }