Skip to content

Commit

Permalink
Customize chat history
Browse files Browse the repository at this point in the history
  • Loading branch information
alexcheng1982 committed May 11, 2024
1 parent 5381bc0 commit dfab459
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package io.github.llmagentbuilder.core.config

import io.github.llmagentbuilder.core.Planner
import io.github.llmagentbuilder.core.chatmemory.ChatMemoryStore
import io.github.llmagentbuilder.core.planner.ChatHistoryCustomizer
import io.github.llmagentbuilder.core.planner.react.ReActPlannerFactory
import io.github.llmagentbuilder.core.planner.reactjson.ReActJsonPlannerFactory
import io.github.llmagentbuilder.core.planner.simple.SimplePlannerFactory
Expand Down Expand Up @@ -61,6 +62,7 @@ data class ToolsConfig(

data class MemoryConfig(
val chatMemoryStore: ChatMemoryStore? = null,
val chatHistoryCustomizer: ChatHistoryCustomizer? = null,
)

data class ObservationConfig(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package io.github.llmagentbuilder.core.planner

import org.springframework.ai.chat.messages.Message
import org.springframework.ai.chat.messages.UserMessage
import java.util.function.BiFunction

/**
* Customize chat history before sending to LLM
*/
interface ChatHistoryCustomizer {
fun customize(messages: List<Message>): List<Message>

companion object {
val DEFAULT = object : ChatHistoryCustomizer {
override fun customize(messages: List<Message>): List<Message> {
return messages
}
}
}
}

/**
* Patch content of last [UserMessage] in chat history
*/
open class PatchLastUserMessageChatHistoryCustomizer(
private val messageContentCustomizer: BiFunction<List<Message>, String, String>
) :
ChatHistoryCustomizer {
override fun customize(messages: List<Message>): List<Message> {
val lastUserMessage = messages.lastOrNull {
it is UserMessage
}
lastUserMessage?.let {
val index = messages.lastIndexOf(it)
val updatedMessages = messages.toMutableList()
updatedMessages[index] = UserMessage(
messageContentCustomizer.apply(messages, it.content)
)
return updatedMessages
}
return messages
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import io.github.llmagentbuilder.core.Planner
import io.github.llmagentbuilder.core.chatmemory.ChatMemory
import io.github.llmagentbuilder.core.chatmemory.ChatMemoryProvider
import io.github.llmagentbuilder.core.chatmemory.ChatMemoryStore
import io.github.llmagentbuilder.core.chatmemory.MessageWindowChatMemory
import io.github.llmagentbuilder.core.config.AgentConfig
import io.github.llmagentbuilder.core.executor.ActionPlanningResult
import io.github.llmagentbuilder.core.observation.AgentPlanningObservationContext
Expand All @@ -26,6 +25,31 @@ import org.springframework.ai.chat.prompt.Prompt
import org.springframework.ai.chat.prompt.PromptTemplate
import java.util.*

interface LLMPlannerChatMemoryProvider {
fun provide(
chatMemoryStore: ChatMemoryStore,
inputs: Map<String, Any>
): ChatMemory?

companion object {
val DEFAULT = object : LLMPlannerChatMemoryProvider {
override fun provide(
chatMemoryStore: ChatMemoryStore,
inputs: Map<String, Any>
): ChatMemory? {
return inputs["memory_id"]?.let { memoryId ->
if (memoryId.toString().trim().isNotBlank())
ChatMemoryProvider.DEFAULT.provideChatMemory(
chatMemoryStore,
memoryId.toString()
)
else null
}
}
}
}
}

open class LLMPlanner(
private val chatClient: ChatClient,
private val chatOptions: ChatOptions,
Expand All @@ -35,11 +59,8 @@ open class LLMPlanner(
private val systemPromptTemplate: PromptTemplate? = null,
private val systemInstruction: String? = null,
private val chatMemoryStore: ChatMemoryStore? = null,
private val chatMemoryProvider: ((ChatMemoryStore, Map<String, Any>) -> ChatMemory?)? = { store, inputs ->
inputs["memory_id"]?.let { memoryId ->
MessageWindowChatMemory(store, memoryId.toString(), 10)
}
},
private val chatMemoryProvider: LLMPlannerChatMemoryProvider? = null,
private val chatHistoryCustomizer: ChatHistoryCustomizer? = null,
private val observationRegistry: ObservationRegistry? = null,
private val meterRegistry: MeterRegistry? = null,
private val stopSequence: List<String>? = null,
Expand Down Expand Up @@ -75,14 +96,16 @@ open class LLMPlanner(
}

val chatMemory = chatMemoryStore?.let { store ->
chatMemoryProvider?.invoke(store, inputs)
chatMemoryProvider?.provide(store, inputs)
}

chatMemory?.let { memory ->
messages.forEach(memory::add)
}
val prompt = Prompt(
chatMemory?.messages() ?: messages,
chatMemory?.messages()
?.let { chatHistoryCustomizer?.customize(it) ?: it }
?: messages,
prepareChatClientOptions(toolNames)
)
val response = chatClient.call(prompt)
Expand Down Expand Up @@ -170,6 +193,7 @@ open class LLMPlanner(
private var systemInstruction: String? = null
private var chatMemoryStore: ChatMemoryStore? = null
private var chatMemoryProvider: ChatMemoryProvider? = null
private var chatHistoryCustomizer: ChatHistoryCustomizer? = null
private var stopSequence: List<String>? = null

fun withChatClient(chatClient: ChatClient): Builder {
Expand Down Expand Up @@ -229,6 +253,11 @@ open class LLMPlanner(
return this
}

fun withChatHistoryCustomizer(chatHistoryCustomizer: ChatHistoryCustomizer?): Builder {
this.chatHistoryCustomizer = chatHistoryCustomizer
return this
}

fun withStopSequence(stopSequence: List<String>?): Builder {
this.stopSequence = stopSequence
return this
Expand All @@ -251,14 +280,8 @@ open class LLMPlanner(
systemPromptTemplate,
systemInstruction,
chatMemoryStore,
{ store, inputs ->
inputs["memory_id"]?.let { memoryId ->
ChatMemoryProvider.DEFAULT.provideChatMemory(
store,
memoryId.toString()
)
}
},
LLMPlannerChatMemoryProvider.DEFAULT,
chatHistoryCustomizer,
observationRegistry,
meterRegistry,
stopSequence,
Expand All @@ -274,14 +297,15 @@ abstract class LLMPlannerFactory {
val (chatClient, chatOptions) = agentConfig.llmConfig
val (_, systemInstruction) = agentConfig.plannerConfig()
val (agentToolsProvider) = agentConfig.toolsConfig()
val (chatMemoryStore) = agentConfig.memoryConfig()
val (chatMemoryStore, chatHistoryCustomizer) = agentConfig.memoryConfig()
val (observationRegistry, meterRegistry) = agentConfig.observationConfig()
return create(
chatClient,
chatOptions,
agentToolsProvider,
systemInstruction,
chatMemoryStore,
chatHistoryCustomizer,
observationRegistry,
meterRegistry
)
Expand All @@ -293,6 +317,7 @@ abstract class LLMPlannerFactory {
agentToolsProvider: AgentToolsProvider? = null,
systemInstruction: String? = null,
chatMemoryStore: ChatMemoryStore? = null,
chatHistoryCustomizer: ChatHistoryCustomizer? = null,
observationRegistry: ObservationRegistry? = null,
meterRegistry: MeterRegistry? = null,
): LLMPlanner {
Expand All @@ -302,6 +327,7 @@ abstract class LLMPlannerFactory {
.withAgentToolsProvider(agentToolsProvider)
.withSystemInstruction(systemInstruction)
.withChatMemoryStore(chatMemoryStore)
.withChatHistoryCustomizer(chatHistoryCustomizer)
.withObservationRegistry(observationRegistry)
.withMeterRegistry(meterRegistry)
.build()
Expand Down
2 changes: 1 addition & 1 deletion llm/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
<packaging>pom</packaging>

<properties>
<dashscope-client.version>1.1.4</dashscope-client.version>
<dashscope-client.version>1.1.5</dashscope-client.version>
</properties>

<dependencyManagement>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io.github.llmagentbuilder.core.Planner;
import io.github.llmagentbuilder.core.chatmemory.ChatMemoryStore;
import io.github.llmagentbuilder.core.chatmemory.InMemoryChatMemoryStore;
import io.github.llmagentbuilder.core.planner.ChatHistoryCustomizer;
import io.github.llmagentbuilder.core.planner.reactjson.ReActJsonPlannerFactory;
import io.github.llmagentbuilder.core.planner.simple.SimplePlannerFactory;
import io.github.llmagentbuilder.core.tool.AgentToolFunctionCallbackContext;
Expand All @@ -20,6 +21,7 @@
import io.micrometer.observation.ObservationRegistry;
import java.util.List;
import java.util.Optional;
import org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration;
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.chat.ChatClient;
Expand All @@ -39,7 +41,9 @@
import org.springframework.context.annotation.Primary;

@AutoConfiguration(before = WebMvcAutoConfiguration.class, after = {
OllamaAutoConfiguration.class, OpenAiAutoConfiguration.class,
OllamaAutoConfiguration.class,
OpenAiAutoConfiguration.class,
MistralAiAutoConfiguration.class,
DashscopeAutoConfiguration.class,
ObservationAutoConfiguration.class})
@ConditionalOnProperty(prefix = ChatAgentProperties.CONFIG_PREFIX, name = "enabled", matchIfMissing = true)
Expand Down Expand Up @@ -75,6 +79,7 @@ public Planner reActJsonPlanner(ChatClient chatClient,
ChatOptions chatOptions,
Optional<ChatMemoryStore> chatMemoryStore,
AgentToolsProvider agentToolsProvider,
Optional<ChatHistoryCustomizer> chatHistoryCustomizer,
Optional<ObservationRegistry> observationRegistry,
Optional<MeterRegistry> meterRegistry) {
return ReActJsonPlannerFactory.INSTANCE.create(
Expand All @@ -83,6 +88,7 @@ public Planner reActJsonPlanner(ChatClient chatClient,
agentToolsProvider,
properties.getPlanner().getSystemInstructions(),
chatMemoryStore.orElse(null),
chatHistoryCustomizer.orElse(null),
properties.tracingEnabled() ? observationRegistry.orElse(null)
: null,
properties.metricsEnabled() ? meterRegistry.orElse(null)
Expand All @@ -98,6 +104,7 @@ public Planner simplePlanner(ChatClient chatClient,
ChatOptions chatOptions,
Optional<ChatMemoryStore> chatMemoryStore,
AgentToolsProvider agentToolsProvider,
Optional<ChatHistoryCustomizer> chatHistoryCustomizer,
Optional<ObservationRegistry> observationRegistry,
Optional<MeterRegistry> meterRegistry) {
return SimplePlannerFactory.INSTANCE.create(
Expand All @@ -106,6 +113,7 @@ public Planner simplePlanner(ChatClient chatClient,
agentToolsProvider,
properties.getPlanner().getSystemInstructions(),
chatMemoryStore.orElse(null),
chatHistoryCustomizer.orElse(null),
properties.tracingEnabled() ? observationRegistry.orElse(null)
: null,
properties.metricsEnabled() ? meterRegistry.orElse(null)
Expand Down

0 comments on commit dfab459

Please sign in to comment.