diff --git a/readme.md b/readme.md index 7b777719..42f3f4b0 100644 --- a/readme.md +++ b/readme.md @@ -7,6 +7,25 @@ A chatbot using **Generative AI** has been added to the famous Spring Petclinic application. This version uses the **[LangChain4j project](https://docs.langchain4j.dev/)** and currently supports **OpenAI** or **Azure's OpenAI** as the **LLM provider**. This is a fork from the **[spring-petclinic-ai](https://github.com/spring-petclinic/spring-petclinic-ai)** based on Spring AI. +This sample demonstrates how to **easily integrate AI/LLM capabilities into a Java application using LangChain4j**. +This can be achieved thanks to: +* A unified **abstraction layer** designed to decouple your code from specific implementations like LLM or embedding providers, enabling easy component swapping. + Only the [application.properties](src/main/resources/application.properties) file references LLM providers such as OpenAI or Azure OpenAI. +* **Memory** offers context to the LLM for both your current and previous conversations. + Refer to the use of the `MessageWindowChatMemory` class in [AssistantConfiguration](src/main/java/org/springframework/samples/petclinic/chat/AssistantConfiguration.java). +* **AI Services** enables declarative definitions of complex AI behaviors through a straightforward Java API. + See the use of the `@AiService` annotation in the [Assistant](src/main/java/org/springframework/samples/petclinic/chat/Assistant.java) interface. +* **System prompts** play a vital role in LLMs as they shape how models interpret and respond to user queries. + Look at the `@SystemMessage` annotation usage in the [Assistant](src/main/java/org/springframework/samples/petclinic/chat/Assistant.java) interface. +* **Streaming** response token-by-token when using the `TokenStream` return type and Spring *Server-Sent Events* supports. + Take a look at the [AssistantController](src/main/java/org/springframework/samples/petclinic/chat/AssistantController.java) REST controller +* **Function calling** or **Tools** allows the LLM to call, when necessary, one or more java methods. + The [AssistantTool](src/main/java/org/springframework/samples/petclinic/chat/AssistantTool.java) component declares functions using the `@Tool` annotation from LangChain4j. +* **Structured outputs** allow LLM responses to be received in a specified format as Java POJOs. + [AssistantTool](src/main/java/org/springframework/samples/petclinic/chat/AssistantTool.java) uses Java records as the LLM/ input/output data structure. +* **Retrieval-Augmented Generation** (RAG) enables an LLM to incorporate and respond based on specific data—such as data from the petclinic database—by ingesting and referencing it during interactions. + The [AssistantConfiguration](src/main/java/org/springframework/samples/petclinic/chat/AssistantConfiguration.java) declares the `EmbeddingModel`, `InMemoryEmbeddingStore` and `EmbeddingStoreContentRetriever`beans while the [EmbeddingStoreInit](src/main/java/org/springframework/samples/petclinic/chat/EmbeddingStoreInit.java) class handles vets data ingestion at startup. + Spring Petclinic integrates a Chatbot that allows you to interact with the application in a natural language. Here are **some examples** of what you could ask: 1. Please list the owners that come to the clinic. diff --git a/src/main/java/org/springframework/samples/petclinic/chat/Assistant.java b/src/main/java/org/springframework/samples/petclinic/chat/Assistant.java index 8a2dc77e..a1eb1ec1 100644 --- a/src/main/java/org/springframework/samples/petclinic/chat/Assistant.java +++ b/src/main/java/org/springframework/samples/petclinic/chat/Assistant.java @@ -1,12 +1,13 @@ package org.springframework.samples.petclinic.chat; import dev.langchain4j.service.SystemMessage; +import dev.langchain4j.service.TokenStream; import dev.langchain4j.service.spring.AiService; @AiService interface Assistant { @SystemMessage(fromResource = "/prompts/system.st") - String chat(String userMessage); + TokenStream chat(String userMessage); } diff --git a/src/main/java/org/springframework/samples/petclinic/chat/AssistantController.java b/src/main/java/org/springframework/samples/petclinic/chat/AssistantController.java index b2da589f..63f20979 100644 --- a/src/main/java/org/springframework/samples/petclinic/chat/AssistantController.java +++ b/src/main/java/org/springframework/samples/petclinic/chat/AssistantController.java @@ -1,21 +1,61 @@ package org.springframework.samples.petclinic.chat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; @RestController class AssistantController { + private static final Logger LOGGER = LoggerFactory.getLogger(AssistantController.class); + private final Assistant assistant; + private final ExecutorService nonBlockingService = Executors.newCachedThreadPool(); + AssistantController(Assistant assistant) { this.assistant = assistant; } - @PostMapping("/chat") - public String chat(@RequestBody String query) { - return assistant.chat(query); + // Using the POST method due to chat memory capabilities + @PostMapping(value = "/chat") + public SseEmitter chat(@RequestBody String query) { + SseEmitter emitter = new SseEmitter(); + nonBlockingService.execute(() -> assistant.chat(query).onNext(message -> { + try { + sendMessage(emitter, message); + } + catch (IOException e) { + LOGGER.error("Error while writing next token", e); + emitter.completeWithError(e); + } + }).onComplete(token -> emitter.complete()).onError(error -> { + LOGGER.error("Unexpected chat error", error); + try { + sendMessage(emitter, error.getMessage()); + } + catch (IOException e) { + LOGGER.error("Error while writing next token", e); + } + emitter.completeWithError(error); + }).start()); + return emitter; + } + + private static void sendMessage(SseEmitter emitter, String message) throws IOException { + String token = message + // Hack line break problem when using Server Sent Events (SSE) + .replace("\n", "
") + // Escape JSON quotes + .replace("\"", "\\\""); + emitter.send("{\"t\": \"" + token + "\"}"); } } diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index 9fa989eb..4a2563a3 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -28,16 +28,16 @@ spring.web.resources.cache.cachecontrol.max-age=12h # Azure OpenAI # These parameters only apply when using the langchain4j-azure-open-ai-spring-boot-starter dependency -langchain4j.azure-open-ai.chat-model.api-key=${AZURE_OPENAI_KEY} -langchain4j.azure-open-ai.chat-model.endpoint=${AZURE_OPENAI_ENDPOINT} -langchain4j.azure-open-ai.chat-model.deployment-name=gpt-4o -langchain4j.azure-open-ai.chat-model.log-requests-and-responses=true +langchain4j.azure-open-ai.streaming-chat-model.api-key=${AZURE_OPENAI_KEY} +langchain4j.azure-open-ai.streaming-chat-model.endpoint=${AZURE_OPENAI_ENDPOINT} +langchain4j.azure-open-ai.streaming-chat-model.deployment-name=gpt-4o +langchain4j.azure-open-ai.streaming-chat-model.log-requests-and-responses=true # OpenAI # These parameters only apply when using the langchain4j-open-ai-spring-boot-starter dependency -langchain4j.open-ai.chat-model.api-key=${OPENAI_API_KEY} -langchain4j.open-ai.chat-model.model-name=gpt-4o -langchain4j.open-ai.chat-model.log-requests=true -langchain4j.open-ai.chat-model.log-responses=true +langchain4j.open-ai.streaming-chat-model.api-key=${OPENAI_API_KEY} +langchain4j.open-ai.streaming-chat-model.model-name=gpt-4o +langchain4j.open-ai.streaming-chat-model.log-requests=true +langchain4j.open-ai.streaming-chat-model.log-responses=true diff --git a/src/main/resources/static/resources/js/chat.js b/src/main/resources/static/resources/js/chat.js index 0b7647f1..8296da29 100644 --- a/src/main/resources/static/resources/js/chat.js +++ b/src/main/resources/static/resources/js/chat.js @@ -1,22 +1,25 @@ - -function appendMessage(message, type) { - const chatMessages = document.getElementById('chatbox-messages'); - const messageElement = document.createElement('div'); - messageElement.classList.add('chat-bubble', type); + +function displayMessage(message, elements) { + let {chatMessages, messageElement} = elements; // Convert Markdown to HTML // May interpret bullet syntax like // 1. **Betty Davis** - const htmlContent = marked.parse(message); - messageElement.innerHTML = htmlContent; - - chatMessages.appendChild(messageElement); + messageElement.innerHTML = marked.parse(message); // Scroll to the bottom of the chatbox to show the latest message chatMessages.scrollTop = chatMessages.scrollHeight; } +function prepareMessage(type) { + const chatMessages = document.getElementById('chatbox-messages'); + const messageElement = document.createElement('div'); + messageElement.classList.add('chat-bubble', type); + chatMessages.appendChild(messageElement); + return {chatMessages, messageElement}; +} + function toggleChatbox() { const chatbox = document.getElementById('chatbox'); const chatboxContent = document.getElementById('chatbox-content'); @@ -30,7 +33,7 @@ function toggleChatbox() { } } -function sendMessage() { +async function sendMessage() { const query = document.getElementById('chatbox-input').value; // Only send if there's a message @@ -39,23 +42,59 @@ function sendMessage() { // Clear the input field after sending the message document.getElementById('chatbox-input').value = ''; - // Display user message in the chatbox - appendMessage(query, 'user'); + // Display user message in the chat box + const userElements = prepareMessage("user"); + displayMessage(query, userElements); - // Send the message to the backend - fetch('/chat', { + // We'll start by using fetch to initiate a POST request to our SSE endpoint. + // This endpoint is configured to send multiple messages, with the response header Content-Type: text/event-stream. + let response = await fetch('/chat', { method: 'POST', headers: { + 'Accept': 'text/event-stream', 'Content-Type': 'application/json', + 'Cache-Control': 'no-cache' }, - body: JSON.stringify(query), - }) - .then(response => response.text()) - .then(responseText => { - // Display the response from the server in the chatbox - appendMessage(responseText, 'bot'); - }) - .catch(error => console.error('Error:', error)); + body: JSON.stringify(query) + }); + + if (response.ok) { + await displayBotReply(response); + } else { + const botElements = prepareMessage('bot'); + displayMessage('Unexpected server error', botElements); + } + +} + + +async function displayBotReply(response) { + // Instantiate a reader to process each network request as it arrives from the server. + const reader = response.body?.getReader(); + + // Set up a loop to keep receiving messages until the done signal is triggered. + // Within this loop, update your frontend application with the incoming SSE messages. + const botElements = prepareMessage('bot'); + let fullReply = ""; + while (true) { + const {value, done} = await reader.read(); + const chars = new TextDecoder().decode(value); + if (done) { + break; + } + const dataArray = chars.trim().split("\n\n"); + const jsonObjects = dataArray.map((data) => { + const jsonString = data.includes("data:") ? data.substring("data:".length) : data; + if (jsonString.length === 0) { + return null; + } + return JSON.parse(jsonString); + }).filter(obj => obj !== null); + jsonObjects.forEach((item) => { + fullReply += item.t.replaceAll('
', '\n'); + }); + displayMessage(fullReply, botElements); + } } function handleKeyPress(event) { diff --git a/src/test/java/org/springframework/samples/petclinic/MySqlIntegrationTests.java b/src/test/java/org/springframework/samples/petclinic/MySqlIntegrationTests.java index c8ae66bc..d8796bdd 100644 --- a/src/test/java/org/springframework/samples/petclinic/MySqlIntegrationTests.java +++ b/src/test/java/org/springframework/samples/petclinic/MySqlIntegrationTests.java @@ -38,9 +38,9 @@ import org.testcontainers.junit.jupiter.Testcontainers; @SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT, - properties = { "langchain4j.open-ai.chat-model.api-key=FAKE_KEY", - "langchain4j.azure-open-ai.chat-model.api-key=FAKE_KEY", - "langchain4j.azure-open-ai.chat-model.endpoint=FAKE_ENDPOINT" }) + properties = { "langchain4j.open-ai.streaming-chat-model.api-key=FAKE_KEY", + "langchain4j.azure-open-ai.streaming-chat-model.api-key=FAKE_KEY", + "langchain4j.azure-open-ai.streaming-chat-model.endpoint=FAKE_ENDPOINT" }) @ActiveProfiles("mysql") @Testcontainers(disabledWithoutDocker = true) @DisabledInNativeImage diff --git a/src/test/java/org/springframework/samples/petclinic/PetClinicIntegrationTests.java b/src/test/java/org/springframework/samples/petclinic/PetClinicIntegrationTests.java index 103078dc..50af8f33 100644 --- a/src/test/java/org/springframework/samples/petclinic/PetClinicIntegrationTests.java +++ b/src/test/java/org/springframework/samples/petclinic/PetClinicIntegrationTests.java @@ -32,9 +32,9 @@ import org.springframework.web.client.RestTemplate; @SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT, - properties = { "langchain4j.open-ai.chat-model.api-key=FAKE_KEY", - "langchain4j.azure-open-ai.chat-model.api-key=FAKE_KEY", - "langchain4j.azure-open-ai.chat-model.endpoint=FAKE_ENDPOINT" }) + properties = { "langchain4j.open-ai.streaming-chat-model.api-key=FAKE_KEY", + "langchain4j.azure-open-ai.streaming-chat-model.api-key=FAKE_KEY", + "langchain4j.azure-open-ai.streaming-chat-model.endpoint=FAKE_ENDPOINT" }) public class PetClinicIntegrationTests { @LocalServerPort diff --git a/src/test/java/org/springframework/samples/petclinic/PostgresIntegrationTests.java b/src/test/java/org/springframework/samples/petclinic/PostgresIntegrationTests.java index 768eaa7d..9494cb8f 100644 --- a/src/test/java/org/springframework/samples/petclinic/PostgresIntegrationTests.java +++ b/src/test/java/org/springframework/samples/petclinic/PostgresIntegrationTests.java @@ -49,9 +49,9 @@ @SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT, properties = { "spring.docker.compose.skip.in-tests=false", "spring.docker.compose.profiles.active=postgres", - "langchain4j.open-ai.chat-model.api-key=FAKE_KEY", - "langchain4j.azure-open-ai.chat-model.api-key=FAKE_KEY", - "langchain4j.azure-open-ai.chat-model.endpoint=FAKE_ENDPOINT" }) + "langchain4j.open-ai.streaming-chat-model.api-key=FAKE_KEY", + "langchain4j.azure-open-ai.streaming-chat-model.api-key=FAKE_KEY", + "langchain4j.azure-open-ai.streaming-chat-model.endpoint=FAKE_ENDPOINT" }) @ActiveProfiles("postgres") @DisabledInNativeImage public class PostgresIntegrationTests { diff --git a/src/test/java/org/springframework/samples/petclinic/system/CrashControllerIntegrationTests.java b/src/test/java/org/springframework/samples/petclinic/system/CrashControllerIntegrationTests.java index cdf4c1e2..6d016a00 100644 --- a/src/test/java/org/springframework/samples/petclinic/system/CrashControllerIntegrationTests.java +++ b/src/test/java/org/springframework/samples/petclinic/system/CrashControllerIntegrationTests.java @@ -48,9 +48,9 @@ // NOT Waiting https://github.com/spring-projects/spring-boot/issues/5574 @SpringBootTest(webEnvironment = RANDOM_PORT, properties = { "server.error.include-message=ALWAYS", "management.endpoints.enabled-by-default=false", - "langchain4j.open-ai.chat-model.api-key=FAKE_KEY", - "langchain4j.azure-open-ai.chat-model.api-key=FAKE_KEY", - "langchain4j.azure-open-ai.chat-model.endpoint=FAKE_ENDPOINT" }) + "langchain4j.open-ai.streaming-chat-model.api-key=FAKE_KEY", + "langchain4j.azure-open-ai.streaming-chat-model.api-key=FAKE_KEY", + "langchain4j.azure-open-ai.streaming-chat-model.endpoint=FAKE_ENDPOINT" }) class CrashControllerIntegrationTests { @SpringBootApplication(exclude = { DataSourceAutoConfiguration.class,