diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceFactory.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceFactory.java index 7ce673a0..7e215c8e 100644 --- a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceFactory.java +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceFactory.java @@ -77,12 +77,10 @@ public Object getObject() { builder.chatMemoryProvider(chatMemoryProvider); } - if (contentRetriever != null) { - builder = builder.contentRetriever(contentRetriever); - } - if (retrievalAugmentor != null) { builder = builder.retrievalAugmentor(retrievalAugmentor); + } else if (contentRetriever != null) { + builder = builder.contentRetriever(contentRetriever); } if (!isNullOrEmpty(tools)) { diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withContentRetrieverAndRetrievalAugmentor/withRetrievalAugmentor/AiServiceWithContentRetrieverAndRetrievalAugmentor.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withContentRetrieverAndRetrievalAugmentor/withRetrievalAugmentor/AiServiceWithContentRetrieverAndRetrievalAugmentor.java new file mode 100644 index 00000000..6ec2b225 --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withContentRetrieverAndRetrievalAugmentor/withRetrievalAugmentor/AiServiceWithContentRetrieverAndRetrievalAugmentor.java @@ -0,0 +1,9 @@ +package dev.langchain4j.service.spring.mode.automatic.withContentRetrieverAndRetrievalAugmentor.withRetrievalAugmentor; + +import dev.langchain4j.service.spring.AiService; + +@AiService +interface AiServiceWithContentRetrieverAndRetrievalAugmentor { + + String chat(String userMessage); +} \ No newline at end of file diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withContentRetrieverAndRetrievalAugmentor/withRetrievalAugmentor/AiServiceWithContentRetrieverAndRetrievalAugmentorApplication.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withContentRetrieverAndRetrievalAugmentor/withRetrievalAugmentor/AiServiceWithContentRetrieverAndRetrievalAugmentorApplication.java new file mode 100644 index 00000000..340b62a8 --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withContentRetrieverAndRetrievalAugmentor/withRetrievalAugmentor/AiServiceWithContentRetrieverAndRetrievalAugmentorApplication.java @@ -0,0 +1,31 @@ +package dev.langchain4j.service.spring.mode.automatic.withContentRetrieverAndRetrievalAugmentor.withRetrievalAugmentor; + +import dev.langchain4j.rag.DefaultRetrievalAugmentor; +import dev.langchain4j.rag.RetrievalAugmentor; +import dev.langchain4j.rag.content.Content; +import dev.langchain4j.rag.content.retriever.ContentRetriever; +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.annotation.Bean; + +import static java.util.Collections.singletonList; + +@SpringBootApplication +class AiServiceWithContentRetrieverAndRetrievalAugmentorApplication { + + @Bean + ContentRetriever contentRetriever() { + return query -> singletonList(Content.from("My name is Klaus.")); + } + + @Bean + RetrievalAugmentor retrievalAugmentor(ContentRetriever contentRetriever) { + return DefaultRetrievalAugmentor.builder() + .contentRetriever(contentRetriever) + .build(); + } + + public static void main(String[] args) { + SpringApplication.run(AiServiceWithContentRetrieverAndRetrievalAugmentorApplication.class, args); + } +} diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withContentRetrieverAndRetrievalAugmentor/withRetrievalAugmentor/AiServiceWithContentRetrieverAndRetrievalAugmentorIT.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withContentRetrieverAndRetrievalAugmentor/withRetrievalAugmentor/AiServiceWithContentRetrieverAndRetrievalAugmentorIT.java new file mode 100644 index 00000000..2142d2d4 --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withContentRetrieverAndRetrievalAugmentor/withRetrievalAugmentor/AiServiceWithContentRetrieverAndRetrievalAugmentorIT.java @@ -0,0 +1,37 @@ +package dev.langchain4j.service.spring.mode.automatic.withContentRetrieverAndRetrievalAugmentor.withRetrievalAugmentor; + +import dev.langchain4j.service.spring.AiServicesAutoConfig; +import org.junit.jupiter.api.Test; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static dev.langchain4j.service.spring.mode.ApiKeys.OPENAI_API_KEY; +import static org.assertj.core.api.Assertions.assertThat; + +class AiServiceWithContentRetrieverAndRetrievalAugmentorIT { + + ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(AiServicesAutoConfig.class)); + + @Test + void should_create_AI_service_with_content_retriever_and_retrieval_augmentor() { + contextRunner + .withPropertyValues( + "langchain4j.open-ai.chat-model.api-key=" + OPENAI_API_KEY, + "langchain4j.open-ai.chat-model.max-tokens=20", + "langchain4j.open-ai.chat-model.temperature=0.0" + ) + .withUserConfiguration(AiServiceWithContentRetrieverAndRetrievalAugmentorApplication.class) + .run(context -> { + + // given + AiServiceWithContentRetrieverAndRetrievalAugmentor aiService = context.getBean(AiServiceWithContentRetrieverAndRetrievalAugmentor.class); + + // when + String answer = aiService.chat("What is my name?"); + + // then + assertThat(answer).containsIgnoringCase("Klaus"); + }); + } +} \ No newline at end of file