From dfab459c5d657051693786c08be4ffcf791f7637 Mon Sep 17 00:00:00 2001 From: Fu Cheng Date: Sat, 11 May 2024 12:02:20 +0800 Subject: [PATCH] Customize chat history --- .../core/config/AgentConfig.kt | 2 + .../core/planner/ChatHistoryCustomizer.kt | 44 ++++++++++++++ .../core/planner/LLMPlanner.kt | 60 +++++++++++++------ llm/pom.xml | 2 +- .../chatagent/ChatAgentAutoConfiguration.java | 10 +++- 5 files changed, 99 insertions(+), 19 deletions(-) create mode 100644 core/src/main/kotlin/io/github/llmagentbuilder/core/planner/ChatHistoryCustomizer.kt diff --git a/core/src/main/kotlin/io/github/llmagentbuilder/core/config/AgentConfig.kt b/core/src/main/kotlin/io/github/llmagentbuilder/core/config/AgentConfig.kt index edd37bb..62bbee2 100644 --- a/core/src/main/kotlin/io/github/llmagentbuilder/core/config/AgentConfig.kt +++ b/core/src/main/kotlin/io/github/llmagentbuilder/core/config/AgentConfig.kt @@ -2,6 +2,7 @@ package io.github.llmagentbuilder.core.config import io.github.llmagentbuilder.core.Planner import io.github.llmagentbuilder.core.chatmemory.ChatMemoryStore +import io.github.llmagentbuilder.core.planner.ChatHistoryCustomizer import io.github.llmagentbuilder.core.planner.react.ReActPlannerFactory import io.github.llmagentbuilder.core.planner.reactjson.ReActJsonPlannerFactory import io.github.llmagentbuilder.core.planner.simple.SimplePlannerFactory @@ -61,6 +62,7 @@ data class ToolsConfig( data class MemoryConfig( val chatMemoryStore: ChatMemoryStore? = null, + val chatHistoryCustomizer: ChatHistoryCustomizer? = null, ) data class ObservationConfig( diff --git a/core/src/main/kotlin/io/github/llmagentbuilder/core/planner/ChatHistoryCustomizer.kt b/core/src/main/kotlin/io/github/llmagentbuilder/core/planner/ChatHistoryCustomizer.kt new file mode 100644 index 0000000..62e4a24 --- /dev/null +++ b/core/src/main/kotlin/io/github/llmagentbuilder/core/planner/ChatHistoryCustomizer.kt @@ -0,0 +1,44 @@ +package io.github.llmagentbuilder.core.planner + +import org.springframework.ai.chat.messages.Message +import org.springframework.ai.chat.messages.UserMessage +import java.util.function.BiFunction + +/** + * Customize chat history before sending to LLM + */ +interface ChatHistoryCustomizer { + fun customize(messages: List): List + + companion object { + val DEFAULT = object : ChatHistoryCustomizer { + override fun customize(messages: List): List { + return messages + } + } + } +} + +/** + * Patch content of last [UserMessage] in chat history + */ +open class PatchLastUserMessageChatHistoryCustomizer( + private val messageContentCustomizer: BiFunction, String, String> +) : + ChatHistoryCustomizer { + override fun customize(messages: List): List { + val lastUserMessage = messages.lastOrNull { + it is UserMessage + } + lastUserMessage?.let { + val index = messages.lastIndexOf(it) + val updatedMessages = messages.toMutableList() + updatedMessages[index] = UserMessage( + messageContentCustomizer.apply(messages, it.content) + ) + return updatedMessages + } + return messages + } + +} \ No newline at end of file diff --git a/core/src/main/kotlin/io/github/llmagentbuilder/core/planner/LLMPlanner.kt b/core/src/main/kotlin/io/github/llmagentbuilder/core/planner/LLMPlanner.kt index 1688ecd..7215de9 100644 --- a/core/src/main/kotlin/io/github/llmagentbuilder/core/planner/LLMPlanner.kt +++ b/core/src/main/kotlin/io/github/llmagentbuilder/core/planner/LLMPlanner.kt @@ -6,7 +6,6 @@ import io.github.llmagentbuilder.core.Planner import io.github.llmagentbuilder.core.chatmemory.ChatMemory import io.github.llmagentbuilder.core.chatmemory.ChatMemoryProvider import io.github.llmagentbuilder.core.chatmemory.ChatMemoryStore -import io.github.llmagentbuilder.core.chatmemory.MessageWindowChatMemory import io.github.llmagentbuilder.core.config.AgentConfig import io.github.llmagentbuilder.core.executor.ActionPlanningResult import io.github.llmagentbuilder.core.observation.AgentPlanningObservationContext @@ -26,6 +25,31 @@ import org.springframework.ai.chat.prompt.Prompt import org.springframework.ai.chat.prompt.PromptTemplate import java.util.* +interface LLMPlannerChatMemoryProvider { + fun provide( + chatMemoryStore: ChatMemoryStore, + inputs: Map + ): ChatMemory? + + companion object { + val DEFAULT = object : LLMPlannerChatMemoryProvider { + override fun provide( + chatMemoryStore: ChatMemoryStore, + inputs: Map + ): ChatMemory? { + return inputs["memory_id"]?.let { memoryId -> + if (memoryId.toString().trim().isNotBlank()) + ChatMemoryProvider.DEFAULT.provideChatMemory( + chatMemoryStore, + memoryId.toString() + ) + else null + } + } + } + } +} + open class LLMPlanner( private val chatClient: ChatClient, private val chatOptions: ChatOptions, @@ -35,11 +59,8 @@ open class LLMPlanner( private val systemPromptTemplate: PromptTemplate? = null, private val systemInstruction: String? = null, private val chatMemoryStore: ChatMemoryStore? = null, - private val chatMemoryProvider: ((ChatMemoryStore, Map) -> ChatMemory?)? = { store, inputs -> - inputs["memory_id"]?.let { memoryId -> - MessageWindowChatMemory(store, memoryId.toString(), 10) - } - }, + private val chatMemoryProvider: LLMPlannerChatMemoryProvider? = null, + private val chatHistoryCustomizer: ChatHistoryCustomizer? = null, private val observationRegistry: ObservationRegistry? = null, private val meterRegistry: MeterRegistry? = null, private val stopSequence: List? = null, @@ -75,14 +96,16 @@ open class LLMPlanner( } val chatMemory = chatMemoryStore?.let { store -> - chatMemoryProvider?.invoke(store, inputs) + chatMemoryProvider?.provide(store, inputs) } chatMemory?.let { memory -> messages.forEach(memory::add) } val prompt = Prompt( - chatMemory?.messages() ?: messages, + chatMemory?.messages() + ?.let { chatHistoryCustomizer?.customize(it) ?: it } + ?: messages, prepareChatClientOptions(toolNames) ) val response = chatClient.call(prompt) @@ -170,6 +193,7 @@ open class LLMPlanner( private var systemInstruction: String? = null private var chatMemoryStore: ChatMemoryStore? = null private var chatMemoryProvider: ChatMemoryProvider? = null + private var chatHistoryCustomizer: ChatHistoryCustomizer? = null private var stopSequence: List? = null fun withChatClient(chatClient: ChatClient): Builder { @@ -229,6 +253,11 @@ open class LLMPlanner( return this } + fun withChatHistoryCustomizer(chatHistoryCustomizer: ChatHistoryCustomizer?): Builder { + this.chatHistoryCustomizer = chatHistoryCustomizer + return this + } + fun withStopSequence(stopSequence: List?): Builder { this.stopSequence = stopSequence return this @@ -251,14 +280,8 @@ open class LLMPlanner( systemPromptTemplate, systemInstruction, chatMemoryStore, - { store, inputs -> - inputs["memory_id"]?.let { memoryId -> - ChatMemoryProvider.DEFAULT.provideChatMemory( - store, - memoryId.toString() - ) - } - }, + LLMPlannerChatMemoryProvider.DEFAULT, + chatHistoryCustomizer, observationRegistry, meterRegistry, stopSequence, @@ -274,7 +297,7 @@ abstract class LLMPlannerFactory { val (chatClient, chatOptions) = agentConfig.llmConfig val (_, systemInstruction) = agentConfig.plannerConfig() val (agentToolsProvider) = agentConfig.toolsConfig() - val (chatMemoryStore) = agentConfig.memoryConfig() + val (chatMemoryStore, chatHistoryCustomizer) = agentConfig.memoryConfig() val (observationRegistry, meterRegistry) = agentConfig.observationConfig() return create( chatClient, @@ -282,6 +305,7 @@ abstract class LLMPlannerFactory { agentToolsProvider, systemInstruction, chatMemoryStore, + chatHistoryCustomizer, observationRegistry, meterRegistry ) @@ -293,6 +317,7 @@ abstract class LLMPlannerFactory { agentToolsProvider: AgentToolsProvider? = null, systemInstruction: String? = null, chatMemoryStore: ChatMemoryStore? = null, + chatHistoryCustomizer: ChatHistoryCustomizer? = null, observationRegistry: ObservationRegistry? = null, meterRegistry: MeterRegistry? = null, ): LLMPlanner { @@ -302,6 +327,7 @@ abstract class LLMPlannerFactory { .withAgentToolsProvider(agentToolsProvider) .withSystemInstruction(systemInstruction) .withChatMemoryStore(chatMemoryStore) + .withChatHistoryCustomizer(chatHistoryCustomizer) .withObservationRegistry(observationRegistry) .withMeterRegistry(meterRegistry) .build() diff --git a/llm/pom.xml b/llm/pom.xml index ba2e355..126bc39 100644 --- a/llm/pom.xml +++ b/llm/pom.xml @@ -20,7 +20,7 @@ pom - 1.1.4 + 1.1.5 diff --git a/spring/spring-boot-autoconfigure/src/main/java/io/github/llmagentbuilder/spring/autoconfigure/chatagent/ChatAgentAutoConfiguration.java b/spring/spring-boot-autoconfigure/src/main/java/io/github/llmagentbuilder/spring/autoconfigure/chatagent/ChatAgentAutoConfiguration.java index f32cec1..b5f6eee 100644 --- a/spring/spring-boot-autoconfigure/src/main/java/io/github/llmagentbuilder/spring/autoconfigure/chatagent/ChatAgentAutoConfiguration.java +++ b/spring/spring-boot-autoconfigure/src/main/java/io/github/llmagentbuilder/spring/autoconfigure/chatagent/ChatAgentAutoConfiguration.java @@ -7,6 +7,7 @@ import io.github.llmagentbuilder.core.Planner; import io.github.llmagentbuilder.core.chatmemory.ChatMemoryStore; import io.github.llmagentbuilder.core.chatmemory.InMemoryChatMemoryStore; +import io.github.llmagentbuilder.core.planner.ChatHistoryCustomizer; import io.github.llmagentbuilder.core.planner.reactjson.ReActJsonPlannerFactory; import io.github.llmagentbuilder.core.planner.simple.SimplePlannerFactory; import io.github.llmagentbuilder.core.tool.AgentToolFunctionCallbackContext; @@ -20,6 +21,7 @@ import io.micrometer.observation.ObservationRegistry; import java.util.List; import java.util.Optional; +import org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration; import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.chat.ChatClient; @@ -39,7 +41,9 @@ import org.springframework.context.annotation.Primary; @AutoConfiguration(before = WebMvcAutoConfiguration.class, after = { - OllamaAutoConfiguration.class, OpenAiAutoConfiguration.class, + OllamaAutoConfiguration.class, + OpenAiAutoConfiguration.class, + MistralAiAutoConfiguration.class, DashscopeAutoConfiguration.class, ObservationAutoConfiguration.class}) @ConditionalOnProperty(prefix = ChatAgentProperties.CONFIG_PREFIX, name = "enabled", matchIfMissing = true) @@ -75,6 +79,7 @@ public Planner reActJsonPlanner(ChatClient chatClient, ChatOptions chatOptions, Optional chatMemoryStore, AgentToolsProvider agentToolsProvider, + Optional chatHistoryCustomizer, Optional observationRegistry, Optional meterRegistry) { return ReActJsonPlannerFactory.INSTANCE.create( @@ -83,6 +88,7 @@ public Planner reActJsonPlanner(ChatClient chatClient, agentToolsProvider, properties.getPlanner().getSystemInstructions(), chatMemoryStore.orElse(null), + chatHistoryCustomizer.orElse(null), properties.tracingEnabled() ? observationRegistry.orElse(null) : null, properties.metricsEnabled() ? meterRegistry.orElse(null) @@ -98,6 +104,7 @@ public Planner simplePlanner(ChatClient chatClient, ChatOptions chatOptions, Optional chatMemoryStore, AgentToolsProvider agentToolsProvider, + Optional chatHistoryCustomizer, Optional observationRegistry, Optional meterRegistry) { return SimplePlannerFactory.INSTANCE.create( @@ -106,6 +113,7 @@ public Planner simplePlanner(ChatClient chatClient, agentToolsProvider, properties.getPlanner().getSystemInstructions(), chatMemoryStore.orElse(null), + chatHistoryCustomizer.orElse(null), properties.tracingEnabled() ? observationRegistry.orElse(null) : null, properties.metricsEnabled() ? meterRegistry.orElse(null)