diff --git a/core/cat/agents/base_agent.py b/core/cat/agents/base_agent.py index d9116d508..47e93de79 100644 --- a/core/cat/agents/base_agent.py +++ b/core/cat/agents/base_agent.py @@ -1,11 +1,11 @@ -from typing import List +from typing import List, Any from abc import ABC, abstractmethod - from cat.utils import BaseModelDict + class AgentOutput(BaseModelDict): - output: str | None = None + output: Any | None = None intermediate_steps: List = [] return_direct: bool = False @@ -14,4 +14,4 @@ class BaseAgent(ABC): @abstractmethod def execute(*args, **kwargs) -> AgentOutput: - pass \ No newline at end of file + pass diff --git a/core/cat/agents/memory_agent.py b/core/cat/agents/memory_agent.py index 22c34f954..6f5ae6f59 100644 --- a/core/cat/agents/memory_agent.py +++ b/core/cat/agents/memory_agent.py @@ -28,12 +28,17 @@ def execute(self, stray, prompt_prefix, prompt_suffix) -> AgentOutput: ] ) + output_parser = StrOutputParser() + output_parser = stray.mad_hatter.execute_hook( + "agent_output_parser", output_parser, cat=stray + ) + chain = ( prompt | RunnableLambda(lambda x: utils.langchain_log_prompt(x, "MAIN PROMPT")) | stray._llm | RunnableLambda(lambda x: utils.langchain_log_output(x, "MAIN PROMPT OUTPUT")) - | StrOutputParser() + | output_parser ) output = chain.invoke( diff --git a/core/cat/convo/messages.py b/core/cat/convo/messages.py index 2c74c339b..4d56e58b2 100644 --- a/core/cat/convo/messages.py +++ b/core/cat/convo/messages.py @@ -1,4 +1,4 @@ -from typing import List, Literal +from typing import List, Literal, Any from cat.utils import BaseModelDict from langchain_core.messages import BaseMessage, AIMessage, HumanMessage from enum import Enum @@ -60,7 +60,7 @@ class CatMessage(BaseModelDict): user_id (str): user id """ - content: str + content: Any user_id: str type: str = "chat" why: MessageWhy | None = None diff --git a/core/cat/looking_glass/stray_cat.py b/core/cat/looking_glass/stray_cat.py index e841c6d86..f11a76918 100644 --- a/core/cat/looking_glass/stray_cat.py +++ b/core/cat/looking_glass/stray_cat.py @@ -448,7 +448,7 @@ def __call__(self, message_dict): # prepare final cat message final_output = CatMessage( - user_id=self.user_id, content=str(agent_output.output), why=why + user_id=self.user_id, content=agent_output.output, why=why ) # run message through plugins diff --git a/core/cat/mad_hatter/core_plugin/hooks/agent.py b/core/cat/mad_hatter/core_plugin/hooks/agent.py index 55fa3661f..db3295b5c 100644 --- a/core/cat/mad_hatter/core_plugin/hooks/agent.py +++ b/core/cat/mad_hatter/core_plugin/hooks/agent.py @@ -86,3 +86,42 @@ def agent_allowed_tools(allowed_tools: List[str], cat) -> List[str]: """ return allowed_tools + + +@hook(priority=0) +def agent_output_parser(output_parser, cat): + """Hook the output parser. + + Allows to edit the output parser of the *Agent*. + + Parameters + --------- + output_parser : LangchainOutputParser + Output parser of the *Agent*. + cat : StrayCat + StrayCat instance. + + Returns + ------- + output_parser : LangchainOutputParser + Output parser of the *Agent*. + + Examples + -------- + + Example 1: use Langchain's PydanticOutputParser to return a structured response + + ```python + + from pydantic import BaseModel + from langchain_core.output_parsers.pydantic import PydanticOutputParser + + class CheshireCatAnswer(BaseModel): + output: str + level_of_madness: int + + output_parser = PydanticOutputParser(pydantic_object=CheshireCatAnswer) + + ``` + """ + return output_parser