Skip to content

Commit

Permalink
AiService support @Profile
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin7-1 committed Sep 11, 2024
1 parent 8ae6b5b commit ea974c7
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 20 deletions.
13 changes: 0 additions & 13 deletions langchain4j-spring-boot-starter/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,6 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.tinylog</groupId>
<artifactId>tinylog-impl</artifactId>
<version>2.6.2</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tinylog</groupId>
<artifactId>slf4j-tinylog</artifactId>
<version>2.6.2</version>
<scope>test</scope>
</dependency>

</dependencies>

<licenses>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.boot.autoconfigure.AutoConfigurationPackages;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.EnvironmentAware;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Profile;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.env.Environment;
import org.springframework.stereotype.Component;
import org.springframework.util.ClassUtils;

import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.*;

@Component
public class AiServiceScannerProcessor implements BeanDefinitionRegistryPostProcessor {
public class AiServiceScannerProcessor implements BeanDefinitionRegistryPostProcessor, EnvironmentAware {

private Environment environment;

@Override
public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
Expand All @@ -26,6 +28,8 @@ public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) t
for (String basePackage : basePackages) {
classPathAiServiceScanner.scan(basePackage);
}

filterBeanDefinitions(registry);
}

private Set<String> getBasePackages(ConfigurableListableBeanFactory beanFactory) {
Expand Down Expand Up @@ -63,4 +67,29 @@ private Set<String> getBasePackages(ConfigurableListableBeanFactory beanFactory)

return basePackages;
}

private void filterBeanDefinitions(BeanDefinitionRegistry registry) {
Arrays.stream(registry.getBeanDefinitionNames())
.filter(beanName -> {
try {
Class<?> beanClass = Class.forName(registry.getBeanDefinition(beanName).getBeanClassName());
if (beanClass.isAnnotationPresent(AiService.class) && beanClass.isAnnotationPresent(Profile.class)) {
Profile profileAnnotation = beanClass.getAnnotation(Profile.class);
String[] profiles = profileAnnotation.value();

return !environment.matchesProfiles(profiles);
} else {
return false;
}
} catch (Exception e) {
// TODO
return false;
}
}).forEach(registry::removeBeanDefinition);
}

@Override
public void setEnvironment(Environment environment) {
this.environment = environment;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@
import org.springframework.beans.factory.support.GenericBeanDefinition;
import org.springframework.beans.factory.support.ManagedList;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Profile;
import org.springframework.core.env.Environment;

import java.lang.reflect.Method;
import java.util.*;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;

import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
Expand All @@ -30,7 +35,7 @@
public class AiServicesAutoConfig {

@Bean
BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor(Environment environment) {
return beanFactory -> {

// all components available in the application context
Expand Down Expand Up @@ -59,6 +64,15 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
for (String aiService : aiServices) {
Class<?> aiServiceClass = beanFactory.getType(aiService);

// Check profile
if (aiServiceClass.isAnnotationPresent(Profile.class)) {
Profile profileAnnotation = aiServiceClass.getAnnotation(Profile.class);
String[] profiles = profileAnnotation.value();
if (!environment.matchesProfiles(profiles)) {
continue;
}
}

GenericBeanDefinition aiServiceBeanDefinition = new GenericBeanDefinition();
aiServiceBeanDefinition.setBeanClass(AiServiceFactory.class);
aiServiceBeanDefinition.getConstructorArgumentValues().addGenericArgumentValue(aiServiceClass);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package dev.langchain4j.service.spring.mode.automatic.withProfiles;

import dev.langchain4j.service.spring.AiService;
import org.springframework.context.annotation.Profile;

@AiService
@Profile("!test")
public interface AiServiceWithProfiles {

String chat(String userMessage);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package dev.langchain4j.service.spring.mode.automatic.withProfiles;

import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;

@SpringBootApplication
public class AiServiceWithProfilesApplication {

@Bean
ChatLanguageModel chatLanguageModel() {
return OpenAiChatModel.withApiKey(System.getenv("OPENAI_API_KEY"));
}

public static void main(String[] args) {
SpringApplication.run(AiServiceWithProfilesApplication.class, args);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package dev.langchain4j.service.spring.mode.automatic.withProfiles;

import dev.langchain4j.service.spring.AiServicesAutoConfig;
import org.junit.jupiter.api.Test;
import org.springframework.beans.BeansException;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

class AiServiceWithProfilesIT {

ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withConfiguration(AutoConfigurations.of(AiServicesAutoConfig.class));

@Test
void should_not_create_ai_service() {
contextRunner
.withPropertyValues(
"spring.profiles.active=test"
)
.withUserConfiguration(AiServiceWithProfilesApplication.class)
.run(context -> assertThatThrownBy(() -> context.getBean(AiServiceWithProfiles.class))
.isInstanceOf(BeansException.class));
}

@Test
void should_create_ai_service() {
contextRunner
.withPropertyValues(
"spring.profiles.active=dev"
)
.withUserConfiguration(AiServiceWithProfilesApplication.class)
.run(context -> assertThat(context.getBean(AiServiceWithProfiles.class)).isNotNull());
}
}

0 comments on commit ea974c7

Please sign in to comment.