Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
original PR link: langchain4j#77
Issues link: langchain4j/langchain4j#2112
  • Loading branch information
catofdestruction committed Dec 9, 2024
1 parent 92b30ec commit 3880155
Showing 1 changed file with 14 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import org.springframework.beans.factory.support.GenericBeanDefinition;
import org.springframework.beans.factory.support.ManagedList;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.context.annotation.Bean;
import org.springframework.lang.NonNull;

import java.lang.reflect.Method;
import java.util.ArrayList;
Expand All @@ -42,17 +40,10 @@
import static dev.langchain4j.service.spring.AiServiceWiringMode.EXPLICIT;
import static java.util.Arrays.asList;

public class AiServicesAutoConfig implements ApplicationEventPublisherAware {

private ApplicationEventPublisher eventPublisher;

@Override
public void setApplicationEventPublisher(@NonNull ApplicationEventPublisher applicationEventPublisher) {
this.eventPublisher = applicationEventPublisher;
}
public class AiServicesAutoConfig {

@Bean
BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor(ApplicationEventPublisher eventPublisher) {
return beanFactory -> {

// all components available in the application context
Expand All @@ -64,7 +55,7 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
String[] retrievalAugmentors = beanFactory.getBeanNamesForType(RetrievalAugmentor.class);
String[] moderationModels = beanFactory.getBeanNamesForType(ModerationModel.class);

Set<String> tools = new HashSet<>();
Set<String> toolBeanNames = new HashSet<>();
Map<String, List<ToolSpecification>> beanToolSpecifications = new HashMap<>();
for (String beanName : beanFactory.getBeanDefinitionNames()) {
try {
Expand All @@ -75,7 +66,7 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
Class<?> beanClass = Class.forName(beanClassName);
for (Method beanMethod : beanClass.getDeclaredMethods()) {
if (beanMethod.isAnnotationPresent(Tool.class)) {
tools.add(beanName);
toolBeanNames.add(beanName);
List<ToolSpecification> toolSpecifications =
beanToolSpecifications.getOrDefault(beanName, new ArrayList<>());
toolSpecifications.add(ToolSpecifications.toolSpecificationFrom(beanMethod));
Expand Down Expand Up @@ -173,8 +164,8 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
propertyValues.add("tools", toManagedList(asList(aiServiceAnnotation.tools())));
registeredEvent = buildEvent(aiServiceClass, beanToolSpecifications, asList(aiServiceAnnotation.tools()));
} else if (aiServiceAnnotation.wiringMode() == AUTOMATIC) {
propertyValues.add("tools", toManagedList(tools));
registeredEvent = buildEvent(aiServiceClass, beanToolSpecifications, tools);
propertyValues.add("tools", toManagedList(toolBeanNames));
registeredEvent = buildEvent(aiServiceClass, beanToolSpecifications, toolBeanNames);
} else {
throw illegalArgument("Unknown wiring mode: " + aiServiceAnnotation.wiringMode());
}
Expand Down Expand Up @@ -234,13 +225,13 @@ private static ManagedList<RuntimeBeanReference> toManagedList(Collection<String
return managedList;
}

private static AiServiceRegisteredEvent buildEvent(Class<?> aiServiceClass,
Map<String, List<ToolSpecification>> toolSpecifications,
Collection<String> tools) {
return new AiServiceRegisteredEvent(AiServicesAutoConfig.class, aiServiceClass,
tools.stream()
.filter(toolSpecifications::containsKey)
.flatMap(tool -> toolSpecifications.get(tool).stream())
.collect(Collectors.toList()));
private AiServiceRegisteredEvent buildEvent(Class<?> aiServiceClass,
Map<String, List<ToolSpecification>> beanToolSpecifications,
Collection<String> toolBeanNames) {
return new AiServiceRegisteredEvent(this, aiServiceClass,
toolBeanNames.stream()
.filter(beanToolSpecifications::containsKey)
.flatMap(toolBeanName -> beanToolSpecifications.get(toolBeanName).stream())
.collect(Collectors.toList()));
}
}

0 comments on commit 3880155

Please sign in to comment.