-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from arey/feature/rag-query-router
Add a QueryRouter to utilize the Embedding Store when appropriate
- Loading branch information
Showing
10 changed files
with
129 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
66 changes: 66 additions & 0 deletions
66
src/main/java/org/springframework/samples/petclinic/chat/VetQueryRouter.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
package org.springframework.samples.petclinic.chat; | ||
|
||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.model.chat.ChatLanguageModel; | ||
import dev.langchain4j.model.input.Prompt; | ||
import dev.langchain4j.model.input.PromptTemplate; | ||
import dev.langchain4j.rag.content.retriever.ContentRetriever; | ||
import dev.langchain4j.rag.query.Query; | ||
import dev.langchain4j.rag.query.router.QueryRouter; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import java.util.Collection; | ||
|
||
import static java.util.Collections.emptyList; | ||
import static java.util.Collections.singletonList; | ||
|
||
/** | ||
* This filter illustrates how to conditionally skip retrieval. In some cases, retrieval | ||
* isn’t needed, such as when a user simply says "Hi". Additionally, only the clinic's | ||
* veterinarians are indexed in the Embedding Store. | ||
* <p> | ||
* To implement this, a custom {@link QueryRouter} is the simplest approach. When | ||
* retrieval is unnecessary, the QueryRouter returns an empty list, indicating that the | ||
* query won’t be routed to any {@link ContentRetriever}. | ||
* <p> | ||
* Decision-making relies on an LLM, which determines whether retrieval is needed based on | ||
* the user's query. | ||
* <p> | ||
* | ||
* @see <a href= | ||
* "https://github.com/langchain4j/langchain4j-examples/blob/main/rag-examples/src/main/java/_3_advanced/_06_Advanced_RAG_Skip_Retrieval_Example.java">_06_Advanced_RAG_Skip_Retrieval_Example.java</a> | ||
*/ | ||
class VetQueryRouter implements QueryRouter { | ||
|
||
private static final Logger LOGGER = LoggerFactory.getLogger(VetQueryRouter.class); | ||
|
||
private static final PromptTemplate PROMPT_TEMPLATE = PromptTemplate.from(""" | ||
Is the following query related to one or more veterinarians of the pet clinic? | ||
Answer only 'yes' or 'no'. | ||
Query: {{it}} | ||
"""); | ||
|
||
private final ContentRetriever vetContentRetriever; | ||
|
||
private final ChatLanguageModel chatLanguageModel; | ||
|
||
public VetQueryRouter(ChatLanguageModel chatLanguageModel, ContentRetriever vetContentRetriever) { | ||
this.chatLanguageModel = chatLanguageModel; | ||
this.vetContentRetriever = vetContentRetriever; | ||
} | ||
|
||
@Override | ||
public Collection<ContentRetriever> route(Query query) { | ||
Prompt prompt = PROMPT_TEMPLATE.apply(query.text()); | ||
|
||
AiMessage aiMessage = chatLanguageModel.generate(prompt.toUserMessage()).content(); | ||
LOGGER.debug("LLM decided: {}", aiMessage.text()); | ||
|
||
if (aiMessage.text().toLowerCase().contains("yes")) { | ||
return singletonList(vetContentRetriever); | ||
} | ||
return emptyList(); | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
## Override some properties of the application.properties for unit tests | ||
|
||
# Azure OpenAI | ||
# These parameters only apply when using the langchain4j-azure-open-ai-spring-boot-starter dependency | ||
langchain4j.azure-open-ai.streaming-chat-model.api-key=FAKE_KEY | ||
langchain4j.azure-open-ai.streaming-chat-model.endpoint=FAKE_ENDPOINT | ||
langchain4j.azure-open-ai.chat-model.api-key=FAKE_KEY | ||
langchain4j.azure-open-ai.chat-model.endpoint=FAKE_ENDPOINT | ||
|
||
# OpenAI | ||
# These parameters only apply when using the langchain4j-open-ai-spring-boot-starter dependency | ||
langchain4j.open-ai.streaming-chat-model.api-key=FAKE_KEY | ||
langchain4j.open-ai.chat-model.api-key=FAKE_KEY |