Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Agent output parser #984

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions core/cat/agents/base_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import List
from typing import List, Any
from abc import ABC, abstractmethod


from cat.utils import BaseModelDict


class AgentOutput(BaseModelDict):
output: str | None = None
output: Any | None = None
intermediate_steps: List = []
return_direct: bool = False

Expand All @@ -14,4 +14,4 @@ class BaseAgent(ABC):

@abstractmethod
def execute(*args, **kwargs) -> AgentOutput:
pass
pass
7 changes: 6 additions & 1 deletion core/cat/agents/memory_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,17 @@ def execute(self, stray, prompt_prefix, prompt_suffix) -> AgentOutput:
]
)

output_parser = StrOutputParser()
output_parser = stray.mad_hatter.execute_hook(
"agent_output_parser", output_parser, cat=stray
)

chain = (
prompt
| RunnableLambda(lambda x: utils.langchain_log_prompt(x, "MAIN PROMPT"))
| stray._llm
| RunnableLambda(lambda x: utils.langchain_log_output(x, "MAIN PROMPT OUTPUT"))
| StrOutputParser()
| output_parser
)

output = chain.invoke(
Expand Down
4 changes: 2 additions & 2 deletions core/cat/convo/messages.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Literal
from typing import List, Literal, Any
from cat.utils import BaseModelDict
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
from enum import Enum
Expand Down Expand Up @@ -60,7 +60,7 @@ class CatMessage(BaseModelDict):
user_id (str): user id
"""

content: str
content: Any
user_id: str
type: str = "chat"
why: MessageWhy | None = None
Expand Down
2 changes: 1 addition & 1 deletion core/cat/looking_glass/stray_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def __call__(self, message_dict):

# prepare final cat message
final_output = CatMessage(
user_id=self.user_id, content=str(agent_output.output), why=why
user_id=self.user_id, content=agent_output.output, why=why
)

# run message through plugins
Expand Down
39 changes: 39 additions & 0 deletions core/cat/mad_hatter/core_plugin/hooks/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,42 @@ def agent_allowed_tools(allowed_tools: List[str], cat) -> List[str]:
"""

return allowed_tools


@hook(priority=0)
def agent_output_parser(output_parser, cat):
"""Hook the output parser.

Allows to edit the output parser of the *Agent*.

Parameters
---------
output_parser : LangchainOutputParser
Output parser of the *Agent*.
cat : StrayCat
StrayCat instance.

Returns
-------
output_parser : LangchainOutputParser
Output parser of the *Agent*.

Examples
--------

Example 1: use Langchain's PydanticOutputParser to return a structured response

```python

from pydantic import BaseModel
from langchain_core.output_parsers.pydantic import PydanticOutputParser

class CheshireCatAnswer(BaseModel):
output: str
level_of_madness: int

output_parser = PydanticOutputParser(pydantic_object=CheshireCatAnswer)

```
"""
return output_parser
Loading