Skip to content

Commit

Permalink
Refactor ActionsSubtask for more precise initialization logic (#1626)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Jan 29, 2025
1 parent b031531 commit c8a17b8
Show file tree
Hide file tree
Showing 11 changed files with 105 additions and 57 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- Error when serializing `RagContext`.
- `Answer:` being trimmed from LLM's final answer even when using native tool calling.


## [1.2.0] - 2025-01-21
Expand Down
13 changes: 9 additions & 4 deletions griptape/common/prompt_stack/messages/message.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, TypeVar
from typing import Any, Optional, TypeVar

from attrs import define, field

Expand Down Expand Up @@ -45,8 +45,13 @@ def to_text(self) -> str:
[content.artifact.to_text() for content in self.content if isinstance(content, TextMessageContent)],
)

def to_artifact(self) -> BaseArtifact:
def to_artifact(self, meta: Optional[dict] = None) -> BaseArtifact:
if meta is None:
meta = {}
if len(self.content) == 1:
return self.content[0].artifact
artifact = self.content[0].artifact
else:
return ListArtifact([content.artifact for content in self.content])
artifact = ListArtifact([content.artifact for content in self.content])

artifact.meta.update(meta)
return artifact
29 changes: 19 additions & 10 deletions griptape/tasks/actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,14 @@ def attach_to(self, parent_task: BaseTask) -> None:
self._origin_task = parent_task
self.structure = parent_task.structure

task_input = self.input
try:
if isinstance(self.input, TextArtifact):
self.__init_from_prompt(self.input.to_text())
if isinstance(task_input, TextArtifact) and task_input.meta.get("is_react_prompt", False):
self.__init_from_prompt(task_input.to_text())
else:
self.__init_from_artifacts(self.input)
self.__init_from_artifact(task_input)

# If StructuredOutputTool was used, treat the input to it as the output of the subtask.
structured_outputs = [a for a in self.actions if isinstance(a.tool, StructuredOutputTool)]
if structured_outputs:
output_values = [JsonArtifact(a.input["values"]) for a in structured_outputs]
Expand Down Expand Up @@ -242,31 +244,38 @@ def __init_from_prompt(self, value: str) -> None:
# The LLM failed to follow the ReAct prompt, set the LLM's raw response as the output.
self.output = TextArtifact(value)

def __init_from_artifacts(self, artifacts: ListArtifact) -> None:
"""Parses the input Artifacts to extract the thought and actions.
def __init_from_artifact(self, artifact: TextArtifact | ListArtifact) -> None:
"""Parses the input Artifact to extract either a final answer or thought and actions.
Text Artifacts are used to extract the thought, and ToolAction Artifacts are used to extract the actions.
When the input Artifact is a TextArtifact, it is assumed to be the final answer.
When the input Artifact is a ListArtifact, it is assumed to contain both thought and actions.
Text Artifacts are parsed as the thought, and ToolAction Artifacts parsed as the actions.
Args:
artifacts: The input Artifacts.
artifact: The input Artifacts.
Returns:
None
"""
# When using native tools, we can assume that a TextArtifact is the LLM providing its final answer.
if isinstance(artifact, TextArtifact):
self.output = artifact
return

self.actions = [
self.__process_action_object(artifact.value.to_dict())
for artifact in artifacts.value
for artifact in artifact.value
if isinstance(artifact, ActionArtifact)
]

# When parsing from Artifacts we can't determine the thought unless there are also Actions
if self.actions:
thoughts = [artifact.value for artifact in artifacts.value if isinstance(artifact, TextArtifact)]
thoughts = [artifact.value for artifact in artifact.value if isinstance(artifact, TextArtifact)]
if thoughts:
self.thought = thoughts[0]
else:
if self.output is None:
self.output = TextArtifact(artifacts.to_text())
self.output = TextArtifact(artifact.to_text())

def __parse_actions(self, actions_matches: list[str]) -> list[ToolAction]:
if len(actions_matches) == 0:
Expand Down
14 changes: 8 additions & 6 deletions griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,11 @@ def try_run(self) -> ListArtifact | TextArtifact | JsonArtifact | ErrorArtifact:
if self.response_stop_sequence not in self.prompt_driver.tokenizer.stop_sequences:
self.prompt_driver.tokenizer.stop_sequences.extend([self.response_stop_sequence])

result = self.prompt_driver.run(self.prompt_stack)
output = self.prompt_driver.run(self.prompt_stack).to_artifact(
meta={"is_react_prompt": not self.prompt_driver.use_native_tools}
)
if self.tools:
subtask = self.add_subtask(ActionsSubtask(result.to_artifact()))
subtask = self.add_subtask(ActionsSubtask(output))

while True:
if subtask.output is None:
Expand All @@ -195,14 +197,14 @@ def try_run(self) -> ListArtifact | TextArtifact | JsonArtifact | ErrorArtifact:
else:
subtask.run()

result = self.prompt_driver.run(self.prompt_stack)
subtask = self.add_subtask(ActionsSubtask(result.to_artifact()))
output = self.prompt_driver.run(self.prompt_stack).to_artifact(
meta={"is_react_prompt": not self.prompt_driver.use_native_tools}
)
subtask = self.add_subtask(ActionsSubtask(output))
else:
break

output = subtask.output
else:
output = result.to_artifact()

if not isinstance(output, (TextArtifact, JsonArtifact, ErrorArtifact)):
raise ValueError(f"Output must be a TextArtifact, JsonArtifact, or ErrorArtifact, not {type(output)}")
Expand Down
7 changes: 5 additions & 2 deletions griptape/tasks/tool_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def try_run(self) -> ListArtifact | TextArtifact | ErrorArtifact:
result = self.prompt_driver.run(self.prompt_stack)

if self.prompt_driver.use_native_tools:
subtask_input = result.to_artifact()
subtask_input = result.to_artifact(meta={"is_react_prompt": False})
else:
action_matches = re.findall(self.ACTION_PATTERN, result.to_text(), re.DOTALL)

Expand All @@ -77,7 +77,10 @@ def try_run(self) -> ListArtifact | TextArtifact | ErrorArtifact:
action_dict = json.loads(data)

action_dict["tag"] = self.tool.name
subtask_input = J2("tasks/tool_task/subtask.j2").render(action_json=json.dumps(action_dict))
subtask_input = TextArtifact(
J2("tasks/tool_task/subtask.j2").render(action_json=json.dumps(action_dict)),
meta={"is_react_prompt": True},
)

try:
subtask = self.add_subtask(ActionsSubtask(subtask_input))
Expand Down
4 changes: 2 additions & 2 deletions tests/mocks/mock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
]
if any(action_messages):
return Message(
content=[TextMessageContent(TextArtifact(f"Answer: {output}"))],
content=[TextMessageContent(TextArtifact(output))],
role=Message.ASSISTANT_ROLE,
usage=Message.Usage(input_tokens=100, output_tokens=100),
)
Expand Down Expand Up @@ -90,7 +90,7 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
message for message in prompt_stack.messages if message.has_any_content_type(ActionCallMessageContent)
]
if any(action_messages):
yield DeltaMessage(content=TextDeltaMessageContent(f"Answer: {output}"))
yield DeltaMessage(content=TextDeltaMessageContent(output))
yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=100, output_tokens=100))
else:
if self.structured_output_strategy == "tool":
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/events/test_finish_actions_subtask_event.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from griptape.artifacts import TextArtifact
from griptape.events import FinishActionsSubtaskEvent
from griptape.structures import Agent
from griptape.tasks import ActionsSubtask, PromptTask
Expand All @@ -9,11 +10,12 @@
class TestFinishActionsSubtaskEvent:
@pytest.fixture()
def finish_subtask_event(self):
valid_input = (
valid_input = TextArtifact(
"Thought: need to test\n"
'Actions: [{"tag": "foo", "name": "MockTool", "path": "test", "input": {"values": {"test": "test input"}}}]\n'
"<|Response|>: test observation\n"
"Answer: test output"
"Answer: test output",
meta={"is_react_prompt": True},
)
task = PromptTask(tools=[MockTool()])
agent = Agent()
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/events/test_start_actions_subtask_event.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from griptape.artifacts import TextArtifact
from griptape.events import StartActionsSubtaskEvent
from griptape.structures import Agent
from griptape.tasks import ActionsSubtask, PromptTask
Expand All @@ -9,11 +10,12 @@
class TestStartActionsSubtaskEvent:
@pytest.fixture()
def start_subtask_event(self):
valid_input = (
valid_input = TextArtifact(
"Thought: need to test\n"
'Actions: [{"tag": "foo", "name": "MockTool", "path": "test", "input": {"values": {"test": "test input"}}}]\n'
"<|Response|>: test observation\n"
"Answer: test output"
"Answer: test output",
meta={"is_react_prompt": True},
)
task = PromptTask(tools=[MockTool()])
agent = Agent()
Expand Down
Loading

0 comments on commit c8a17b8

Please sign in to comment.