Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve XML agent #14681

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions libs/langchain/langchain/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,18 +1034,17 @@ def _take_next_step(
intermediate_steps: List[Tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
return self._consume_next_step(
[
a
for a in self._iter_next_step(
name_to_tool_map,
color_mapping,
inputs,
intermediate_steps,
run_manager,
)
]
"""Take a single step."""
next_steps = list(
self._iter_next_step(
name_to_tool_map,
color_mapping,
inputs,
intermediate_steps,
run_manager,
)
)
return self._consume_next_step(next_steps)

def _iter_next_step(
self,
Expand Down
39 changes: 34 additions & 5 deletions libs/langchain/langchain/agents/output_parsers/xml.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Union

from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException

from langchain.agents import AgentOutputParser

Expand Down Expand Up @@ -32,20 +33,48 @@
if "</tool>" in text:
tool, tool_input = text.split("</tool>")
_tool = tool.split("<tool>")[1]
_tool_input = tool_input.split("<tool_input>")[1]
if "</tool_input>" in _tool_input:
_tool_input = _tool_input.split("</tool_input>")[0]
if "<tool_input>" in tool_input:
_tool_input = tool_input.split("<tool_input>")[1]
if "</tool_input>" in _tool_input:
_tool_input = _tool_input.split("</tool_input>")[0]
else:
raise OutputParserException(
error=ValueError("Invalid format for output."),
llm_output=text,
observation=(
"ERROR: For a fool invocation, be sure to include a <tool_input> and"

Check failure on line 45 in libs/langchain/langchain/agents/output_parsers/xml.py

View workflow job for this annotation

GitHub Actions / ci (libs/langchain) / lint / build (3.8)

Ruff (E501)

langchain/agents/output_parsers/xml.py:45:89: E501 Line too long (93 > 88)

Check failure on line 45 in libs/langchain/langchain/agents/output_parsers/xml.py

View workflow job for this annotation

GitHub Actions / ci (libs/langchain) / lint / build (3.11)

Ruff (E501)

langchain/agents/output_parsers/xml.py:45:89: E501 Line too long (93 > 88)
"</tool_input> tags. A function without parameters could be invoked with a "

Check failure on line 46 in libs/langchain/langchain/agents/output_parsers/xml.py

View workflow job for this annotation

GitHub Actions / ci (libs/langchain) / lint / build (3.8)

Ruff (E501)

langchain/agents/output_parsers/xml.py:46:89: E501 Line too long (100 > 88)

Check failure on line 46 in libs/langchain/langchain/agents/output_parsers/xml.py

View workflow job for this annotation

GitHub Actions / ci (libs/langchain) / lint / build (3.11)

Ruff (E501)

langchain/agents/output_parsers/xml.py:46:89: E501 Line too long (100 > 88)
"an empty dictionary as the tool input.\n"
"To invoke a tool, use the format "
"`<tool>$TOOL_NAME</tool><tool_input>$TOOL_INPUT</tool_input>`.\n "

Check failure on line 49 in libs/langchain/langchain/agents/output_parsers/xml.py

View workflow job for this annotation

GitHub Actions / ci (libs/langchain) / lint / build (3.8)

Ruff (E501)

langchain/agents/output_parsers/xml.py:49:89: E501 Line too long (91 > 88)

Check failure on line 49 in libs/langchain/langchain/agents/output_parsers/xml.py

View workflow job for this annotation

GitHub Actions / ci (libs/langchain) / lint / build (3.11)

Ruff (E501)

langchain/agents/output_parsers/xml.py:49:89: E501 Line too long (91 > 88)
),
)
return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
elif "<final_answer>" in text:
_, answer = text.split("<final_answer>")
if "</final_answer>" in answer:
answer = answer.split("</final_answer>")[0]
return AgentFinish(return_values={"output": answer}, log=text)
else:
raise ValueError
raise OutputParserException(
error=ValueError("Invalid format for output."),
llm_output=text,
observation=(
"ERROR: Please either invoke a tool or provide a final answer."
"To invoke a tool, use the format "
"`<tool>$TOOL_NAME</tool><tool_input>$TOOL_INPUT</tool_input>`. "
"where $TOOL_NAME is one of the provided tools and $TOOL_INPUT "
"is a dictionary of arguments to pass to the tool, "
"matching the schema.\n"
),
send_to_llm=True,
)

def get_format_instructions(self) -> str:
raise NotImplementedError
"""Get the format instructions for this output parser."""
raise NotImplementedError(
"XMLAgentOutputParser does contain format instructions."
)

@property
def _type(self) -> str:
Expand Down
61 changes: 43 additions & 18 deletions libs/langchain/langchain/agents/xml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
class XMLAgent(BaseSingleActionAgent):
"""Agent that uses XML tags.

This agent only works with LLMs not chat models!

Ability of agent to invoke tools varies a lot depending on how good the underlying
LLM is!

Args:
tools: list of tools the agent can choose from
llm_chain: The LLMChain to call to predict the next action
Expand All @@ -22,13 +27,25 @@ class XMLAgent(BaseSingleActionAgent):

.. code-block:: python

from langchain.agents import XMLAgent
from langchain
from langchain.agents import AgentExecutor, XMLAgent
from langchain.chains import LLMChain

tools = ...
model =
chain = LLMChain(
llm=model,
prompt=XMLAgent.get_default_prompt(),
output_parser=XMLAgent.get_default_output_parser(),
)

agent = XMLAgent(tools=tools, llm_chain=chain)

agent_executor = AgentExecutor(
agent=agent,
tools=tools,
verbose=True,
handle_parsing_errors=True
)

agent_executor.invoke({"input": "what's the weather in New york?"})
"""

tools: List[BaseTool]
Expand All @@ -38,6 +55,7 @@ class XMLAgent(BaseSingleActionAgent):

@property
def input_keys(self) -> List[str]:
"""Get the input keys."""
return ["input"]

@staticmethod
Expand All @@ -48,25 +66,38 @@ def get_default_prompt() -> ChatPromptTemplate:

@staticmethod
def get_default_output_parser() -> XMLAgentOutputParser:
"""Get the default output parser."""
return XMLAgentOutputParser()

def _format_intermediate_steps(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> str:
"""Format the steps."""
log = ""
for action, observation in intermediate_steps:
if action.tool == "_Exception":
# This only works correctly when handle_parsing_errors=True
log += action.log # Will contain the llm output from the exception
log += "\n{observation}\n"
pass
else:
log += (
f"<tool>{action.tool}</tool><tool_input>{action.tool_input}"
f"</tool_input>\n<observation>{observation}</observation>\n"
)
return log

def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
log = ""
for action, observation in intermediate_steps:
log += (
f"<tool>{action.tool}</tool><tool_input>{action.tool_input}"
f"</tool_input><observation>{observation}</observation>"
)
tools = ""
for tool in self.tools:
tools += f"{tool.name}: {tool.description}\n"
inputs = {
"intermediate_steps": log,
"intermediate_steps": self._format_intermediate_steps(intermediate_steps),
"tools": tools,
"question": kwargs["input"],
"stop": ["</tool_input>", "</final_answer>"],
Expand All @@ -80,17 +111,11 @@ async def aplan(
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
log = ""
for action, observation in intermediate_steps:
log += (
f"<tool>{action.tool}</tool><tool_input>{action.tool_input}"
f"</tool_input><observation>{observation}</observation>"
)
tools = ""
for tool in self.tools:
tools += f"{tool.name}: {tool.description}\n"
inputs = {
"intermediate_steps": log,
"intermediate_steps": self._format_intermediate_steps(intermediate_steps),
"tools": tools,
"question": kwargs["input"],
"stop": ["</tool_input>", "</final_answer>"],
Expand Down
Loading