Skip to content

Commit

Permalink
Merge branch 'main' into chore/update_docs
Browse files Browse the repository at this point in the history
  • Loading branch information
vitalii-dynamiq authored Feb 3, 2025
2 parents 7c1c048 + b73e74f commit e7e5086
Show file tree
Hide file tree
Showing 26 changed files with 1,297 additions and 287 deletions.
2 changes: 1 addition & 1 deletion dynamiq/connections/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ def cursor_params(self) -> dict:

class AWSRedshift(BaseConnection):
host: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_HOST"))
port: int = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_PORT", 5432))
port: int = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_PORT", 5439))
database: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_DATABASE", "db"))
user: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_USER", "awsuser"))
password: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_PASSWORD", "password"))
Expand Down
71 changes: 47 additions & 24 deletions dynamiq/nodes/agents/orchestrators/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
from functools import cached_property
from typing import Any

from pydantic import BaseModel, TypeAdapter
from pydantic import BaseModel, Field, TypeAdapter

from dynamiq.connections.managers import ConnectionManager
from dynamiq.nodes import NodeGroup
from dynamiq.nodes.agents.base import Agent
from dynamiq.nodes.agents.orchestrators.linear_manager import LinearAgentManager
from dynamiq.nodes.agents.orchestrators.orchestrator import ActionParseError, Orchestrator
from dynamiq.nodes.agents.orchestrators.orchestrator import ActionParseError, Orchestrator, OrchestratorError
from dynamiq.nodes.node import NodeDependency
from dynamiq.runnables import RunnableConfig, RunnableStatus
from dynamiq.types.feedback import PlanApprovalConfig
from dynamiq.utils.logger import logger


Expand Down Expand Up @@ -44,7 +45,7 @@ class LinearOrchestrator(Orchestrator):
use_summarizer (Optional[bool]): Indicates if a final summarizer is used.
summarize_all_answers (Optional[bool]): Indicates whether to summarize answers to all tasks
or use only last one. Will only be applied if use_summarizer is set to True.
max_plan_retries (Optional[int]): Maximum number of plan generation retries.
"""

name: str | None = "LinearOrchestrator"
Expand All @@ -53,6 +54,8 @@ class LinearOrchestrator(Orchestrator):
agents: list[Agent] = []
use_summarizer: bool = True
summarize_all_answers: bool = False
max_plan_retries: int = 5
plan_approval: PlanApprovalConfig = Field(default_factory=PlanApprovalConfig)

def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand Down Expand Up @@ -102,31 +105,51 @@ def agents_descriptions(self) -> str:

def get_tasks(self, input_task: str, config: RunnableConfig = None, **kwargs) -> list[Task]:
"""Generate tasks using the manager agent."""
manager_result = self.manager.run(
input_data={
"action": "plan",
"input_task": input_task,
"agents": self.agents_descriptions,
},
config=config,
run_depends=self._run_depends,
**kwargs,
)
self._run_depends = [NodeDependency(node=self.manager).to_dict()]
manager_result_content = ""
feedback = ""

for _ in range(self.max_plan_retries):
manager_result = self.manager.run(
input_data={
"action": "plan",
"input_task": input_task,
"agents": self.agents_descriptions,
"feedback": feedback,
"previous_plan": manager_result_content,
},
config=config,
run_depends=self._run_depends,
**kwargs,
)
self._run_depends = [NodeDependency(node=self.manager).to_dict()]

if manager_result.status != RunnableStatus.SUCCESS:
error_message = f"LLM '{self.manager.name}' failed: {manager_result.output.get('content')}"
raise ValueError(f"Failed to generate tasks: {error_message}")
if manager_result.status != RunnableStatus.SUCCESS:
error_message = f"LLM '{self.manager.name}' failed: {manager_result.output.get('content')}"
raise ValueError(f"Failed to generate tasks: {error_message}")

manager_result_content = manager_result.output.get("content").get("result")
logger.info(
f"Orchestrator {self.name} - {self.id}: Tasks generated by {self.manager.name} - {self.manager.id}:"
f"\n{manager_result_content}"
)
manager_result_content = manager_result.output.get("content").get("result")
logger.info(
f"Orchestrator {self.name} - {self.id}: Tasks generated by {self.manager.name} - {self.manager.id}:"
f"\n{manager_result_content}"
)
try:
tasks = self.parse_tasks_from_output(manager_result_content)
except ActionParseError as e:
feedback = str(e)
continue

if not self.plan_approval.enabled:
return tasks
else:
approval_result = self.send_approval_message(
self.plan_approval, {"tasks": tasks}, config=config, **kwargs
)

tasks = self.parse_tasks_from_output(manager_result_content)
feedback = approval_result.feedback
if approval_result.is_approved:
return approval_result.data.get("tasks")

return tasks
raise OrchestratorError("Maximum number of loops reached for generating plan.")

def parse_tasks_from_output(self, output: str) -> list[Task]:
"""Parse tasks from the manager's output string."""
Expand Down
6 changes: 6 additions & 0 deletions dynamiq/nodes/agents/orchestrators/linear_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@
Ensure that your JSON output is valid and can be parsed by Python.
Double-check that all inputs to subtasks are properly passed
and that the final step (creating the JSON output) is performed last.
Here is previous plan:
{previous_plan}
Feedback from user about previous plan:
{feedback}
"""
PROMPT_TEMPLATE_AGENT_MANAGER_LINEAR_ASSIGN = """
You are a helpful agent responsible for recommending the best agent for a specific task.
Expand Down
5 changes: 4 additions & 1 deletion dynamiq/nodes/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ class NodeException(Exception):
failed_depend (NodeDependency): The dependency that caused the exception.
"""

def __init__(self, failed_depend: Optional["NodeDependency"] = None, message=None):
def __init__(
self, failed_depend: Optional["NodeDependency"] = None, message: str = None, recoverable: bool = False
):
super().__init__(message)
self.failed_depend = failed_depend
self.recoverable = recoverable


class NodeFailedException(NodeException):
Expand Down
148 changes: 145 additions & 3 deletions dynamiq/nodes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Callable, ClassVar, Union
from uuid import uuid4

from jinja2 import Template
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, computed_field, model_validator

from dynamiq.cache.utils import cache_wf_entity
Expand All @@ -24,6 +25,13 @@
from dynamiq.nodes.types import NodeGroup
from dynamiq.runnables import Runnable, RunnableConfig, RunnableResult, RunnableStatus
from dynamiq.storages.vector.base import BaseVectorStoreParams
from dynamiq.types.feedback import (
ApprovalConfig,
ApprovalInputData,
ApprovalStreamingInputEventMessage,
ApprovalStreamingOutputEventMessage,
FeedbackMethod,
)
from dynamiq.types.streaming import STREAMING_EVENT, StreamingConfig, StreamingEventMessage
from dynamiq.utils import format_value, generate_uuid, merge
from dynamiq.utils.duration import format_duration
Expand Down Expand Up @@ -204,6 +212,8 @@ class Node(BaseModel, Runnable, ABC):
output_transformer: OutputTransformer = Field(default_factory=OutputTransformer)
caching: CachingConfig = Field(default_factory=CachingConfig)
streaming: StreamingConfig = Field(default_factory=StreamingConfig)
approval: ApprovalConfig = Field(default_factory=ApprovalConfig)

depends: list[NodeDependency] = []
metadata: NodeMetadata | None = None

Expand Down Expand Up @@ -461,6 +471,134 @@ def to_dict(self, include_secure_params: bool = False, **kwargs) -> dict:
data["input_mapping"] = format_value(self.input_mapping)
return data

def send_streaming_approval_message(
self, template: str, input_data: dict, approval_config: ApprovalConfig, config: RunnableConfig = None, **kwargs
) -> ApprovalInputData:
"""
Sends approval message and waits for response.
Args:
template (str): Template to send.
input_data (dict): Data that will be sent.
approval_config (ApprovalConfig): Configuration for approval.
config (RunnableConfig, optional): Configuration for the runnable.
**kwargs: Additional keyword arguments.
Return:
ApprovalInputData: Response to approval message.
"""
event = ApprovalStreamingOutputEventMessage(
wf_run_id=config.run_id,
entity_id=self.id,
data={"template": template, "data": input_data, "mutable_data_params": approval_config.mutable_data_params},
event=approval_config.event,
)

logger.info(f"Node {self.name} - {self.id}: sending approval.")

self.run_on_node_execute_stream(callbacks=config.callbacks, event=event, **kwargs)

output: ApprovalInputData = self.get_input_streaming_event(
event=approval_config.event, event_msg_type=ApprovalStreamingInputEventMessage, config=config
).data

return output

def send_console_approval_message(self, template: str) -> ApprovalInputData:
"""
Sends approval message in console and waits for response.
Args:
template (dict): Template to send.
Returns:
ApprovalInputData: Response to approval message.
"""
feedback = input(template)
return ApprovalInputData(feedback=feedback)

def send_approval_message(
self, approval_config: ApprovalConfig, input_data: dict, config: RunnableConfig = None, **kwargs
) -> ApprovalInputData:
"""
Sends approval message and determines if it was approved or disapproved (canceled).
Args:
approval_config (ApprovalConfig): Configuration for the approval.
input_data (dict): Data that will be sent.
config (RunnableConfig, optional): Configuration for the runnable.
**kwargs: Additional keyword arguments.
Returns:
ApprovalInputData: Result of approval.
"""

message = Template(approval_config.msg_template).render(self.to_dict(), input_data=input_data)
match approval_config.feedback_method:
case FeedbackMethod.STREAM:
approval_result = self.send_streaming_approval_message(
message, input_data, approval_config, config=config, **kwargs
)
case FeedbackMethod.CONSOLE:
approval_result = self.send_console_approval_message(message)
case _:
raise ValueError(f"Error: Incorrect feedback method is chosen {approval_config.feedback_method}.")

update_params = {
feature_name: approval_result.data[feature_name]
for feature_name in approval_config.mutable_data_params
if feature_name in approval_result.data
}
approval_result.data = {**input_data, **update_params}

if approval_result.is_approved is None:
if approval_result.feedback == approval_config.accept_pattern:
logger.info(
f"Node {self.name} action was approved by human "
f"with provided feedback '{approval_result.feedback}'."
)
approval_result.is_approved = True

else:
approval_result.is_approved = False
logger.info(
f"Node {self.name} action was canceled by human"
f"with provided feedback '{approval_result.feedback}'."
)

return approval_result

def get_approved_data_or_origin(
self, input_data: dict[str, Any], config: RunnableConfig = None, **kwargs
) -> dict[str, Any]:
"""
Approves or disapproves (cancels) Node execution by requesting feedback.
Updates input data according to the feedback or leaves it the same.
Raises NodeException if execution was canceled by feedback.
Args:
input_data(dict[str, Any]): Input data.
config (RunnableConfig, optional): Configuration for the runnable.
**kwargs: Additional keyword arguments.
Returns:
dict[str, Any]: Updated input data.
Raises:
NodeException: If Node execution was canceled by feedback.
"""
if self.approval.enabled:
approval_result = self.send_approval_message(self.approval, input_data, config=config, **kwargs)
if not approval_result.is_approved:
raise NodeException(
message=f"Execution was canceled by human with feedback {approval_result.feedback}",
recoverable=True,
failed_depend=NodeDependency(self, option="Execution was canceled."),
)
return approval_result.data

return input_data

def run(
self,
input_data: Any,
Expand Down Expand Up @@ -496,6 +634,7 @@ def run(
try:
try:
self.validate_depends(depends_result)
input_data = self.get_approved_data_or_origin(input_data, config=config, **merged_kwargs)
except NodeException as e:
transformed_input = input_data | {
k: result.to_tracing_depend_dict() for k, result in depends_result.items()
Expand All @@ -508,12 +647,14 @@ def run(
**merged_kwargs,
)
logger.info(f"Node {self.name} - {self.id}: execution skipped.")
return RunnableResult(status=RunnableStatus.SKIP, input=transformed_input, output=format_value(e))
return RunnableResult(
status=RunnableStatus.SKIP,
input=transformed_input,
output=format_value(e, recoverable=e.recoverable),
)

transformed_input = self.transform_input(input_data=input_data, depends_result=depends_result)

self.run_on_node_start(config.callbacks, transformed_input, **merged_kwargs)

cache = cache_wf_entity(
entity_id=self.id,
cache_enabled=self.caching.enabled,
Expand All @@ -526,6 +667,7 @@ def run(

merged_kwargs["is_output_from_cache"] = from_cache
transformed_output = self.transform_output(output)

self.run_on_node_end(config.callbacks, transformed_output, **merged_kwargs)

logger.info(
Expand Down
2 changes: 1 addition & 1 deletion dynamiq/nodes/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
from .python import Python
from .retriever import RetrievalTool
from .scale_serp import ScaleSerpTool
from .sql_executor import SqlExecutor
from .sql_executor import SQLExecutor
from .tavily import TavilyTool
from .zenrows import ZenRowsTool
20 changes: 15 additions & 5 deletions dynamiq/nodes/tools/function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def execute(self, input_data: dict[str, Any], config: RunnableConfig = None, **k
config = ensure_config(config)
self.run_on_node_execute_run(config.callbacks, **kwargs)

result = self.run_func(input_data)
result = self.run_func(input_data, config=config, **kwargs)

logger.info(f"Tool {self.name} - {self.id}: finished with RESULT:\n{str(result)[:200]}...")
return {"content": result}
Expand Down Expand Up @@ -100,10 +100,20 @@ def function_tool(func: Callable[..., T]) -> type[FunctionTool[T]]:

def create_input_schema(func) -> type[BaseModel]:
signature = inspect.signature(func)
params_dict = {param.name: (param.annotation, param.default) for param in signature.parameters.values()}

params_dict = {}

for param in signature.parameters.values():
if param.name == "kwargs" or param.name == "config":
continue
if param.default is inspect.Parameter.empty:
params_dict[param.name] = (param.annotation, ...)
else:
params_dict[param.name] = (param.annotation, param.default)

return create_model(
"FunctionToolInputSchema",
**{k: (v[0], ...) if v[1] is inspect.Parameter.empty else (v[0], v[1]) for k, v in params_dict.items()},
**params_dict,
model_config=dict(extra="allow"),
)

Expand All @@ -116,8 +126,8 @@ class FunctionToolFromDecorator(FunctionTool[T]):
_original_func = staticmethod(func)
input_schema: ClassVar[type[BaseModel]] = create_input_schema(func)

def run_func(self, input_data: BaseModel, **_) -> T:
return func(**input_data.model_dump())
def run_func(self, input_data: BaseModel, **kwargs) -> T:
return func(**input_data.model_dump(), **kwargs)

FunctionToolFromDecorator.__name__ = func.__name__
FunctionToolFromDecorator.__qualname__ = (
Expand Down
Loading

0 comments on commit e7e5086

Please sign in to comment.