Skip to content

Commit

Permalink
configure ChatOptions
Browse files Browse the repository at this point in the history
  • Loading branch information
alexcheng1982 committed May 6, 2024
1 parent 9eda9d3 commit daa0e29
Show file tree
Hide file tree
Showing 28 changed files with 253 additions and 75 deletions.
11 changes: 0 additions & 11 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,6 @@
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-core</artifactId>
</dependency>
<dependency>
<groupId>io.github.alexcheng1982</groupId>
<artifactId>spring-ai-dashscope-client</artifactId>
<version>0.8.0</version>
<exclusions>
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.module</groupId>
<artifactId>jackson-module-kotlin</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ package io.github.llmagentbuilder.core.config

import io.github.llmagentbuilder.core.Planner
import io.github.llmagentbuilder.core.chatmemory.ChatMemoryStore
import io.github.llmagentbuilder.core.planner.planner.react.ReActPlannerFactory
import io.github.llmagentbuilder.core.planner.planner.reactjson.ReActJsonPlannerFactory
import io.github.llmagentbuilder.core.planner.planner.simple.SimplePlannerFactory
import io.github.llmagentbuilder.core.planner.planner.structuredchat.StructuredChatPlannerFactory
import io.github.llmagentbuilder.core.planner.react.ReActPlannerFactory
import io.github.llmagentbuilder.core.planner.reactjson.ReActJsonPlannerFactory
import io.github.llmagentbuilder.core.planner.simple.SimplePlannerFactory
import io.github.llmagentbuilder.core.planner.structuredchat.StructuredChatPlannerFactory
import io.github.llmagentbuilder.core.tool.AgentToolsProvider
import io.github.llmagentbuilder.core.tool.AutoDiscoveredAgentToolsProvider
import io.micrometer.core.instrument.MeterRegistry
import io.micrometer.observation.ObservationRegistry
import org.springframework.ai.chat.ChatClient
import org.springframework.ai.chat.prompt.ChatOptions

enum class PlannerType {
ReAct {
Expand Down Expand Up @@ -46,6 +47,7 @@ data class MetadataConfig(

data class LLMConfig(
val chatClient: ChatClient,
val chatOptions: ChatOptions,
)

data class PlannerConfig(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package io.github.llmagentbuilder.core.executor

import io.github.llmagentbuilder.core.*
import io.github.llmagentbuilder.core.observation.AgentExecutionObservationContext
import io.github.llmagentbuilder.core.observation.AgentExecutionObservationDocumentation
import io.github.llmagentbuilder.core.observation.DefaultAgentExecutionObservationConvention
import io.github.llmagentbuilder.core.planner.planner.OutputParserException
import io.github.llmagentbuilder.core.planner.planner.OutputParserExceptionHandler
import io.github.llmagentbuilder.core.planner.planner.ParseResult
import io.github.llmagentbuilder.core.planner.OutputParserException
import io.github.llmagentbuilder.core.planner.OutputParserExceptionHandler
import io.github.llmagentbuilder.core.planner.ParseResult
import io.github.llmagentbuilder.core.tool.ExceptionTool
import io.github.llmagentbuilder.core.tool.InvalidTool
import io.github.llmagentbuilder.core.tool.InvalidToolInput
import io.github.llmagentbuilder.core.*
import io.micrometer.observation.ObservationRegistry
import org.slf4j.LoggerFactory
import org.springframework.ai.model.function.FunctionCallback
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package io.github.llmagentbuilder.core.planner

import org.springframework.ai.chat.prompt.ChatOptions

interface ChatOptionsConfigurer {

data class ChatOptionsConfig(
val toolNames: Set<String>? = null,
val stopSequence: List<String>? = null,
)

fun supports(chatOptions: ChatOptions): Boolean
fun configure(
chatOptions: ChatOptions,
config: ChatOptionsConfig
): ChatOptions
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.github.llmagentbuilder.core.planner.planner
package io.github.llmagentbuilder.core.planner

import com.fasterxml.jackson.core.type.TypeReference
import com.fasterxml.jackson.databind.ObjectMapper
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.github.llmagentbuilder.core.planner.planner
package io.github.llmagentbuilder.core.planner

import io.github.llmagentbuilder.core.AgentFinish
import io.github.llmagentbuilder.core.IntermediateAgentStep
Expand All @@ -13,23 +13,22 @@ import io.github.llmagentbuilder.core.observation.AgentPlanningObservationContex
import io.github.llmagentbuilder.core.observation.AgentPlanningObservationDocumentation
import io.github.llmagentbuilder.core.observation.DefaultAgentPlanningObservationConvention
import io.github.llmagentbuilder.core.observation.InstrumentedChatClient
import io.github.llmagentbuilder.core.planner.planner.simple.SimpleOutputParser
import io.github.llmagentbuilder.core.planner.simple.SimpleOutputParser
import io.github.llmagentbuilder.core.tool.AgentTool
import io.github.llmagentbuilder.core.tool.AgentToolsProvider
import io.github.llmagentbuilder.core.tool.AutoDiscoveredAgentToolsProvider
import io.github.alexcheng1982.springai.dashscope.DashscopeChatClient
import io.github.alexcheng1982.springai.dashscope.DashscopeChatOptions
import io.github.alexcheng1982.springai.dashscope.api.DashscopeModelName
import io.micrometer.core.instrument.MeterRegistry
import io.micrometer.observation.ObservationRegistry
import org.springframework.ai.chat.ChatClient
import org.springframework.ai.chat.messages.SystemMessage
import org.springframework.ai.chat.prompt.ChatOptions
import org.springframework.ai.chat.prompt.Prompt
import org.springframework.ai.chat.prompt.PromptTemplate
import java.util.*

open class LLMPlanner(
private val chatClient: ChatClient,
private val chatOptions: ChatOptions,
private val toolsProvider: AgentToolsProvider,
private val outputParser: OutputParser,
private val userPromptTemplate: PromptTemplate,
Expand All @@ -43,7 +42,7 @@ open class LLMPlanner(
},
private val observationRegistry: ObservationRegistry? = null,
private val meterRegistry: MeterRegistry? = null,
private val stopSequence: List<String>? = listOf("\\nObservation")
private val stopSequence: List<String>? = null,
) : Planner {

override fun plan(
Expand Down Expand Up @@ -84,7 +83,7 @@ open class LLMPlanner(
}
val prompt = Prompt(
chatMemory?.messages() ?: messages,
prepareChatClientOptions(chatClient, toolNames)
prepareChatClientOptions(toolNames)
)
val response = chatClient.call(prompt)
val text = response.result?.output?.content?.trim() ?: ""
Expand Down Expand Up @@ -143,20 +142,14 @@ open class LLMPlanner(
}

private fun prepareChatClientOptions(
chatClient: ChatClient,
toolNames: Set<String>
): ChatOptions? {
val client =
if (chatClient is InstrumentedChatClient) chatClient.chatClient else chatClient
if (client is DashscopeChatClient) {
return DashscopeChatOptions.builder()
.withModel(DashscopeModelName.QWEN_MAX)
.withTemperature(0.2f)
.withFunctions(toolNames)
.withStops(stopSequence)
.build()
}
return null
): ChatOptions {
return ServiceLoader.load(ChatOptionsConfigurer::class.java)
.firstOrNull { it.supports(chatOptions) }?.configure(
chatOptions, ChatOptionsConfigurer.ChatOptionsConfig(
toolNames, stopSequence
)
) ?: chatOptions
}

override fun toString(): String {
Expand All @@ -165,6 +158,7 @@ open class LLMPlanner(

class Builder {
private lateinit var chatClient: ChatClient
private lateinit var chatOptions: ChatOptions
private var toolsProvider: AgentToolsProvider? = null
private var outputParser: OutputParser = SimpleOutputParser.INSTANCE
private var observationRegistry: ObservationRegistry? = null
Expand All @@ -176,12 +170,18 @@ open class LLMPlanner(
private var systemInstruction: String? = null
private var chatMemoryStore: ChatMemoryStore? = null
private var chatMemoryProvider: ChatMemoryProvider? = null
private var stopSequence: List<String>? = null

fun withChatClient(chatClient: ChatClient): Builder {
this.chatClient = chatClient
return this
}

fun withChatOptions(chatOptions: ChatOptions): Builder {
this.chatOptions = chatOptions
return this
}

fun withAgentToolsProvider(toolsProvider: AgentToolsProvider?): Builder {
this.toolsProvider = toolsProvider
return this
Expand Down Expand Up @@ -229,6 +229,11 @@ open class LLMPlanner(
return this
}

fun withStopSequence(stopSequence: List<String>?): Builder {
this.stopSequence = stopSequence
return this
}

fun build(): LLMPlanner {
if (!::chatClient.isInitialized) {
throw IllegalArgumentException("ChatClient is required")
Expand All @@ -239,6 +244,7 @@ open class LLMPlanner(
)
return LLMPlanner(
chatClient,
chatOptions,
toolsProvider ?: AutoDiscoveredAgentToolsProvider,
outputParser,
userPromptTemplate,
Expand All @@ -255,6 +261,7 @@ open class LLMPlanner(
},
observationRegistry,
meterRegistry,
stopSequence,
)
}
}
Expand All @@ -264,13 +271,14 @@ abstract class LLMPlannerFactory {
abstract fun defaultBuilder(): LLMPlanner.Builder

fun create(agentConfig: AgentConfig): LLMPlanner {
val (chatClient) = agentConfig.llmConfig
val (chatClient, chatOptions) = agentConfig.llmConfig
val (_, systemInstruction) = agentConfig.plannerConfig()
val (agentToolsProvider) = agentConfig.toolsConfig()
val (chatMemoryStore) = agentConfig.memoryConfig()
val (observationRegistry, meterRegistry) = agentConfig.observationConfig()
return create(
chatClient,
chatOptions,
agentToolsProvider,
systemInstruction,
chatMemoryStore,
Expand All @@ -281,6 +289,7 @@ abstract class LLMPlannerFactory {

fun create(
chatClient: ChatClient,
chatOptions: ChatOptions,
agentToolsProvider: AgentToolsProvider? = null,
systemInstruction: String? = null,
chatMemoryStore: ChatMemoryStore? = null,
Expand All @@ -289,6 +298,7 @@ abstract class LLMPlannerFactory {
): LLMPlanner {
return defaultBuilder()
.withChatClient(chatClient)
.withChatOptions(chatOptions)
.withAgentToolsProvider(agentToolsProvider)
.withSystemInstruction(systemInstruction)
.withChatMemoryStore(chatMemoryStore)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.github.llmagentbuilder.core.planner.planner
package io.github.llmagentbuilder.core.planner

import io.github.llmagentbuilder.core.AgentAction
import io.github.llmagentbuilder.core.AgentFinish
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package io.github.llmagentbuilder.core.planner.planner.react
package io.github.llmagentbuilder.core.planner.react

import io.github.llmagentbuilder.core.AgentAction
import io.github.llmagentbuilder.core.planner.planner.OutputParser
import io.github.llmagentbuilder.core.planner.planner.OutputParserException
import io.github.llmagentbuilder.core.planner.planner.ParseResult
import io.github.llmagentbuilder.core.planner.OutputParser
import io.github.llmagentbuilder.core.planner.OutputParserException
import io.github.llmagentbuilder.core.planner.ParseResult
import java.util.regex.Pattern

class ReActOutputParser : OutputParser {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.github.llmagentbuilder.core.planner.planner.react
package io.github.llmagentbuilder.core.planner.react

import io.github.llmagentbuilder.core.planner.planner.LLMPlanner
import io.github.llmagentbuilder.core.planner.planner.LLMPlannerFactory
import io.github.llmagentbuilder.core.planner.LLMPlanner
import io.github.llmagentbuilder.core.planner.LLMPlannerFactory
import org.springframework.ai.chat.prompt.PromptTemplate
import org.springframework.core.io.ClassPathResource

Expand All @@ -12,5 +12,6 @@ object ReActPlannerFactory : LLMPlannerFactory() {
.withSystemPromptTemplate(PromptTemplate(ClassPathResource("prompts/react/system.st")))
.withOutputParser(ReActOutputParser.INSTANCE)
.withSystemInstruction("Answer the following questions as best you can.")
.withStopSequence(listOf("\\nObservation"))
}
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package io.github.llmagentbuilder.core.planner.planner.reactjson
package io.github.llmagentbuilder.core.planner.reactjson

import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.kotlin.KotlinModule
import io.github.llmagentbuilder.core.AgentAction
import io.github.llmagentbuilder.core.AgentFinish
import io.github.llmagentbuilder.core.planner.planner.JsonParser
import io.github.llmagentbuilder.core.planner.planner.OutputParser
import io.github.llmagentbuilder.core.planner.planner.OutputParserException
import io.github.llmagentbuilder.core.planner.planner.ParseResult
import io.github.llmagentbuilder.core.planner.JsonParser
import io.github.llmagentbuilder.core.planner.OutputParser
import io.github.llmagentbuilder.core.planner.OutputParserException
import io.github.llmagentbuilder.core.planner.ParseResult

class ReActJsonOutputParser : OutputParser {
private val finalAnswerAction = "Final Answer:"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.github.llmagentbuilder.core.planner.planner.reactjson
package io.github.llmagentbuilder.core.planner.reactjson

import io.github.llmagentbuilder.core.planner.planner.LLMPlanner
import io.github.llmagentbuilder.core.planner.planner.LLMPlannerFactory
import io.github.llmagentbuilder.core.planner.LLMPlanner
import io.github.llmagentbuilder.core.planner.LLMPlannerFactory
import org.springframework.ai.chat.prompt.PromptTemplate
import org.springframework.core.io.ClassPathResource

Expand All @@ -12,5 +12,6 @@ object ReActJsonPlannerFactory : LLMPlannerFactory() {
.withSystemPromptTemplate(PromptTemplate(ClassPathResource("prompts/react-json/system.st")))
.withOutputParser(ReActJsonOutputParser.INSTANCE)
.withSystemInstruction("Answer the following questions as best you can.")
.withStopSequence(listOf("\\nObservation"))
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package io.github.llmagentbuilder.core.planner.planner.simple
package io.github.llmagentbuilder.core.planner.simple

import io.github.llmagentbuilder.core.AgentFinish
import io.github.llmagentbuilder.core.planner.planner.OutputParser
import io.github.llmagentbuilder.core.planner.planner.ParseResult
import io.github.llmagentbuilder.core.planner.OutputParser
import io.github.llmagentbuilder.core.planner.ParseResult

/**
* Output from LLM is used as a return value directly, no further actions will be taken.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.github.llmagentbuilder.core.planner.planner.simple
package io.github.llmagentbuilder.core.planner.simple

import io.github.llmagentbuilder.core.planner.planner.LLMPlanner
import io.github.llmagentbuilder.core.planner.planner.LLMPlannerFactory
import io.github.llmagentbuilder.core.planner.LLMPlanner
import io.github.llmagentbuilder.core.planner.LLMPlannerFactory
import org.springframework.ai.chat.prompt.PromptTemplate
import org.springframework.core.io.ClassPathResource

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.github.llmagentbuilder.core.planner.planner.structuredchat;
package io.github.llmagentbuilder.core.planner.structuredchat;

import com.fasterxml.jackson.annotation.JsonProperty;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package io.github.llmagentbuilder.core.planner.planner.structuredchat
package io.github.llmagentbuilder.core.planner.structuredchat

import com.fasterxml.jackson.databind.ObjectMapper
import io.github.llmagentbuilder.core.AgentAction
import io.github.llmagentbuilder.core.planner.planner.OutputParser
import io.github.llmagentbuilder.core.planner.planner.OutputParserException
import io.github.llmagentbuilder.core.planner.planner.ParseResult
import io.github.llmagentbuilder.core.planner.OutputParser
import io.github.llmagentbuilder.core.planner.OutputParserException
import io.github.llmagentbuilder.core.planner.ParseResult
import java.util.regex.Pattern

class StructuredChatOutputParser : OutputParser {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.github.llmagentbuilder.core.planner.planner.structuredchat
package io.github.llmagentbuilder.core.planner.structuredchat

import io.github.llmagentbuilder.core.planner.planner.LLMPlanner
import io.github.llmagentbuilder.core.planner.planner.LLMPlannerFactory
import io.github.llmagentbuilder.core.planner.LLMPlanner
import io.github.llmagentbuilder.core.planner.LLMPlannerFactory
import org.springframework.ai.chat.prompt.PromptTemplate
import org.springframework.core.io.ClassPathResource

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package io.github.alexcheng1982.llmagentbuilder.core.planner

import io.github.llmagentbuilder.core.planner.planner.JsonParser
import io.github.llmagentbuilder.core.planner.JsonParser
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNotNull
import org.junit.jupiter.api.Test
Expand Down
Loading

0 comments on commit daa0e29

Please sign in to comment.