Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
original PR link: langchain4j#77
Issues link: langchain4j/langchain4j#2112
  • Loading branch information
catofdestruction committed Dec 9, 2024
1 parent 781c479 commit 1ecb553
Showing 1 changed file with 3 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,9 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
Expand All @@ -56,7 +53,7 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor(Applicati
String[] moderationModels = beanFactory.getBeanNamesForType(ModerationModel.class);

Set<String> toolBeanNames = new HashSet<>();
Map<String, List<ToolSpecification>> beanToolSpecifications = new HashMap<>();
List<ToolSpecification> toolSpecifications = new ArrayList<>();
for (String beanName : beanFactory.getBeanDefinitionNames()) {
try {
String beanClassName = beanFactory.getBeanDefinition(beanName).getBeanClassName();
Expand All @@ -67,10 +64,7 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor(Applicati
for (Method beanMethod : beanClass.getDeclaredMethods()) {
if (beanMethod.isAnnotationPresent(Tool.class)) {
toolBeanNames.add(beanName);
List<ToolSpecification> toolSpecifications =
beanToolSpecifications.getOrDefault(beanName, new ArrayList<>());
toolSpecifications.add(ToolSpecifications.toolSpecificationFrom(beanMethod));
beanToolSpecifications.put(beanName, toolSpecifications);
}
}
} catch (Exception e) {
Expand Down Expand Up @@ -162,10 +156,10 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor(Applicati
AiServiceRegisteredEvent registeredEvent;
if (aiServiceAnnotation.wiringMode() == EXPLICIT) {
propertyValues.add("tools", toManagedList(asList(aiServiceAnnotation.tools())));
registeredEvent = buildEvent(aiServiceClass, beanToolSpecifications, asList(aiServiceAnnotation.tools()));
registeredEvent = new AiServiceRegisteredEvent(this, aiServiceClass, toolSpecifications);
} else if (aiServiceAnnotation.wiringMode() == AUTOMATIC) {
propertyValues.add("tools", toManagedList(toolBeanNames));
registeredEvent = buildEvent(aiServiceClass, beanToolSpecifications, toolBeanNames);
registeredEvent = new AiServiceRegisteredEvent(this, aiServiceClass, toolSpecifications);
} else {
throw illegalArgument("Unknown wiring mode: " + aiServiceAnnotation.wiringMode());
}
Expand Down Expand Up @@ -224,14 +218,4 @@ private static ManagedList<RuntimeBeanReference> toManagedList(Collection<String
}
return managedList;
}

private AiServiceRegisteredEvent buildEvent(Class<?> aiServiceClass,
Map<String, List<ToolSpecification>> beanToolSpecifications,
Collection<String> toolBeanNames) {
return new AiServiceRegisteredEvent(this, aiServiceClass,
toolBeanNames.stream()
.filter(beanToolSpecifications::containsKey)
.flatMap(toolBeanName -> beanToolSpecifications.get(toolBeanName).stream())
.collect(Collectors.toList()));
}
}

0 comments on commit 1ecb553

Please sign in to comment.