From ea974c7824b2cc9244987da8c3bde2cb0f87bbac Mon Sep 17 00:00:00 2001
From: Martin7-1 <1754350460@qq.com>
Date: Wed, 11 Sep 2024 14:55:09 +0800
Subject: [PATCH] AiService support @Profile
---
langchain4j-spring-boot-starter/pom.xml | 13 -------
.../spring/AiServiceScannerProcessor.java | 39 ++++++++++++++++---
.../service/spring/AiServicesAutoConfig.java | 18 ++++++++-
.../withProfiles/AiServiceWithProfiles.java | 11 ++++++
.../AiServiceWithProfilesApplication.java | 20 ++++++++++
.../withProfiles/AiServiceWithProfilesIT.java | 37 ++++++++++++++++++
6 files changed, 118 insertions(+), 20 deletions(-)
create mode 100644 langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withProfiles/AiServiceWithProfiles.java
create mode 100644 langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withProfiles/AiServiceWithProfilesApplication.java
create mode 100644 langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withProfiles/AiServiceWithProfilesIT.java
diff --git a/langchain4j-spring-boot-starter/pom.xml b/langchain4j-spring-boot-starter/pom.xml
index 2482b1d0..8d036570 100644
--- a/langchain4j-spring-boot-starter/pom.xml
+++ b/langchain4j-spring-boot-starter/pom.xml
@@ -69,19 +69,6 @@
test
-
- org.tinylog
- tinylog-impl
- 2.6.2
- test
-
-
- org.tinylog
- slf4j-tinylog
- 2.6.2
- test
-
-
diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceScannerProcessor.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceScannerProcessor.java
index 12e75bf6..3a8cb719 100644
--- a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceScannerProcessor.java
+++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceScannerProcessor.java
@@ -6,18 +6,20 @@
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.boot.autoconfigure.AutoConfigurationPackages;
import org.springframework.boot.autoconfigure.SpringBootApplication;
+import org.springframework.context.EnvironmentAware;
import org.springframework.context.annotation.ComponentScan;
+import org.springframework.context.annotation.Profile;
import org.springframework.core.annotation.AnnotationUtils;
+import org.springframework.core.env.Environment;
import org.springframework.stereotype.Component;
import org.springframework.util.ClassUtils;
-import java.util.Collections;
-import java.util.LinkedHashSet;
-import java.util.List;
-import java.util.Set;
+import java.util.*;
@Component
-public class AiServiceScannerProcessor implements BeanDefinitionRegistryPostProcessor {
+public class AiServiceScannerProcessor implements BeanDefinitionRegistryPostProcessor, EnvironmentAware {
+
+ private Environment environment;
@Override
public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
@@ -26,6 +28,8 @@ public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) t
for (String basePackage : basePackages) {
classPathAiServiceScanner.scan(basePackage);
}
+
+ filterBeanDefinitions(registry);
}
private Set getBasePackages(ConfigurableListableBeanFactory beanFactory) {
@@ -63,4 +67,29 @@ private Set getBasePackages(ConfigurableListableBeanFactory beanFactory)
return basePackages;
}
+
+ private void filterBeanDefinitions(BeanDefinitionRegistry registry) {
+ Arrays.stream(registry.getBeanDefinitionNames())
+ .filter(beanName -> {
+ try {
+ Class> beanClass = Class.forName(registry.getBeanDefinition(beanName).getBeanClassName());
+ if (beanClass.isAnnotationPresent(AiService.class) && beanClass.isAnnotationPresent(Profile.class)) {
+ Profile profileAnnotation = beanClass.getAnnotation(Profile.class);
+ String[] profiles = profileAnnotation.value();
+
+ return !environment.matchesProfiles(profiles);
+ } else {
+ return false;
+ }
+ } catch (Exception e) {
+ // TODO
+ return false;
+ }
+ }).forEach(registry::removeBeanDefinition);
+ }
+
+ @Override
+ public void setEnvironment(Environment environment) {
+ this.environment = environment;
+ }
}
diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java
index 4c3f5a6e..b08c1cb2 100644
--- a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java
+++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java
@@ -15,9 +15,14 @@
import org.springframework.beans.factory.support.GenericBeanDefinition;
import org.springframework.beans.factory.support.ManagedList;
import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Profile;
+import org.springframework.core.env.Environment;
import java.lang.reflect.Method;
-import java.util.*;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Set;
import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
@@ -30,7 +35,7 @@
public class AiServicesAutoConfig {
@Bean
- BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
+ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor(Environment environment) {
return beanFactory -> {
// all components available in the application context
@@ -59,6 +64,15 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
for (String aiService : aiServices) {
Class> aiServiceClass = beanFactory.getType(aiService);
+ // Check profile
+ if (aiServiceClass.isAnnotationPresent(Profile.class)) {
+ Profile profileAnnotation = aiServiceClass.getAnnotation(Profile.class);
+ String[] profiles = profileAnnotation.value();
+ if (!environment.matchesProfiles(profiles)) {
+ continue;
+ }
+ }
+
GenericBeanDefinition aiServiceBeanDefinition = new GenericBeanDefinition();
aiServiceBeanDefinition.setBeanClass(AiServiceFactory.class);
aiServiceBeanDefinition.getConstructorArgumentValues().addGenericArgumentValue(aiServiceClass);
diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withProfiles/AiServiceWithProfiles.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withProfiles/AiServiceWithProfiles.java
new file mode 100644
index 00000000..c5d573a8
--- /dev/null
+++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withProfiles/AiServiceWithProfiles.java
@@ -0,0 +1,11 @@
+package dev.langchain4j.service.spring.mode.automatic.withProfiles;
+
+import dev.langchain4j.service.spring.AiService;
+import org.springframework.context.annotation.Profile;
+
+@AiService
+@Profile("!test")
+public interface AiServiceWithProfiles {
+
+ String chat(String userMessage);
+}
diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withProfiles/AiServiceWithProfilesApplication.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withProfiles/AiServiceWithProfilesApplication.java
new file mode 100644
index 00000000..5a6e133f
--- /dev/null
+++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withProfiles/AiServiceWithProfilesApplication.java
@@ -0,0 +1,20 @@
+package dev.langchain4j.service.spring.mode.automatic.withProfiles;
+
+import dev.langchain4j.model.chat.ChatLanguageModel;
+import dev.langchain4j.model.openai.OpenAiChatModel;
+import org.springframework.boot.SpringApplication;
+import org.springframework.boot.autoconfigure.SpringBootApplication;
+import org.springframework.context.annotation.Bean;
+
+@SpringBootApplication
+public class AiServiceWithProfilesApplication {
+
+ @Bean
+ ChatLanguageModel chatLanguageModel() {
+ return OpenAiChatModel.withApiKey(System.getenv("OPENAI_API_KEY"));
+ }
+
+ public static void main(String[] args) {
+ SpringApplication.run(AiServiceWithProfilesApplication.class, args);
+ }
+}
diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withProfiles/AiServiceWithProfilesIT.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withProfiles/AiServiceWithProfilesIT.java
new file mode 100644
index 00000000..5f65c4a7
--- /dev/null
+++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withProfiles/AiServiceWithProfilesIT.java
@@ -0,0 +1,37 @@
+package dev.langchain4j.service.spring.mode.automatic.withProfiles;
+
+import dev.langchain4j.service.spring.AiServicesAutoConfig;
+import org.junit.jupiter.api.Test;
+import org.springframework.beans.BeansException;
+import org.springframework.boot.autoconfigure.AutoConfigurations;
+import org.springframework.boot.test.context.runner.ApplicationContextRunner;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+class AiServiceWithProfilesIT {
+
+ ApplicationContextRunner contextRunner = new ApplicationContextRunner()
+ .withConfiguration(AutoConfigurations.of(AiServicesAutoConfig.class));
+
+ @Test
+ void should_not_create_ai_service() {
+ contextRunner
+ .withPropertyValues(
+ "spring.profiles.active=test"
+ )
+ .withUserConfiguration(AiServiceWithProfilesApplication.class)
+ .run(context -> assertThatThrownBy(() -> context.getBean(AiServiceWithProfiles.class))
+ .isInstanceOf(BeansException.class));
+ }
+
+ @Test
+ void should_create_ai_service() {
+ contextRunner
+ .withPropertyValues(
+ "spring.profiles.active=dev"
+ )
+ .withUserConfiguration(AiServiceWithProfilesApplication.class)
+ .run(context -> assertThat(context.getBean(AiServiceWithProfiles.class)).isNotNull());
+ }
+}