diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceScannerProcessor.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceScannerProcessor.java index 02e16ae0..2649b862 100644 --- a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceScannerProcessor.java +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceScannerProcessor.java @@ -46,25 +46,27 @@ private Set getBasePackages(ConfigurableListableBeanFactory beanFactory) } private void addComponentScanPackages(ConfigurableListableBeanFactory beanFactory, Set collectedBasePackages) { - beanFactory.getBeansWithAnnotation(ComponentScan.class).forEach((beanName, instance) -> { - Set componentScans = AnnotatedElementUtils.getMergedRepeatableAnnotations(instance.getClass(), ComponentScan.class); - for (ComponentScan componentScan : componentScans) { - Set basePackages = new LinkedHashSet<>(); - String[] basePackagesArray = componentScan.basePackages(); - for (String pkg : basePackagesArray) { - String[] tokenized = StringUtils.tokenizeToStringArray(this.environment.resolvePlaceholders(pkg), - ConfigurableApplicationContext.CONFIG_LOCATION_DELIMITERS); - Collections.addAll(basePackages, tokenized); - } - for (Class clazz : componentScan.basePackageClasses()) { - basePackages.add(ClassUtils.getPackageName(clazz)); - } - if (basePackages.isEmpty()) { - basePackages.add(ClassUtils.getPackageName(instance.getClass())); + for (String beanName : beanFactory.getBeanNamesForAnnotation(ComponentScan.class)) { + Class beanClass = beanFactory.getType(beanName); + if (beanClass != null) { + Set componentScans = AnnotatedElementUtils.getMergedRepeatableAnnotations(beanClass, ComponentScan.class); + for (ComponentScan componentScan : componentScans) { + Set basePackages = new LinkedHashSet<>(); + for (String pkg : componentScan.basePackages()) { + String[] tokenized = StringUtils.tokenizeToStringArray(this.environment.resolvePlaceholders(pkg), + ConfigurableApplicationContext.CONFIG_LOCATION_DELIMITERS); + Collections.addAll(basePackages, tokenized); + } + for (Class clazz : componentScan.basePackageClasses()) { + basePackages.add(ClassUtils.getPackageName(clazz)); + } + if (basePackages.isEmpty()) { + basePackages.add(ClassUtils.getPackageName(beanClass)); + } + collectedBasePackages.addAll(basePackages); } - collectedBasePackages.addAll(basePackages); } - }); + } } private void removeAiServicesWithInactiveProfiles(BeanDefinitionRegistry registry) { diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/Issue2133/TestAutowireAiServiceApplication.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/Issue2133/TestAutowireAiServiceApplication.java new file mode 100644 index 00000000..1c67ff9c --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/Issue2133/TestAutowireAiServiceApplication.java @@ -0,0 +1,24 @@ +package dev.langchain4j.service.spring.mode.automatic.Issue2133; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; + +/** + * @author: qing + * @Date: 2024/11/20 + */ +@SpringBootApplication +public class TestAutowireAiServiceApplication { + + @Autowired + TestAutowireConfiguration testAutowireConfiguration; + + public static void main(String[] args) { + SpringApplication.run(TestAutowireAiServiceApplication.class, args); + } + + TestAutowireConfiguration getConfiguration() { + return testAutowireConfiguration; + } +} diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/Issue2133/TestAutowireClassAiServiceIT.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/Issue2133/TestAutowireClassAiServiceIT.java new file mode 100644 index 00000000..a640848b --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/Issue2133/TestAutowireClassAiServiceIT.java @@ -0,0 +1,28 @@ +package dev.langchain4j.service.spring.mode.automatic.Issue2133; + +import dev.langchain4j.service.spring.AiServicesAutoConfig; +import org.junit.jupiter.api.Test; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.junit.jupiter.api.Assertions.assertNotNull; + +class TestAutowireClassAiServiceIT { + + ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(AiServicesAutoConfig.class)); + + @Test + void should_get_configuration_class() { + contextRunner + .withUserConfiguration(TestAutowireAiServiceApplication.class) + .withBean(TestAutowireConfiguration.class) + .run(context -> { + // given + TestAutowireAiServiceApplication application = context.getBean(TestAutowireAiServiceApplication.class); + + // should get the configuration class + assertNotNull(application.getConfiguration(), "TestConfiguration class should be not null"); + }); + } +} \ No newline at end of file diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/Issue2133/TestAutowireConfiguration.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/Issue2133/TestAutowireConfiguration.java new file mode 100644 index 00000000..15d7a88a --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/Issue2133/TestAutowireConfiguration.java @@ -0,0 +1,11 @@ +package dev.langchain4j.service.spring.mode.automatic.Issue2133; + +import org.springframework.context.annotation.Configuration; + +/** + * @author: qing + * @Date: 2024/11/20 + */ +@Configuration +class TestAutowireConfiguration { +}