Skip to content

Commit

Permalink
[FEATURE] Get all metadata of AiService & Tool when the Spring Boot s…
Browse files Browse the repository at this point in the history
  • Loading branch information
catofdestruction committed Nov 15, 2024
1 parent c6c183b commit c7a57e7
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package dev.langchain4j.service.spring;

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.agent.tool.ToolSpecifications;
import dev.langchain4j.exception.IllegalConfigurationException;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
Expand All @@ -9,19 +11,27 @@
import dev.langchain4j.model.moderation.ModerationModel;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent;
import org.springframework.beans.MutablePropertyValues;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.GenericBeanDefinition;
import org.springframework.beans.factory.support.ManagedList;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.context.annotation.Bean;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
Expand All @@ -31,7 +41,14 @@
import static dev.langchain4j.service.spring.AiServiceWiringMode.EXPLICIT;
import static java.util.Arrays.asList;

public class AiServicesAutoConfig {
public class AiServicesAutoConfig implements ApplicationEventPublisherAware {

private ApplicationEventPublisher eventPublisher;

@Override
public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
this.eventPublisher = applicationEventPublisher;
}

@Bean
BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
Expand All @@ -47,12 +64,15 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
String[] moderationModels = beanFactory.getBeanNamesForType(ModerationModel.class);

Set<String> tools = new HashSet<>();
Map<String, ToolSpecification> toolSpecifications = new HashMap<>();
for (String beanName : beanFactory.getBeanDefinitionNames()) {
try {
Class<?> beanClass = Class.forName(beanFactory.getBeanDefinition(beanName).getBeanClassName());
for (Method beanMethod : beanClass.getDeclaredMethods()) {
if (beanMethod.isAnnotationPresent(Tool.class)) {
tools.add(beanName);
if (tools.add(beanName)) {
toolSpecifications.put(beanName, ToolSpecifications.toolSpecificationFrom(beanMethod));
}
}
}
} catch (Exception e) {
Expand All @@ -70,7 +90,6 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
MutablePropertyValues propertyValues = aiServiceBeanDefinition.getPropertyValues();

AiService aiServiceAnnotation = aiServiceClass.getAnnotation(AiService.class);

addBeanReference(
ChatLanguageModel.class,
aiServiceAnnotation,
Expand Down Expand Up @@ -140,18 +159,24 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
"moderationModel",
propertyValues
);

AiServiceRegisteredEvent registeredEvent;
if (aiServiceAnnotation.wiringMode() == EXPLICIT) {
propertyValues.add("tools", toManagedList(asList(aiServiceAnnotation.tools())));
registeredEvent = buildEvent(aiServiceClass, toolSpecifications, asList(aiServiceAnnotation.tools()));
} else if (aiServiceAnnotation.wiringMode() == AUTOMATIC) {
propertyValues.add("tools", toManagedList(tools));
registeredEvent = buildEvent(aiServiceClass, toolSpecifications, tools);
} else {
throw illegalArgument("Unknown wiring mode: " + aiServiceAnnotation.wiringMode());
}

BeanDefinitionRegistry registry = (BeanDefinitionRegistry) beanFactory;
registry.removeBeanDefinition(aiService);
registry.registerBeanDefinition(lowercaseFirstLetter(aiService), aiServiceBeanDefinition);

if (eventPublisher != null) {
eventPublisher.publishEvent(registeredEvent);
}
}
};
}
Expand Down Expand Up @@ -199,4 +224,13 @@ private static ManagedList<RuntimeBeanReference> toManagedList(Collection<String
}
return managedList;
}

private static AiServiceRegisteredEvent buildEvent(Class<?> aiServiceClass,
Map<String, ToolSpecification> toolSpecifications,
Collection<String> tools) {
return new AiServiceRegisteredEvent(aiServiceClass, aiServiceClass, tools.stream()
.filter(toolSpecifications::containsKey)
.map(toolSpecifications::get)
.collect(Collectors.toList()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package dev.langchain4j.service.spring.event;

import dev.langchain4j.agent.tool.ToolSpecification;
import org.springframework.context.ApplicationEvent;

import java.util.List;

public class AiServiceRegisteredEvent extends ApplicationEvent {

private final Class<?> aiServiceClass;
private final List<ToolSpecification> toolSpecifications;

public AiServiceRegisteredEvent(Object source, Class<?> aiServiceClass, List<ToolSpecification> toolSpecifications) {
super(source);
this.aiServiceClass = aiServiceClass;
this.toolSpecifications = toolSpecifications;
}

public Class<?> getAiServiceClass() {
return aiServiceClass;
}

public List<ToolSpecification> getToolSpecifications() {
return toolSpecifications;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package dev.langchain4j.service.spring.event;

import org.springframework.context.ApplicationListener;

public interface AiServiceRegisteredEventListener extends ApplicationListener<AiServiceRegisteredEvent> {
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
package dev.langchain4j.service.spring.mode.automatic.withTools;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent;
import dev.langchain4j.service.spring.event.AiServiceRegisteredEventListener;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

import java.util.List;

@SpringBootApplication
class AiServiceWithToolsApplication {
class AiServiceWithToolsApplication implements AiServiceRegisteredEventListener {

public static void main(String[] args) {
SpringApplication.run(AiServiceWithToolsApplication.class, args);
}

@Override
public void onApplicationEvent(AiServiceRegisteredEvent event) {
Class<?> aiServiceClass = event.getAiServiceClass();
List<ToolSpecification> toolSpecifications = event.getToolSpecifications();
for (int i = 0; i < toolSpecifications.size(); i++) {
System.out.printf("[%s]: [Tool-%s]: %s%n", aiServiceClass.getSimpleName(), i + 1, toolSpecifications.get(i));
}
}
}

0 comments on commit c7a57e7

Please sign in to comment.