Skip to content

Commit

Permalink
Merge pull request #4 from arey/feature/rag-query-router
Browse files Browse the repository at this point in the history
Add a QueryRouter to utilize the Embedding Store when appropriate
  • Loading branch information
arey authored Oct 30, 2024
2 parents 7cfcc01 + 1d972c1 commit 3462623
Show file tree
Hide file tree
Showing 10 changed files with 129 additions and 43 deletions.
1 change: 1 addition & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ This can be achieved thanks to:
[AssistantTool](src/main/java/org/springframework/samples/petclinic/chat/AssistantTool.java) uses Java records as the LLM/ input/output data structure.
* **Retrieval-Augmented Generation** (RAG) enables an LLM to incorporate and respond based on specific data—such as data from the petclinic database—by ingesting and referencing it during interactions.
The [AssistantConfiguration](src/main/java/org/springframework/samples/petclinic/chat/AssistantConfiguration.java) declares the `EmbeddingModel`, `InMemoryEmbeddingStore` and `EmbeddingStoreContentRetriever`beans while the [EmbeddingStoreInit](src/main/java/org/springframework/samples/petclinic/chat/EmbeddingStoreInit.java) class handles vets data ingestion at startup.
The [VetQueryRouter](src/main/java/org/springframework/samples/petclinic/chat/VetQueryRouter.java) demonstrates how to conditionally skip retrieval, with decision-making driven by an LLM.

Spring Petclinic integrates a Chatbot that allows you to interact with the application in a natural language. Here are **some examples** of what you could ask:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2.AllMiniLmL6V2EmbeddingModel;
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import org.springframework.context.annotation.Bean;
Expand Down Expand Up @@ -37,4 +41,11 @@ EmbeddingStoreContentRetriever contentRetriever(InMemoryEmbeddingStore<TextSegme
return new EmbeddingStoreContentRetriever(embeddingStore, embeddingModel);
}

@Bean
RetrievalAugmentor retrievalAugmentor(ChatLanguageModel chatLanguageModel, ContentRetriever vetContentRetriever) {
return DefaultRetrievalAugmentor.builder()
.queryRouter(new VetQueryRouter(chatLanguageModel, vetContentRetriever))
.build();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package org.springframework.samples.petclinic.chat;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.rag.query.router.QueryRouter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collection;

import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;

/**
* This filter illustrates how to conditionally skip retrieval. In some cases, retrieval
* isn’t needed, such as when a user simply says "Hi". Additionally, only the clinic's
* veterinarians are indexed in the Embedding Store.
* <p>
* To implement this, a custom {@link QueryRouter} is the simplest approach. When
* retrieval is unnecessary, the QueryRouter returns an empty list, indicating that the
* query won’t be routed to any {@link ContentRetriever}.
* <p>
* Decision-making relies on an LLM, which determines whether retrieval is needed based on
* the user's query.
* <p>
*
* @see <a href=
* "https://github.com/langchain4j/langchain4j-examples/blob/main/rag-examples/src/main/java/_3_advanced/_06_Advanced_RAG_Skip_Retrieval_Example.java">_06_Advanced_RAG_Skip_Retrieval_Example.java</a>
*/
class VetQueryRouter implements QueryRouter {

private static final Logger LOGGER = LoggerFactory.getLogger(VetQueryRouter.class);

private static final PromptTemplate PROMPT_TEMPLATE = PromptTemplate.from("""
Is the following query related to one or more veterinarians of the pet clinic?
Answer only 'yes' or 'no'.
Query: {{it}}
""");

private final ContentRetriever vetContentRetriever;

private final ChatLanguageModel chatLanguageModel;

public VetQueryRouter(ChatLanguageModel chatLanguageModel, ContentRetriever vetContentRetriever) {
this.chatLanguageModel = chatLanguageModel;
this.vetContentRetriever = vetContentRetriever;
}

@Override
public Collection<ContentRetriever> route(Query query) {
Prompt prompt = PROMPT_TEMPLATE.apply(query.text());

AiMessage aiMessage = chatLanguageModel.generate(prompt.toUserMessage()).content();
LOGGER.debug("LLM decided: {}", aiMessage.text());

if (aiMessage.text().toLowerCase().contains("yes")) {
return singletonList(vetContentRetriever);
}
return emptyList();
}

}
10 changes: 9 additions & 1 deletion src/main/resources/application.properties
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ management.endpoints.web.exposure.include=*
# Logging
logging.level.dev.langchain4j=DEBUG
logging.level.dev.ai4j.openai4j=DEBUG
logging.level.org.springframework.samples.petclinic=DEBUG
# logging.level.org.springframework.web=DEBUG
# logging.level.org.springframework.context.annotation=TRACE

Expand All @@ -32,12 +33,19 @@ langchain4j.azure-open-ai.streaming-chat-model.api-key=${AZURE_OPENAI_KEY}
langchain4j.azure-open-ai.streaming-chat-model.endpoint=${AZURE_OPENAI_ENDPOINT}
langchain4j.azure-open-ai.streaming-chat-model.deployment-name=gpt-4o
langchain4j.azure-open-ai.streaming-chat-model.log-requests-and-responses=true
langchain4j.azure-open-ai.chat-model.api-key=${AZURE_OPENAI_KEY}
langchain4j.azure-open-ai.chat-model.endpoint=${AZURE_OPENAI_ENDPOINT}
langchain4j.azure-open-ai.chat-model.deployment-name=gpt-4o
langchain4j.azure-open-ai.chat-model.log-requests-and-responses=true

# OpenAI
# These parameters only apply when using the langchain4j-open-ai-spring-boot-starter dependency
langchain4j.open-ai.streaming-chat-model.api-key=${OPENAI_API_KEY}
langchain4j.open-ai.streaming-chat-model.model-name=gpt-4o
langchain4j.open-ai.streaming-chat-model.log-requests=true
langchain4j.open-ai.streaming-chat-model.log-responses=true

langchain4j.open-ai.chat-model.api-key=${OPENAI_API_KEY}
langchain4j.open-ai.chat-model.model-name=gpt-4o-mini
langchain4j.open-ai.chat-model.log-requests=true
langchain4j.open-ai.chat-model.log-responses=true

1 change: 1 addition & 0 deletions src/main/resources/prompts/system.st
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
You are a friendly AI assistant designed to help with the management of a veterinarian pet clinic called Spring Petclinic.
Your job is to answer questions about and to perform actions on the user's behalf, mainly around
veterinarians, owners, owners' pets and owners' visits.
If you need access to pet owners or pet types, list and locate them without asking the user.
You are required to answer an a professional manner. If you don't know the answer, politely tell the user
you don't know the answer, then ask the user a followup question to try and clarify the question they are asking.
If you do know the answer, provide the answer but do not provide any additional followup questions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package org.springframework.samples.petclinic;

import static org.assertj.core.api.Assertions.assertThat;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledInNativeImage;
import org.springframework.beans.factory.annotation.Autowired;
Expand All @@ -37,11 +35,10 @@
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;

@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT,
properties = { "langchain4j.open-ai.streaming-chat-model.api-key=FAKE_KEY",
"langchain4j.azure-open-ai.streaming-chat-model.api-key=FAKE_KEY",
"langchain4j.azure-open-ai.streaming-chat-model.endpoint=FAKE_ENDPOINT" })
@ActiveProfiles("mysql")
import static org.assertj.core.api.Assertions.assertThat;

@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT)
@ActiveProfiles(profiles = { "mysql", "test" })
@Testcontainers(disabledWithoutDocker = true)
@DisabledInNativeImage
@DisabledInAotMode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package org.springframework.samples.petclinic;

import static org.assertj.core.api.Assertions.assertThat;

import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringApplication;
Expand All @@ -29,12 +27,13 @@
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.samples.petclinic.vet.VetRepository;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.web.client.RestTemplate;

@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT,
properties = { "langchain4j.open-ai.streaming-chat-model.api-key=FAKE_KEY",
"langchain4j.azure-open-ai.streaming-chat-model.api-key=FAKE_KEY",
"langchain4j.azure-open-ai.streaming-chat-model.endpoint=FAKE_ENDPOINT" })
import static org.assertj.core.api.Assertions.assertThat;

@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT)
@ActiveProfiles("test")
public class PetClinicIntegrationTests {

@LocalServerPort
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@

package org.springframework.samples.petclinic;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assumptions.assumeTrue;

import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.jupiter.api.BeforeAll;
Expand All @@ -47,12 +40,16 @@
import org.springframework.web.client.RestTemplate;
import org.testcontainers.DockerClientFactory;

import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assumptions.assumeTrue;

@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT,
properties = { "spring.docker.compose.skip.in-tests=false", "spring.docker.compose.profiles.active=postgres",
"langchain4j.open-ai.streaming-chat-model.api-key=FAKE_KEY",
"langchain4j.azure-open-ai.streaming-chat-model.api-key=FAKE_KEY",
"langchain4j.azure-open-ai.streaming-chat-model.endpoint=FAKE_ENDPOINT" })
@ActiveProfiles("postgres")
properties = { "spring.docker.compose.skip.in-tests=false", "spring.docker.compose.profiles.active=postgres" })
@ActiveProfiles(profiles = { "postgres", "test" })
@DisabledInNativeImage
public class PostgresIntegrationTests {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@

package org.springframework.samples.petclinic.system;

import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.RANDOM_PORT;

import java.util.List;
import java.util.Map;

import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
Expand All @@ -32,13 +26,14 @@
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.client.TestRestTemplate;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.http.*;
import org.springframework.test.context.ActiveProfiles;

import java.util.List;
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.RANDOM_PORT;

/**
* Integration Test for {@link CrashController}.
Expand All @@ -47,10 +42,8 @@
*/
// NOT Waiting https://github.com/spring-projects/spring-boot/issues/5574
@SpringBootTest(webEnvironment = RANDOM_PORT,
properties = { "server.error.include-message=ALWAYS", "management.endpoints.enabled-by-default=false",
"langchain4j.open-ai.streaming-chat-model.api-key=FAKE_KEY",
"langchain4j.azure-open-ai.streaming-chat-model.api-key=FAKE_KEY",
"langchain4j.azure-open-ai.streaming-chat-model.endpoint=FAKE_ENDPOINT" })
properties = { "server.error.include-message=ALWAYS", "management.endpoints.enabled-by-default=false" })
@ActiveProfiles("test")
class CrashControllerIntegrationTests {

@SpringBootApplication(exclude = { DataSourceAutoConfiguration.class,
Expand Down
13 changes: 13 additions & 0 deletions src/test/resources/application-test.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
## Override some properties of the application.properties for unit tests

# Azure OpenAI
# These parameters only apply when using the langchain4j-azure-open-ai-spring-boot-starter dependency
langchain4j.azure-open-ai.streaming-chat-model.api-key=FAKE_KEY
langchain4j.azure-open-ai.streaming-chat-model.endpoint=FAKE_ENDPOINT
langchain4j.azure-open-ai.chat-model.api-key=FAKE_KEY
langchain4j.azure-open-ai.chat-model.endpoint=FAKE_ENDPOINT

# OpenAI
# These parameters only apply when using the langchain4j-open-ai-spring-boot-starter dependency
langchain4j.open-ai.streaming-chat-model.api-key=FAKE_KEY
langchain4j.open-ai.chat-model.api-key=FAKE_KEY

0 comments on commit 3462623

Please sign in to comment.