Skip to content

Commit

Permalink
WIP: declarative AI services and EasyRAG
Browse files Browse the repository at this point in the history
  • Loading branch information
langchain4j committed Mar 25, 2024
1 parent 91ad00d commit 83ace77
Show file tree
Hide file tree
Showing 68 changed files with 1,148 additions and 359 deletions.
14 changes: 13 additions & 1 deletion langchain4j-spring-boot-starter/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
</parent>

<artifactId>langchain4j-spring-boot-starter</artifactId>
<!-- TODO or langchain4j-spring-boot-autoconfigure ? -->
<name>Spring Boot starter for LangChain4j</name>

<dependencies>
Expand Down Expand Up @@ -76,6 +75,19 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.tinylog</groupId>
<artifactId>tinylog-impl</artifactId>
<version>2.6.2</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tinylog</groupId>
<artifactId>slf4j-tinylog</artifactId>
<version>2.6.2</version>
<scope>test</scope>
</dependency>

</dependencies>

<licenses>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
import java.lang.annotation.Retention;
import java.lang.annotation.Target;

import static dev.langchain4j.service.spring.AiServiceWiringMode.AUTOMATIC;
import static java.lang.annotation.ElementType.TYPE;
import static java.lang.annotation.RetentionPolicy.RUNTIME;

/**
* Any interface annotated with {@code @AiService} will be automatically registered as a bean
* and configured to use all the following components (beans) available in the context:
* An interface annotated with {@code @AiService} will be automatically registered as a bean
* and wired with all the following components (beans) available in the context:
* <pre>
* - {@link ChatLanguageModel}
* - {@link StreamingChatLanguageModel}
Expand All @@ -27,8 +28,9 @@
* - {@link RetrievalAugmentor}
* - All beans containing methods annotated with {@code @}{@link Tool}
* </pre>
* You can also explicitly specify which components this AI Service should use by specifying bean names
* using the following properties:
* You can also explicitly specify which components (beans) should be wired into this AI Service
* by setting {@link #wiringMode()} to {@link AiServiceWiringMode#EXPLICIT}
* and specifying bean names using the following attributes:
* <pre>
* - {@link #chatModel()}
* - {@link #streamingChatModel()}
Expand All @@ -47,37 +49,49 @@
public @interface AiService {

/**
* The name of a {@link ChatLanguageModel} bean that should be used by this AI Service.
* Specifies how LangChain4j components (beans) are wired (injected) into this AI Service.
*/
AiServiceWiringMode wiringMode() default AUTOMATIC;

/**
* When the {@link #wiringMode()} is set to {@link AiServiceWiringMode#EXPLICIT},
* this attribute specifies the name of a {@link ChatLanguageModel} bean that should be used by this AI Service.
*/
String chatModel() default "";

/**
* The name of a {@link StreamingChatLanguageModel} bean that should be used by this AI Service.
* When the {@link #wiringMode()} is set to {@link AiServiceWiringMode#EXPLICIT},
* this attribute specifies the name of a {@link StreamingChatLanguageModel} bean that should be used by this AI Service.
*/
String streamingChatModel() default "";

/**
* The name of a {@link ChatMemory} bean that should be used by this AI Service.
* When the {@link #wiringMode()} is set to {@link AiServiceWiringMode#EXPLICIT},
* this attribute specifies the name of a {@link ChatMemory} bean that should be used by this AI Service.
*/
String chatMemory() default "";

/**
* The name of a {@link ChatMemoryProvider} bean that should be used by this AI Service.
* When the {@link #wiringMode()} is set to {@link AiServiceWiringMode#EXPLICIT},
* this attribute specifies the name of a {@link ChatMemoryProvider} bean that should be used by this AI Service.
*/
String chatMemoryProvider() default "";

/**
* The name of a {@link ContentRetriever} bean that should be used by this AI Service.
* When the {@link #wiringMode()} is set to {@link AiServiceWiringMode#EXPLICIT},
* this attribute specifies the name of a {@link ContentRetriever} bean that should be used by this AI Service.
*/
String contentRetriever() default "";

/**
* The name of a {@link RetrievalAugmentor} bean that should be used by this AI Service.
* When the {@link #wiringMode()} is set to {@link AiServiceWiringMode#EXPLICIT},
* this attribute specifies the name of a {@link RetrievalAugmentor} bean that should be used by this AI Service.
*/
String retrievalAugmentor() default "";

/**
* The names of beans containing methods annotated with {@link Tool} that should be used by this AI Service.
* When the {@link #wiringMode()} is set to {@link AiServiceWiringMode#EXPLICIT},
* this attribute specifies the names of beans containing methods annotated with {@link Tool} that should be used by this AI Service.
*/
String[] tools() default {};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import dev.langchain4j.service.AiServices;
import org.springframework.beans.factory.FactoryBean;

import java.util.ArrayList;
import java.util.List;

import static dev.langchain4j.internal.Utils.isNullOrEmpty;
Expand All @@ -23,7 +22,7 @@ class AiServiceFactory implements FactoryBean<Object> {
private ChatMemoryProvider chatMemoryProvider;
private ContentRetriever contentRetriever;
private RetrievalAugmentor retrievalAugmentor;
private final List<Object> beansWithTools = new ArrayList<>();
private List<Object> tools;

public AiServiceFactory(Class<Object> aiServiceClass) {
this.aiServiceClass = aiServiceClass;
Expand Down Expand Up @@ -53,8 +52,8 @@ public void setRetrievalAugmentor(RetrievalAugmentor retrievalAugmentor) {
this.retrievalAugmentor = retrievalAugmentor;
}

public void setBeanWithTools(Object beanWithTools) {
this.beansWithTools.add(beanWithTools);
public void setTools(List<Object> tools) {
this.tools = tools;
}

@Override
Expand Down Expand Up @@ -86,8 +85,8 @@ public Object getObject() {
builder = builder.retrievalAugmentor(retrievalAugmentor);
}

if (!isNullOrEmpty(beansWithTools)) {
builder = builder.tools(beansWithTools);
if (!isNullOrEmpty(tools)) {
builder = builder.tools(tools);
}

return builder.build();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package dev.langchain4j.service.spring;

/**
* Specifies how LangChain4j components are wired (injected) into a given AI Service.
*/
public enum AiServiceWiringMode {

/**
* All LangChain4j components available in the application context are wired automatically into a given AI Service.
* If there are multiple components of the same type, an exception is thrown.
*/
AUTOMATIC,

/**
* Only explicitly specified LangChain4j components are wired into a given AI Service.
* Component (bean) names are specified using attributes of {@link AiService} annotation like this:
* {@code AiService(wiringMode = EXPLICIT, chatMemory = "<name of a ChatMemory bean>")}
*/
EXPLICIT
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,46 @@
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.GenericBeanDefinition;
import org.springframework.beans.factory.support.ManagedList;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;

import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.Utils.isNotNullOrBlank;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.service.spring.AiServiceWiringMode.AUTOMATIC;
import static dev.langchain4j.service.spring.AiServiceWiringMode.EXPLICIT;
import static java.util.Arrays.asList;

public class AiServicesAutoConfig {

@Bean
BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
return beanFactory -> {

// all components available in the application context
String[] chatLanguageModels = beanFactory.getBeanNamesForType(ChatLanguageModel.class);
String[] streamingChatLanguageModels = beanFactory.getBeanNamesForType(StreamingChatLanguageModel.class);
String[] chatMemories = beanFactory.getBeanNamesForType(ChatMemory.class);
String[] chatMemoryProviders = beanFactory.getBeanNamesForType(ChatMemoryProvider.class);
String[] contentRetrievers = beanFactory.getBeanNamesForType(ContentRetriever.class);
String[] retrievalAugmentors = beanFactory.getBeanNamesForType(RetrievalAugmentor.class);

Set<String> beansWithTools = new HashSet<>();
Set<String> tools = new HashSet<>();
for (String beanName : beanFactory.getBeanDefinitionNames()) {
try {
Class<?> beanClass = Class.forName(beanFactory.getBeanDefinition(beanName).getBeanClassName());
System.out.println();
for (Method beanMethod : beanClass.getDeclaredMethods()) {
if (beanMethod.isAnnotationPresent(Tool.class)) {
beansWithTools.add(beanName);
tools.add(beanName);
}
}
} catch (Exception e) {
Expand All @@ -72,60 +80,70 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {

addBeanReference(
ChatLanguageModel.class,
aiServiceAnnotation,
aiServiceAnnotation.chatModel(),
chatLanguageModels,
"chatModel",
"chatLanguageModel",
propertyValues
);

addBeanReference(
StreamingChatLanguageModel.class,
aiServiceAnnotation,
aiServiceAnnotation.streamingChatModel(),
streamingChatLanguageModels,
"streamingChatModel",
"streamingChatLanguageModel",
propertyValues
);

addBeanReference(
ChatMemory.class,
aiServiceAnnotation,
aiServiceAnnotation.chatMemory(),
chatMemories,
"chatMemory",
"chatMemory",
propertyValues
);

addBeanReference(
ChatMemoryProvider.class,
aiServiceAnnotation,
aiServiceAnnotation.chatMemoryProvider(),
chatMemoryProviders,
"chatMemoryProvider",
"chatMemoryProvider",
propertyValues
);

addBeanReference(
ContentRetriever.class,
aiServiceAnnotation,
aiServiceAnnotation.contentRetriever(),
contentRetrievers,
"contentRetriever",
"contentRetriever",
propertyValues
);

addBeanReference(
RetrievalAugmentor.class,
aiServiceAnnotation,
aiServiceAnnotation.retrievalAugmentor(),
retrievalAugmentors,
"retrievalAugmentor",
"retrievalAugmentor",
propertyValues
);

if (aiServiceAnnotation.tools().length > 0) {
for (String beanWithTools : aiServiceAnnotation.tools()) {
propertyValues.add("beanWithTools", new RuntimeBeanReference(beanWithTools));
}
if (aiServiceAnnotation.wiringMode() == EXPLICIT) {
propertyValues.add("tools", toManagedList(asList(aiServiceAnnotation.tools())));
} else if (aiServiceAnnotation.wiringMode() == AUTOMATIC) {
propertyValues.add("tools", toManagedList(tools));
} else {
for (String beanWithTools : beansWithTools) {
propertyValues.add("beanWithTools", new RuntimeBeanReference(beanWithTools));
}
throw illegalArgument("Unknown component selection mode: " + aiServiceAnnotation.wiringMode());
}

BeanDefinitionRegistry registry = (BeanDefinitionRegistry) beanFactory;
Expand All @@ -143,25 +161,32 @@ private static Set<Class<?>> findAiServices(ConfigurableListableBeanFactory bean
}

private static void addBeanReference(Class<?> beanType,
AiService aiServiceAnnotation,
String customBeanName,
String[] beanNames,
String propertyName,
String annotationAttributeName,
String factoryPropertyName,
MutablePropertyValues propertyValues) {
if (isNotNullOrBlank(customBeanName)) {
propertyValues.add(propertyName, new RuntimeBeanReference(customBeanName));
} else {
if (aiServiceAnnotation.wiringMode() == EXPLICIT) {
if (isNotNullOrBlank(customBeanName)) {
propertyValues.add(factoryPropertyName, new RuntimeBeanReference(customBeanName));
}
} else if (aiServiceAnnotation.wiringMode() == AUTOMATIC) {
if (beanNames.length == 1) {
propertyValues.add(propertyName, new RuntimeBeanReference(beanNames[0]));
propertyValues.add(factoryPropertyName, new RuntimeBeanReference(beanNames[0]));
} else if (beanNames.length > 1) {
throw conflict(beanType, beanNames);
throw conflict(beanType, beanNames, annotationAttributeName);
}
} else {
throw illegalArgument("Unknown wiring mode: " + aiServiceAnnotation.wiringMode());
}
}

private static IllegalConfigurationException conflict(Class<?> beanType, Object[] beanNames) {
private static IllegalConfigurationException conflict(Class<?> beanType, Object[] beanNames, String attributeName) {
return illegalConfiguration("Conflict: multiple beans of type %s are found: %s. " +
"Please specify which one you wish to use in the @AiService annotation like this: " +
"@AiService(chatModel = \"<beanName>\").", beanType.getName(), Arrays.toString(beanNames));
"Please specify which one you wish to wire in the @AiService annotation like this: " +
"@AiService(wiringMode = EXPLICIT, %s = \"<beanName>\").",
beanType.getName(), Arrays.toString(beanNames), attributeName);
}

private static String lowercaseFirstLetter(String text) {
Expand All @@ -170,4 +195,12 @@ private static String lowercaseFirstLetter(String text) {
}
return text.substring(0, 1).toLowerCase() + text.substring(1);
}

private static ManagedList<RuntimeBeanReference> toManagedList(Collection<String> beanNames) {
ManagedList<RuntimeBeanReference> managedList = new ManagedList<>();
for (String beanName : beanNames) {
managedList.add(new RuntimeBeanReference(beanName));
}
return managedList;
}
}

This file was deleted.

This file was deleted.

Loading

0 comments on commit 83ace77

Please sign in to comment.