diff --git a/dynamiq/connections/connections.py b/dynamiq/connections/connections.py index a2b75528..15bd80cb 100644 --- a/dynamiq/connections/connections.py +++ b/dynamiq/connections/connections.py @@ -999,7 +999,7 @@ def cursor_params(self) -> dict: class AWSRedshift(BaseConnection): host: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_HOST")) - port: int = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_PORT", 5432)) + port: int = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_PORT", 5439)) database: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_DATABASE", "db")) user: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_USER", "awsuser")) password: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_PASSWORD", "password")) diff --git a/dynamiq/nodes/agents/orchestrators/linear.py b/dynamiq/nodes/agents/orchestrators/linear.py index e772f61a..23efdb3b 100644 --- a/dynamiq/nodes/agents/orchestrators/linear.py +++ b/dynamiq/nodes/agents/orchestrators/linear.py @@ -2,15 +2,16 @@ from functools import cached_property from typing import Any -from pydantic import BaseModel, TypeAdapter +from pydantic import BaseModel, Field, TypeAdapter from dynamiq.connections.managers import ConnectionManager from dynamiq.nodes import NodeGroup from dynamiq.nodes.agents.base import Agent from dynamiq.nodes.agents.orchestrators.linear_manager import LinearAgentManager -from dynamiq.nodes.agents.orchestrators.orchestrator import ActionParseError, Orchestrator +from dynamiq.nodes.agents.orchestrators.orchestrator import ActionParseError, Orchestrator, OrchestratorError from dynamiq.nodes.node import NodeDependency from dynamiq.runnables import RunnableConfig, RunnableStatus +from dynamiq.types.feedback import PlanApprovalConfig from dynamiq.utils.logger import logger @@ -44,7 +45,7 @@ class LinearOrchestrator(Orchestrator): use_summarizer (Optional[bool]): Indicates if a final summarizer is used. summarize_all_answers (Optional[bool]): Indicates whether to summarize answers to all tasks or use only last one. Will only be applied if use_summarizer is set to True. - + max_plan_retries (Optional[int]): Maximum number of plan generation retries. """ name: str | None = "LinearOrchestrator" @@ -53,6 +54,8 @@ class LinearOrchestrator(Orchestrator): agents: list[Agent] = [] use_summarizer: bool = True summarize_all_answers: bool = False + max_plan_retries: int = 5 + plan_approval: PlanApprovalConfig = Field(default_factory=PlanApprovalConfig) def __init__(self, **kwargs): super().__init__(**kwargs) @@ -102,31 +105,51 @@ def agents_descriptions(self) -> str: def get_tasks(self, input_task: str, config: RunnableConfig = None, **kwargs) -> list[Task]: """Generate tasks using the manager agent.""" - manager_result = self.manager.run( - input_data={ - "action": "plan", - "input_task": input_task, - "agents": self.agents_descriptions, - }, - config=config, - run_depends=self._run_depends, - **kwargs, - ) - self._run_depends = [NodeDependency(node=self.manager).to_dict()] + manager_result_content = "" + feedback = "" + + for _ in range(self.max_plan_retries): + manager_result = self.manager.run( + input_data={ + "action": "plan", + "input_task": input_task, + "agents": self.agents_descriptions, + "feedback": feedback, + "previous_plan": manager_result_content, + }, + config=config, + run_depends=self._run_depends, + **kwargs, + ) + self._run_depends = [NodeDependency(node=self.manager).to_dict()] - if manager_result.status != RunnableStatus.SUCCESS: - error_message = f"LLM '{self.manager.name}' failed: {manager_result.output.get('content')}" - raise ValueError(f"Failed to generate tasks: {error_message}") + if manager_result.status != RunnableStatus.SUCCESS: + error_message = f"LLM '{self.manager.name}' failed: {manager_result.output.get('content')}" + raise ValueError(f"Failed to generate tasks: {error_message}") - manager_result_content = manager_result.output.get("content").get("result") - logger.info( - f"Orchestrator {self.name} - {self.id}: Tasks generated by {self.manager.name} - {self.manager.id}:" - f"\n{manager_result_content}" - ) + manager_result_content = manager_result.output.get("content").get("result") + logger.info( + f"Orchestrator {self.name} - {self.id}: Tasks generated by {self.manager.name} - {self.manager.id}:" + f"\n{manager_result_content}" + ) + try: + tasks = self.parse_tasks_from_output(manager_result_content) + except ActionParseError as e: + feedback = str(e) + continue + + if not self.plan_approval.enabled: + return tasks + else: + approval_result = self.send_approval_message( + self.plan_approval, {"tasks": tasks}, config=config, **kwargs + ) - tasks = self.parse_tasks_from_output(manager_result_content) + feedback = approval_result.feedback + if approval_result.is_approved: + return approval_result.data.get("tasks") - return tasks + raise OrchestratorError("Maximum number of loops reached for generating plan.") def parse_tasks_from_output(self, output: str) -> list[Task]: """Parse tasks from the manager's output string.""" diff --git a/dynamiq/nodes/agents/orchestrators/linear_manager.py b/dynamiq/nodes/agents/orchestrators/linear_manager.py index 1b6c3989..0be69037 100644 --- a/dynamiq/nodes/agents/orchestrators/linear_manager.py +++ b/dynamiq/nodes/agents/orchestrators/linear_manager.py @@ -109,6 +109,12 @@ Ensure that your JSON output is valid and can be parsed by Python. Double-check that all inputs to subtasks are properly passed and that the final step (creating the JSON output) is performed last. + +Here is previous plan: +{previous_plan} + +Feedback from user about previous plan: +{feedback} """ PROMPT_TEMPLATE_AGENT_MANAGER_LINEAR_ASSIGN = """ You are a helpful agent responsible for recommending the best agent for a specific task. diff --git a/dynamiq/nodes/exceptions.py b/dynamiq/nodes/exceptions.py index 5ead1c4f..25ff55e7 100644 --- a/dynamiq/nodes/exceptions.py +++ b/dynamiq/nodes/exceptions.py @@ -16,9 +16,12 @@ class NodeException(Exception): failed_depend (NodeDependency): The dependency that caused the exception. """ - def __init__(self, failed_depend: Optional["NodeDependency"] = None, message=None): + def __init__( + self, failed_depend: Optional["NodeDependency"] = None, message: str = None, recoverable: bool = False + ): super().__init__(message) self.failed_depend = failed_depend + self.recoverable = recoverable class NodeFailedException(NodeException): diff --git a/dynamiq/nodes/node.py b/dynamiq/nodes/node.py index 0c871801..7eeb8dee 100644 --- a/dynamiq/nodes/node.py +++ b/dynamiq/nodes/node.py @@ -8,6 +8,7 @@ from typing import Any, Callable, ClassVar, Union from uuid import uuid4 +from jinja2 import Template from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, computed_field, model_validator from dynamiq.cache.utils import cache_wf_entity @@ -24,6 +25,13 @@ from dynamiq.nodes.types import NodeGroup from dynamiq.runnables import Runnable, RunnableConfig, RunnableResult, RunnableStatus from dynamiq.storages.vector.base import BaseVectorStoreParams +from dynamiq.types.feedback import ( + ApprovalConfig, + ApprovalInputData, + ApprovalStreamingInputEventMessage, + ApprovalStreamingOutputEventMessage, + FeedbackMethod, +) from dynamiq.types.streaming import STREAMING_EVENT, StreamingConfig, StreamingEventMessage from dynamiq.utils import format_value, generate_uuid, merge from dynamiq.utils.duration import format_duration @@ -204,6 +212,8 @@ class Node(BaseModel, Runnable, ABC): output_transformer: OutputTransformer = Field(default_factory=OutputTransformer) caching: CachingConfig = Field(default_factory=CachingConfig) streaming: StreamingConfig = Field(default_factory=StreamingConfig) + approval: ApprovalConfig = Field(default_factory=ApprovalConfig) + depends: list[NodeDependency] = [] metadata: NodeMetadata | None = None @@ -461,6 +471,134 @@ def to_dict(self, include_secure_params: bool = False, **kwargs) -> dict: data["input_mapping"] = format_value(self.input_mapping) return data + def send_streaming_approval_message( + self, template: str, input_data: dict, approval_config: ApprovalConfig, config: RunnableConfig = None, **kwargs + ) -> ApprovalInputData: + """ + Sends approval message and waits for response. + + Args: + template (str): Template to send. + input_data (dict): Data that will be sent. + approval_config (ApprovalConfig): Configuration for approval. + config (RunnableConfig, optional): Configuration for the runnable. + **kwargs: Additional keyword arguments. + + Return: + ApprovalInputData: Response to approval message. + + """ + event = ApprovalStreamingOutputEventMessage( + wf_run_id=config.run_id, + entity_id=self.id, + data={"template": template, "data": input_data, "mutable_data_params": approval_config.mutable_data_params}, + event=approval_config.event, + ) + + logger.info(f"Node {self.name} - {self.id}: sending approval.") + + self.run_on_node_execute_stream(callbacks=config.callbacks, event=event, **kwargs) + + output: ApprovalInputData = self.get_input_streaming_event( + event=approval_config.event, event_msg_type=ApprovalStreamingInputEventMessage, config=config + ).data + + return output + + def send_console_approval_message(self, template: str) -> ApprovalInputData: + """ + Sends approval message in console and waits for response. + + Args: + template (dict): Template to send. + Returns: + ApprovalInputData: Response to approval message. + """ + feedback = input(template) + return ApprovalInputData(feedback=feedback) + + def send_approval_message( + self, approval_config: ApprovalConfig, input_data: dict, config: RunnableConfig = None, **kwargs + ) -> ApprovalInputData: + """ + Sends approval message and determines if it was approved or disapproved (canceled). + + Args: + approval_config (ApprovalConfig): Configuration for the approval. + input_data (dict): Data that will be sent. + config (RunnableConfig, optional): Configuration for the runnable. + **kwargs: Additional keyword arguments. + + Returns: + ApprovalInputData: Result of approval. + """ + + message = Template(approval_config.msg_template).render(self.to_dict(), input_data=input_data) + match approval_config.feedback_method: + case FeedbackMethod.STREAM: + approval_result = self.send_streaming_approval_message( + message, input_data, approval_config, config=config, **kwargs + ) + case FeedbackMethod.CONSOLE: + approval_result = self.send_console_approval_message(message) + case _: + raise ValueError(f"Error: Incorrect feedback method is chosen {approval_config.feedback_method}.") + + update_params = { + feature_name: approval_result.data[feature_name] + for feature_name in approval_config.mutable_data_params + if feature_name in approval_result.data + } + approval_result.data = {**input_data, **update_params} + + if approval_result.is_approved is None: + if approval_result.feedback == approval_config.accept_pattern: + logger.info( + f"Node {self.name} action was approved by human " + f"with provided feedback '{approval_result.feedback}'." + ) + approval_result.is_approved = True + + else: + approval_result.is_approved = False + logger.info( + f"Node {self.name} action was canceled by human" + f"with provided feedback '{approval_result.feedback}'." + ) + + return approval_result + + def get_approved_data_or_origin( + self, input_data: dict[str, Any], config: RunnableConfig = None, **kwargs + ) -> dict[str, Any]: + """ + Approves or disapproves (cancels) Node execution by requesting feedback. + Updates input data according to the feedback or leaves it the same. + Raises NodeException if execution was canceled by feedback. + + Args: + input_data(dict[str, Any]): Input data. + config (RunnableConfig, optional): Configuration for the runnable. + **kwargs: Additional keyword arguments. + + Returns: + dict[str, Any]: Updated input data. + + Raises: + NodeException: If Node execution was canceled by feedback. + """ + if self.approval.enabled: + approval_result = self.send_approval_message(self.approval, input_data, config=config, **kwargs) + if not approval_result.is_approved: + raise NodeException( + message=f"Execution was canceled by human with feedback {approval_result.feedback}", + recoverable=True, + failed_depend=NodeDependency(self, option="Execution was canceled."), + ) + return approval_result.data + + return input_data + def run( self, input_data: Any, @@ -496,6 +634,7 @@ def run( try: try: self.validate_depends(depends_result) + input_data = self.get_approved_data_or_origin(input_data, config=config, **merged_kwargs) except NodeException as e: transformed_input = input_data | { k: result.to_tracing_depend_dict() for k, result in depends_result.items() @@ -508,12 +647,14 @@ def run( **merged_kwargs, ) logger.info(f"Node {self.name} - {self.id}: execution skipped.") - return RunnableResult(status=RunnableStatus.SKIP, input=transformed_input, output=format_value(e)) + return RunnableResult( + status=RunnableStatus.SKIP, + input=transformed_input, + output=format_value(e, recoverable=e.recoverable), + ) transformed_input = self.transform_input(input_data=input_data, depends_result=depends_result) - self.run_on_node_start(config.callbacks, transformed_input, **merged_kwargs) - cache = cache_wf_entity( entity_id=self.id, cache_enabled=self.caching.enabled, @@ -526,6 +667,7 @@ def run( merged_kwargs["is_output_from_cache"] = from_cache transformed_output = self.transform_output(output) + self.run_on_node_end(config.callbacks, transformed_output, **merged_kwargs) logger.info( diff --git a/dynamiq/nodes/tools/__init__.py b/dynamiq/nodes/tools/__init__.py index 3481b03f..aabe8809 100644 --- a/dynamiq/nodes/tools/__init__.py +++ b/dynamiq/nodes/tools/__init__.py @@ -7,6 +7,6 @@ from .python import Python from .retriever import RetrievalTool from .scale_serp import ScaleSerpTool -from .sql_executor import SqlExecutor +from .sql_executor import SQLExecutor from .tavily import TavilyTool from .zenrows import ZenRowsTool diff --git a/dynamiq/nodes/tools/function_tool.py b/dynamiq/nodes/tools/function_tool.py index 2315a6bc..622d20dc 100644 --- a/dynamiq/nodes/tools/function_tool.py +++ b/dynamiq/nodes/tools/function_tool.py @@ -48,7 +48,7 @@ def execute(self, input_data: dict[str, Any], config: RunnableConfig = None, **k config = ensure_config(config) self.run_on_node_execute_run(config.callbacks, **kwargs) - result = self.run_func(input_data) + result = self.run_func(input_data, config=config, **kwargs) logger.info(f"Tool {self.name} - {self.id}: finished with RESULT:\n{str(result)[:200]}...") return {"content": result} @@ -100,10 +100,20 @@ def function_tool(func: Callable[..., T]) -> type[FunctionTool[T]]: def create_input_schema(func) -> type[BaseModel]: signature = inspect.signature(func) - params_dict = {param.name: (param.annotation, param.default) for param in signature.parameters.values()} + + params_dict = {} + + for param in signature.parameters.values(): + if param.name == "kwargs" or param.name == "config": + continue + if param.default is inspect.Parameter.empty: + params_dict[param.name] = (param.annotation, ...) + else: + params_dict[param.name] = (param.annotation, param.default) + return create_model( "FunctionToolInputSchema", - **{k: (v[0], ...) if v[1] is inspect.Parameter.empty else (v[0], v[1]) for k, v in params_dict.items()}, + **params_dict, model_config=dict(extra="allow"), ) @@ -116,8 +126,8 @@ class FunctionToolFromDecorator(FunctionTool[T]): _original_func = staticmethod(func) input_schema: ClassVar[type[BaseModel]] = create_input_schema(func) - def run_func(self, input_data: BaseModel, **_) -> T: - return func(**input_data.model_dump()) + def run_func(self, input_data: BaseModel, **kwargs) -> T: + return func(**input_data.model_dump(), **kwargs) FunctionToolFromDecorator.__name__ = func.__name__ FunctionToolFromDecorator.__qualname__ = ( diff --git a/dynamiq/nodes/tools/human_feedback.py b/dynamiq/nodes/tools/human_feedback.py index 24df9e4e..fa3bbc53 100644 --- a/dynamiq/nodes/tools/human_feedback.py +++ b/dynamiq/nodes/tools/human_feedback.py @@ -1,4 +1,3 @@ -import enum from abc import ABC, abstractmethod from typing import Any, ClassVar, Literal @@ -7,15 +6,11 @@ from dynamiq.nodes import NodeGroup from dynamiq.nodes.node import Node, ensure_config from dynamiq.runnables import RunnableConfig +from dynamiq.types.feedback import FeedbackMethod from dynamiq.types.streaming import StreamingEventMessage from dynamiq.utils.logger import logger -class InputMethod(str, enum.Enum): - console = "console" - stream = "stream" - - class HFStreamingInputEventMessageData(BaseModel): content: str @@ -54,6 +49,26 @@ def get_input(self, prompt: str, **kwargs) -> str: pass +class OutputMethodCallable(ABC): + """ + Abstract base class for sending message. + + This class defines the interface for various output methods that can be used + to send user information in the MessageSenderTool. + """ + + @abstractmethod + def send_message(self, message: str, **kwargs) -> None: + """ + Sends message to the user + + Args: + message (str): The message to send to the user. + """ + + pass + + class HumanFeedbackInputSchema(BaseModel): input: str = Field(..., description="Parameter to provide a question to the user.") @@ -69,13 +84,15 @@ class HumanFeedbackTool(Node): group (Literal[NodeGroup.TOOLS]): The group the node belongs to. name (str): The name of the tool. description (str): A brief description of the tool's purpose. - input_method (InputMethod): The method used to gather user input. + input_method (FeedbackMethod | InputMethodCallable): The method used to gather user input. """ group: Literal[NodeGroup.TOOLS] = NodeGroup.TOOLS name: str = "Human Feedback Tool" - description: str = "Tool to gather user information. Use it to check actual information or get additional input." - input_method: InputMethod | InputMethodCallable = InputMethod.console + description: str = ( + "Tool to gather feedback from the user. Use it to check actual information or get additional input." + ) + input_method: FeedbackMethod | InputMethodCallable = FeedbackMethod.CONSOLE input_schema: ClassVar[type[HumanFeedbackInputSchema]] = HumanFeedbackInputSchema model_config = ConfigDict(arbitrary_types_allowed=True) @@ -145,14 +162,14 @@ def execute( self.run_on_node_execute_run(config.callbacks, **kwargs) input_text = input_data.input - if isinstance(self.input_method, InputMethod): - if self.input_method == InputMethod.console: + if isinstance(self.input_method, FeedbackMethod): + if self.input_method == FeedbackMethod.CONSOLE: result = self.input_method_console(input_text) - elif self.input_method == InputMethod.stream: + elif self.input_method == FeedbackMethod.STREAM: streaming = getattr(config.nodes_override.get(self.id), "streaming", None) or self.streaming if not streaming.input_streaming_enabled: raise ValueError( - f"'{InputMethod.stream.value}' input method requires enabled input and output streaming." + f"'{FeedbackMethod.STREAM}' input method requires enabled input and output streaming." ) result = self.input_method_streaming(prompt=input_text, config=config, **kwargs) @@ -163,3 +180,97 @@ def execute( logger.debug(f"Tool {self.name} - {self.id}: finished with result {result}") return {"content": result} + + +class MessageSenderInputSchema(BaseModel): + input: str = Field(..., description="Parameter to provide a message to the user.") + + +class MessageSenderTool(Node): + """ + A tool for sending messsages. + + Attributes: + group (Literal[NodeGroup.TOOLS]): The group the node belongs to. + name (str): The name of the tool. + description (str): A brief description of the tool's purpose. + output_method (FeedbackMethod | InputMethodCallable): The method used to gather user input. + """ + + group: Literal[NodeGroup.TOOLS] = NodeGroup.TOOLS + name: str = "Message Sender Tool" + description: str = "Tool can be used send messages to the user." + output_method: FeedbackMethod | OutputMethodCallable = FeedbackMethod.CONSOLE + input_schema: ClassVar[type[MessageSenderInputSchema]] = MessageSenderInputSchema + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def output_method_console(self, prompt: str) -> None: + """ + Sends message to console. + + Args: + prompt (str): The prompt to display to the user. + """ + print(prompt) + + def output_method_streaming(self, prompt: str, config: RunnableConfig, **kwargs) -> None: + """ + Sends message using streaming method. + + Args: + prompt (str): The prompt to display to the user. + config (RunnableConfig, optional): The configuration for the runnable. Defaults to None. + """ + + event = HFStreamingOutputEventMessage( + wf_run_id=config.run_id, + entity_id=self.id, + data=HFStreamingOutputEventMessageData(prompt=prompt), + event=self.streaming.event, + ) + logger.debug(f"Tool {self.name} - {self.id}: sending output event {event}") + self.run_on_node_execute_stream(callbacks=config.callbacks, event=event, **kwargs) + + def execute( + self, input_data: MessageSenderInputSchema, config: RunnableConfig | None = None, **kwargs + ) -> dict[str, Any]: + """ + Execute the tool with the provided input data and configuration. + + This method prompts the user for input using the specified input method and returns the result. + + Args: + input_data (dict[str, Any]): The input data containing the prompt for the user. + config (RunnableConfig, optional): The configuration for the runnable. Defaults to None. + **kwargs: Additional keyword arguments to be passed to the node execute run. + + Returns: + dict[str, Any]: A dictionary containing the user's input under the 'content' key. + + Raises: + ValueError: If the input_data does not contain an 'input' key. + """ + logger.debug(f"Tool {self.name} - {self.id}: started with input data {input_data.model_dump()}") + config = ensure_config(config) + self.run_on_node_execute_run(config.callbacks, **kwargs) + + input_text = input_data.input + if isinstance(self.output_method, FeedbackMethod): + if self.output_method == FeedbackMethod.CONSOLE: + self.output_method_console(input_text) + elif self.output_method == FeedbackMethod.STREAM: + streaming = getattr(config.nodes_override.get(self.id), "streaming", None) or self.streaming + if not streaming.input_streaming_enabled: + raise ValueError( + f"'{FeedbackMethod.STREAM}' input method requires enabled input and output streaming." + ) + + self.output_method_streaming(prompt=input_text, config=config, **kwargs) + else: + raise ValueError(f"Unsupported feedback method: {self.output_method}") + else: + self.output_method.send_message(input_text) + + logger.debug(f"Tool {self.name} - {self.id}: finished") + return {"content": f"Message {input_text} was sent."} diff --git a/dynamiq/nodes/tools/sql_executor.py b/dynamiq/nodes/tools/sql_executor.py index 9de1336c..85185b15 100644 --- a/dynamiq/nodes/tools/sql_executor.py +++ b/dynamiq/nodes/tools/sql_executor.py @@ -11,10 +11,10 @@ class SQLInputSchema(BaseModel): - query: str = Field(..., description="Parameter to provide a query that needs to be executed.") + query: str | None = Field(None, description="Parameter to provide a query that needs to be executed.") -class SqlExecutor(ConnectionNode): +class SQLExecutor(ConnectionNode): """ A tool for SQL query execution. @@ -33,6 +33,7 @@ class SqlExecutor(ConnectionNode): "You can use this tool to execute the query, specified for PostgreSQL, MySQL, Snowflake, AWS Redshift." ) connection: PostgreSQL | MySQL | Snowflake | AWSRedshift + query: str | None = None model_config = ConfigDict(arbitrary_types_allowed=True) @@ -44,8 +45,10 @@ def execute(self, input_data, config: RunnableConfig = None, **kwargs) -> dict[s config = ensure_config(config) self.run_on_node_execute_run(config.callbacks, **kwargs) - query = input_data.query + query = input_data.query or self.query try: + if not query: + raise ValueError("Query cannot be empty") cursor = self.client.cursor( **self.connection.cursor_params if not isinstance(self.connection, (PostgreSQL, AWSRedshift)) else {} ) diff --git a/dynamiq/types/feedback.py b/dynamiq/types/feedback.py new file mode 100644 index 00000000..f51b271e --- /dev/null +++ b/dynamiq/types/feedback.py @@ -0,0 +1,76 @@ +from enum import Enum +from typing import Any + +from pydantic import BaseModel, Field, model_validator + +from dynamiq.types.streaming import StreamingEventMessage + + +class FeedbackMethod(Enum): + CONSOLE = "console" + STREAM = "stream" + + +APPROVAL_EVENT = "approval" +PLAN_APPROVAL_EVENT = "plan_approval" + + +class ApprovalOutputEventData(BaseModel): + template: str = Field(..., description="Message template that will be sent.") + data: dict[str, Any] = Field(..., description="Data that will be sent.") + mutable_data_params: list[str] = Field( + ..., + description=( + "List of parameters from 'data'" + "field that will be possible to update." + "This field is used to secure data that shouldn't be modified." + ), + ) + + +class ApprovalStreamingOutputEventMessage(StreamingEventMessage): + data: ApprovalOutputEventData + + +class ApprovalInputData(BaseModel): + feedback: str = None + data: dict[str, Any] = {} + is_approved: bool | None = None + + @model_validator(mode="after") + def validate_feedback(self): + if self.feedback is None and self.is_approved is None: + raise ValueError("Error: No feedback or approval has been provided.") + return self + + +class ApprovalStreamingInputEventMessage(StreamingEventMessage): + data: ApprovalInputData + + +class ApprovalConfig(BaseModel): + enabled: bool = False + feedback_method: FeedbackMethod = FeedbackMethod.CONSOLE + + mutable_data_params: list[str] = [] + msg_template: str = """ + Node {{name}}: Approve or cancel execution. Send nothing for approval; provide feedback to cancel. + """ + + event: str = APPROVAL_EVENT + accept_pattern: str = "" + + +class PlanApprovalConfig(ApprovalConfig): + msg_template: str = ( + """ + Approve or cancel plan. Send nothing for approval; provide feedback to cancel. + {% for task in input_data.tasks %} + Task name: {{ task.name }} + Task description: {{ task.description }} + Task dependencies: {{ task.dependencies }} + {% endfor %} + """ + ) + + event: str = PLAN_APPROVAL_EVENT diff --git a/dynamiq/types/streaming.py b/dynamiq/types/streaming.py index 9ce8d2eb..989e2bc8 100644 --- a/dynamiq/types/streaming.py +++ b/dynamiq/types/streaming.py @@ -12,8 +12,8 @@ class StreamingMode(str, Enum): """Enumeration for streaming modes.""" - FINAL = "final" # Only final output - ALL = "all" # All intermediate steps and final output + FINAL = "final" # Streams only final output in agents nodes. + ALL = "all" # Streams all intermediate steps and final output in agents and llms nodes. STREAMING_EVENT = "streaming" @@ -29,6 +29,7 @@ class StreamingEventMessage(BaseModel): data (Any): Data associated with the event. event (str | None): Event name. Defaults to "streaming". """ + run_id: str | None = None wf_run_id: str | None = Field(default_factory=generate_uuid) entity_id: str diff --git a/dynamiq/utils/feedback.py b/dynamiq/utils/feedback.py new file mode 100644 index 00000000..754613b6 --- /dev/null +++ b/dynamiq/utils/feedback.py @@ -0,0 +1,24 @@ +from dynamiq.callbacks.streaming import StreamingEventMessage +from dynamiq.runnables import RunnableConfig +from dynamiq.types.feedback import FeedbackMethod + + +def send_message( + event_message: StreamingEventMessage, + config: RunnableConfig, + feedback_method: FeedbackMethod = FeedbackMethod.STREAM, +) -> None: + """Emits message + + Args: + message (StreamingEventMessage): Message to send. + config (RunnableConfig): Configuration for the runnable. + feedback_method (FeedbackMethod, optional): Sets up where message is sent. Defaults to "stream". + """ + + match feedback_method: + case FeedbackMethod.CONSOLE: + print(event_message.data) + case FeedbackMethod.STREAM: + for callback in config.callbacks: + callback.on_node_execute_stream({}, event=event_message) diff --git a/examples/graph_like/code_assistant.py b/examples/graph_like/code_assistant.py index d11a80f6..3929bffd 100644 --- a/examples/graph_like/code_assistant.py +++ b/examples/graph_like/code_assistant.py @@ -11,6 +11,7 @@ from dynamiq.nodes.types import InferenceMode from dynamiq.prompts import Message, Prompt from dynamiq.runnables import RunnableConfig, RunnableResult, RunnableStatus +from dynamiq.utils.logger import logger from examples.llm_setup import setup_llm @@ -61,12 +62,12 @@ def code_llm(messages, structured_output=True): return json.loads(llm_result) if structured_output else llm_result - def generate_code_solution(context: dict[str, Any]): + def generate_code_solution(context: dict[str, Any], **kwargs): """ Generate a code solution """ - print("#####CODE GENERATION#####") + logger.info("CODE GENERATION") messages = context.get("messages") @@ -87,8 +88,8 @@ def generate_code_solution(context: dict[str, Any]): return {"result": code_solution, **context} - def reflect(context: dict[str, Any]): - print("#####REFLECTING ON ERRORS#####") + def reflect(context: dict[str, Any], **kwargs): + logger.info("REFLECTING ON ERRORS") reflections = code_llm(messages=context.get("messages")) context["messages"] += [Message(role="assistant", content=f"Here are reflections on the error: {reflections}")] return {"result": reflections, **context} @@ -104,25 +105,25 @@ def validate_code(context: dict[str, Any], **_): state (dict): New key added to state, error """ - print("#####CHECKING#####") + logger.info("CHECKING") solution = context["solution"] result = tool_code.run(input_data={"python": solution.get("code"), "packages": solution.get("libraries")}) if result.status == RunnableStatus.SUCCESS: - print("#####SUCCESSFUL#####") + logger.info("SUCCESSFUL") successful_message = [ Message(role="user", content=f"Your code executed successfully {result.output['content']}") ] context["messages"] += successful_message context["reiterate"] = False else: - print("#####FAILED#####") + logger.info("FAILED") error_message = [ Message( role="user", content=( - f"Your solution failed the code execution test: {result.output['content']}." - " Reflect on possible errors." + f"Your solution failed to execute: {result.output['content']}." + " Reflect on possible errors and solutions." ), ) ] @@ -144,7 +145,7 @@ def validate_code(context: dict[str, Any], **_): orchestrator.add_edge("generate_code", "validate_code") orchestrator.add_edge("reflect", "generate_code") - def orchestrate(context: dict[str, Any]) -> str: + def orchestrate(context: dict[str, Any], **kwargs) -> str: return "reflect" if context["reiterate"] else END orchestrator.add_conditional_edge("validate_code", ["generate_code", END], orchestrate) @@ -169,7 +170,7 @@ def run_orchestrator(request="Write 100 lines of code.") -> RunnableResult: config=RunnableConfig(callbacks=[tracing]), ) - print(result.output[orchestrator.id]) + logger.info(result.output[orchestrator.id]) return result.output[orchestrator.id]["output"]["content"], tracing.runs diff --git a/examples/graph_like/concierge_orchestration.py b/examples/graph_like/concierge_orchestration.py index 2eb37576..2a1f8424 100644 --- a/examples/graph_like/concierge_orchestration.py +++ b/examples/graph_like/concierge_orchestration.py @@ -4,7 +4,6 @@ from dynamiq import Workflow from dynamiq.callbacks import TracingCallbackHandler from dynamiq.flows import Flow -from dynamiq.nodes.agents.base import Agent from dynamiq.nodes.agents.orchestrators.graph import END, GraphOrchestrator from dynamiq.nodes.agents.orchestrators.graph_manager import GraphAgentManager from dynamiq.nodes.agents.react import ReActAgent @@ -16,33 +15,6 @@ from examples.llm_setup import setup_llm -class CustomToolAgent(Agent): - def execute(self, input_data, config: RunnableConfig | None = None, **kwargs): - """ - Executes the agent with the given input data. - """ - logger.debug(f"Agent {self.name} - {self.id}: started with input {input_data}") - self.reset_run_state() - self.run_on_node_execute_run(config.callbacks, **kwargs) - - tool_result = self.tools[0].run( - input_data=input_data, - config=config, - **kwargs, - ) - - execution_result = { - "content": tool_result, - "intermediate_steps": self._intermediate_steps, - } - - if self.streaming.enabled: - self.run_on_node_execute_stream(config.callbacks, execution_result, **kwargs) - - logger.debug(f"Agent {self.name} - {self.id}: finished with result {tool_result}") - return execution_result.output.get("context") - - def create_workflow() -> Workflow: """ Create the workflow with all necessary agents and tools. @@ -54,10 +26,18 @@ def create_workflow() -> Workflow: llm = setup_llm() # Stock Lookup Agent - def stock_lookup(context: dict[str, Any]): + def stock_lookup(context: dict[str, Any], **kwargs): + + def search_for_stock_symbol(str: str) -> str: + """Useful for searching for a stock symbol given a company name.""" + logger.info("Searching for stock symbol") + return str.upper() + @function_tool - def lookup_stock_price(stock_symbol: str) -> str: - """Useful for looking up a stock price.""" + def lookup_stock_price(company_name: str, **kwargs) -> str: + """Useful for looking up a stock price .""" + + stock_symbol = search_for_stock_symbol(company_name) logger.info(f"Looking up stock price for {stock_symbol}") context["current_task"] = "" @@ -65,65 +45,43 @@ def lookup_stock_price(stock_symbol: str) -> str: return f"Symbol {stock_symbol} is currently trading at $100.00" - @function_tool - def search_for_stock_symbol(str: str) -> str: - """Useful for searching for a stock symbol given a free-form company name.""" - logger.info("Searching for stock symbol") - return str.upper() - stock_lookup_agent = ReActAgent( name="stock_lookup_agent", - role="""You are a helpful assistant that is looking up stock prices. - The user may not know the stock symbol of the company they're interested in, - so you can help them look it up by the name of the company. - You can only look up stock symbols given to you by the search_for_stock_symbol tool, don't make them up. - Trust the output of the search_for_stock_symbol tool even if it doesn't make sense to you.""", + role="""You are a helpful assistant that is searching for stock prices.""", goal="Provide actions according to role", # noqa: E501 llm=llm, - tools=[lookup_stock_price(), search_for_stock_symbol()], + tools=[lookup_stock_price()], ) + + company_name = input("Provide company name for which you want to find stock price.\n") result = stock_lookup_agent.run( - input_data={"input": "Get stock price for nvidia."}, + input_data={"input": f"Get stock price for {company_name}."}, ) return {"result": result.output.get("content"), **context} # Auth Agent - def authentificate(context: dict[str, Any]): - @function_tool - def store_username(username: str) -> None: - """Adds the username to the user state.""" - print("Recording username") - return "Username was recorded." + def authenticate(context: dict[str, Any], **kwargs): @function_tool - def login(password: str) -> None: - """Given a password, logs in and stores a session token in the user state.""" + def login(password: str, **kwargs) -> None: + """Given a password, logs in and stores a session in the user state.""" logger.info(f"Logging in with password {password}") context["authenticated"] = True return f"Sucessfully logged in with password {password}" @function_tool - def is_authenticated() -> bool: - """Checks if the user has a session token.""" - print("Checking if authenticated") - return "User is authenticated" - - @function_tool - def done() -> None: - """When you complete your task, call this tool.""" - logger.info("Authentication is complete") - return "Authentication is complete" + def is_authenticated(**kwargs) -> bool: + """Checks if the user has saved session in user state.""" + logger.info("Checking if authenticated") + return context.get("authenticated", False) auth_agent = ReActAgent( name="auth_agent", - role="""You are a helpful assistant that is authenticating a user. - Your task is to get a valid session token stored in the user state. - To do this, the user must supply you with a username and a valid password. You can ask them to supply these. - If the user supplies a username and password, call the tool "login" to log them in.""", + role="""You are a helpful assistant that is authenticating a user.""", goal="Provide actions according to role", # noqa: E501 llm=llm, - tools=[store_username(), login(), is_authenticated(), done()], + tools=[login(), is_authenticated()], ) result = auth_agent.run( @@ -132,49 +90,31 @@ def done() -> None: return {"result": result.output.get("content"), **context} - def account_balance(context: dict[str, Any]): + def account_balance(context: dict[str, Any], **kwargs): @function_tool - def get_account_id(account_name: str) -> str: - """Useful for looking up an account ID.""" - print(f"Looking up account ID for {account_name}") - account_id = "1234567890" - return f"Account id is {account_id}" + def get_account_credentials(**kwargs) -> str: + """Useful for looking up account ID.""" + logger.info("Searching the account ID.") + return "Account ID - 1234567890" @function_tool - def get_account_name() -> str: - """Useful for looking up an account name.""" - print("Looking up account for account name") - account_name = "john123" - return f"Account name is {account_name}" + def get_account_balance(account_id: str, **kwargs) -> str: + """Useful for looking up an account balance. Account ID is required.""" - @function_tool - def get_account_balance(account_id: str) -> str: - """Useful for looking up an account balance.""" logger.info(f"Looking up account balance for {account_id}") - context["current_task"] = "" context["task_result"] = "Account has a balance of $1000" return f"Account {account_id} has a balance of $1000" - @function_tool - def is_authenticated() -> bool: - """Checks if the user has a session token.""" - logger.info("Account balance agent is checking if authenticated") - return "User is authentificated" - account_balance_agent = ReActAgent( name="account_balance_agent", role=""" You are a helpful assistant that is looking up account balances. - The user may not know the account ID of the account they're interested in, - so you can help them look it up by the name of the account. - If they're trying to transfer money, they have to check their account balance\ - first, which you can help with. """, goal="Provide actions according to role", # noqa: E501 llm=llm, - tools=[get_account_id(), get_account_balance(), is_authenticated(), get_account_name()], + tools=[get_account_credentials(), get_account_balance()], ) result = account_balance_agent.run( @@ -184,57 +124,49 @@ def is_authenticated() -> bool: return {"result": result.output.get("content"), **context} # Transfer Money Agent - def transfer_money(context: dict[str, Any]): + def transfer_money(context: dict[str, Any], **kwargs): @function_tool - def transfer_money(from_account_id: str, to_account_id: str, amount: int) -> None: + def transfer_money(from_account_id: str, to_account_id: str, amount: int, **kwargs) -> None: """Useful for transferring money between accounts.""" + logger.info(f"Transferring {amount} from {from_account_id} account {to_account_id}") context["current_task"] = "" - context["task_result"] = "Money was successfully transferred" + context["task_result"] = f"{amount}$ was successfully transferred" return f"Transferred {amount} to account {to_account_id}" @function_tool - def balance_sufficient(account_id: str, amount: int) -> bool: - """Useful for checking if an account has enough money to transfer.""" - logger.info("Checking if balance is sufficient") - return "There is enough money." - - @function_tool - def has_balance() -> bool: - """Useful for checking if an account has a balance.""" - logger.info("Checking if account has a balance") - return "Account has enough balance" + def check_balance(account_id: str, amount: int, **kwargs) -> bool: + """Useful for checking if an account has enough money. + Account ID and amount of money to transfer are required.""" - @function_tool - def is_authenticated() -> bool: - """Checks if the user has a session token.""" - logger.info("Transfer money agent is checking if authenticated") - return "User has a session token." + logger.info(f"Checking if there is more than {amount} on account with ID - {account_id}") + return "Balance has sufficient amount of money." transfer_money_agent = ReActAgent( name="transfer_money_agent", role=""" You are a helpful assistant that transfers money between accounts. - The user can only do this if they are authenticated, which you can check with the is_authenticated tool. - If they aren't authenticated, tell them to authenticate first. - The user must also have looked up their account balance already,\ - which you can check with the has_balance tool. """, goal="Provide actions according to role", # noqa: E501 llm=llm, - tools=[transfer_money(), balance_sufficient(), has_balance(), is_authenticated()], + tools=[transfer_money(), check_balance()], ) + + sender_id = input("Provide account ID of sender.\n") + receiver_id = input("Provide account ID of receiver.\n") + amount = input("Provide amount of money you want to transfer.\n") + result = transfer_money_agent.run( - input_data={"input": "Transfer 100$ from account 71829301827 to 81092837881."}, + input_data={"input": f"Transfer {amount}$ from account {sender_id} to {receiver_id}."}, ) return {"result": result.output.get("content"), **context} human_feedback_tool = HumanFeedbackTool() - def concierge(context: dict[str, Any]): + def concierge(context: dict[str, Any], **kwargs): if current_task := context.get("current_task"): return {"result": f"Proceed with task {current_task}"} @@ -248,8 +180,7 @@ def concierge(context: dict[str, Any]): "* looking up a stock price\n" "* authenticating the user\n" "* checking an account balance (requires authentication first)\n" - "* transferring money between accounts (requires authentication\n" - "and checking an account balance first)\n" + "* transferring money between accounts (requires authentication)\n" ) result = human_feedback_tool.run( @@ -267,16 +198,14 @@ def concierge(context: dict[str, Any]): graph_orchestrator = GraphOrchestrator(manager=agent_manager, final_summarizer=True, initial_state="concierge") # Orchestration path function - def orchestration(context: dict[str, Any]): + def orchestrate(context: dict[str, Any], **kwargs): formatted_prompt = f""" - You are on orchestration agent. - Your job is to decide which state to run based on the current state of the user - and what they've asked to do. - You run an next state by calling the appropriate state name. - You do not need to call more than one state. + You are an orchestration agent. + Your task is to decide and trigger the next state based on the user's current state and request. + Only one state should be called. - Just return name of the state you want to do next: + Just return name of the state: * stock_lookup - To find stock price * transfer_money - To transfer money * authenticate - To authentificate @@ -308,12 +237,12 @@ def orchestration(context: dict[str, Any]): graph_orchestrator.add_state_by_tasks("concierge", [concierge]) graph_orchestrator.add_state_by_tasks("stock_lookup", [stock_lookup]) graph_orchestrator.add_state_by_tasks("transfer_money", [transfer_money]) - graph_orchestrator.add_state_by_tasks("authenticate", [authentificate]) + graph_orchestrator.add_state_by_tasks("authenticate", [authenticate]) graph_orchestrator.add_state_by_tasks("account_balance", [account_balance]) # Add path to other nodes through concierge graph_orchestrator.add_conditional_edge( - "concierge", ["stock_lookup", "transfer_money", "authenticate", "account_balance", END], orchestration + "concierge", ["stock_lookup", "transfer_money", "authenticate", "account_balance", END], orchestrate ) # Add path back from specialized states to concierge diff --git a/examples/human_in_the_loop/confirmation_email_writer/console/console_agent.py b/examples/human_in_the_loop/confirmation_email_writer/console/console_agent.py new file mode 100644 index 00000000..80aefea2 --- /dev/null +++ b/examples/human_in_the_loop/confirmation_email_writer/console/console_agent.py @@ -0,0 +1,65 @@ +from dynamiq.nodes.agents import ReActAgent +from dynamiq.nodes.tools.human_feedback import HumanFeedbackTool, MessageSenderTool +from dynamiq.nodes.tools.python import Python +from dynamiq.types.feedback import ApprovalConfig, FeedbackMethod +from examples.llm_setup import setup_llm + +PYTHON_TOOL_CODE = """ +def run(inputs): + return {"content": "Email was sent."} +""" + + +def run_agent(query) -> dict: + """ + Creates agent + + Returns: + dict: Agent final output. + """ + llm = setup_llm() + + email_sender_tool = Python( + name="EmailSenderTool", + description="Sends email. Put all email in string under 'email' key. ", + code=PYTHON_TOOL_CODE, + approval=ApprovalConfig( + enabled=True, + feedback_method=FeedbackMethod.CONSOLE, + msg_template=( + "Email sketch: {{input_data.email}}.\n" + "Approve or cancel email sending. Send nothing for approval;" + "provide feedback to cancel and regenerate." + ), + ), + ) + + human_feedback_tool = HumanFeedbackTool( + description="This tool can be used to request some clarifications from user.", + input_method=FeedbackMethod.CONSOLE, + ) + + message_sender_tool = MessageSenderTool(output_method=FeedbackMethod.CONSOLE) + + agent = ReActAgent( + name="research_agent", + role=( + "You are a helpful assistant that has access to the internet using Tavily Tool." + "You can request some clarifications using HumanFeedbackTool and send messages using MessageSenderTool " + ), + llm=llm, + tools=[email_sender_tool, human_feedback_tool, message_sender_tool], + ) + + return agent.run( + input_data={ + "input": f"Write and send email: {query}. Notify user about status of email using MessageSenderTool." + }, + ).output["content"] + + +if __name__ == "__main__": + query = input("Describe details of the email: ") + result = run_agent(query) + print("Result: ") + print(result) diff --git a/examples/human_in_the_loop/confirmation_email_writer/socket/client.py b/examples/human_in_the_loop/confirmation_email_writer/socket/client.py new file mode 100644 index 00000000..3a434971 --- /dev/null +++ b/examples/human_in_the_loop/confirmation_email_writer/socket/client.py @@ -0,0 +1,46 @@ +import asyncio +import json +import logging + +import websockets + +from dynamiq.types.feedback import APPROVAL_EVENT +from dynamiq.types.streaming import StreamingEventMessage +from examples.human_in_the_loop.confirmation_email_writer.socket.socket_agent import HOST, PORT, WF_ID, SocketMessage + +logger = logging.getLogger(__name__) + +WS_URI = f"ws://{HOST}:{PORT}/ws" + + +async def websocket_client(): + async with websockets.connect(WS_URI) as websocket: + input_query = input("Provide email details: ") + wf_run_event = SocketMessage(type="run", content=input_query) + await websocket.send(wf_run_event.to_json()) + + try: + while True: + event_raw = json.loads(await websocket.recv()) + if event_raw["event"] == "approval": + feedback = input(event_raw["data"]["template"]) + feedback = StreamingEventMessage( + entity_id=WF_ID, data={"feedback": feedback}, event=APPROVAL_EVENT + ).to_json() + + else: + feedback = input(event_raw["data"]["prompt"]) + feedback = StreamingEventMessage(entity_id=WF_ID, data={"content": feedback}).to_json() + await websocket.send(SocketMessage(type="message", content=feedback).to_json()) + + if "finish" in event_raw["data"] and event_raw["data"]["finish"]: + break + + except websockets.ConnectionClosed: + logger.error("WebSocket connection closed by the server") + + await websocket.close() + + +if __name__ == "__main__": + asyncio.run(websocket_client()) diff --git a/examples/human_in_the_loop/confirmation_email_writer/socket/socket_agent.py b/examples/human_in_the_loop/confirmation_email_writer/socket/socket_agent.py new file mode 100644 index 00000000..a3233a88 --- /dev/null +++ b/examples/human_in_the_loop/confirmation_email_writer/socket/socket_agent.py @@ -0,0 +1,147 @@ +import asyncio +import logging +from queue import Queue +from typing import Any, Literal + +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from pydantic import BaseModel + +from dynamiq.callbacks.streaming import AsyncStreamingIteratorCallbackHandler +from dynamiq.nodes.agents.react import ReActAgent +from dynamiq.nodes.tools.human_feedback import HumanFeedbackTool, MessageSenderTool +from dynamiq.nodes.tools.python import Python +from dynamiq.runnables import RunnableConfig +from dynamiq.types.feedback import ApprovalConfig, FeedbackMethod +from dynamiq.types.streaming import StreamingConfig +from dynamiq.utils.logger import logger +from examples.llm_setup import setup_llm + +HOST = "127.0.0.1" +PORT = 6001 + +WF_ID = "dd643e12-fe89-4eef-b48c-050f01c74517" +app = FastAPI() + +logging.basicConfig(level=logging.INFO) + + +class SocketMessage(BaseModel): + type: Literal["run", "message"] + content: Any + + def to_json(self, **kwargs) -> str: + """Convert to JSON string. + + Returns: + str: JSON string representation. + """ + return self.model_dump_json(**kwargs) + + +PYTHON_TOOL_CODE = """ +def run(inputs): + return {"content": "Email was sent."} +""" + + +def run_agent(request: str, input_queue: Queue, send_handler: AsyncStreamingIteratorCallbackHandler) -> dict: + """ + Creates agent + + Returns: + dict: Agent final output. + """ + llm = setup_llm() + + email_sender_tool = Python( + name="EmailSenderTool", + description="Sends email. Put all email in string under 'email' key. ", + code=PYTHON_TOOL_CODE, + approval=ApprovalConfig( + enabled=True, + feedback_method=FeedbackMethod.STREAM, + msg_template=( + "Email sketch: {{input_data.email}}.\n" + "Approve or cancel email sending. Send nothing for approval;" + "provide feedback to cancel and regenerate." + ), + ), + streaming=StreamingConfig(enabled=True, input_queue=input_queue), + ) + + human_feedback_tool = HumanFeedbackTool( + description="This tool can be used to request some clarifications from user.", + input_method=FeedbackMethod.STREAM, + streaming=StreamingConfig(enabled=True, input_queue=input_queue), + ) + + message_sender_tool = MessageSenderTool( + output_method=FeedbackMethod.STREAM, streaming=StreamingConfig(enabled=True, input_queue=input_queue) + ) + + agent = ReActAgent( + name="research_agent", + role=( + "You are a helpful assistant that has access to the internet using Tavily Tool." + "You can request some clarifications using HumanFeedbackTool and send messages using MessageSenderTool " + ), + llm=llm, + tools=[email_sender_tool, human_feedback_tool, message_sender_tool], + ) + + return agent.run( + input_data={ + "input": f"Write and send email: {request}. Notify user about status of email using MessageSenderTool." + }, + config=RunnableConfig(callbacks=[send_handler]), + ).output["content"] + + +async def _send_stream_events_by_ws(websocket: WebSocket, send_handler: Any): + try: + async for event in send_handler: + await websocket.send_text(event.to_json()) + logger.info("All streaming events sent") + except WebSocketDisconnect as e: + logger.error(f"WebSocket disconnected. Error: {e}") + except asyncio.CancelledError: + logger.error("Task cancelled") + except Exception as e: + logger.error(f"Unexpected error. Error: {e}") + finally: + pass + + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + logger.info("WebSocket connected") + + message_queue = Queue() + + try: + while True: + message_raw = await websocket.receive_text() + message = SocketMessage.model_validate_json(message_raw) + + if message.type == "run": + send_handler = AsyncStreamingIteratorCallbackHandler() + + asyncio.create_task(_send_stream_events_by_ws(websocket, send_handler)) + await asyncio.sleep(0.01) + + asyncio.get_running_loop().run_in_executor( + None, run_agent, message.content, message_queue, send_handler + ) + + elif message.type == "message": + message_queue.put(message.content) + + except WebSocketDisconnect: + logging.info("WebSocket disconnected") + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host=HOST, port=PORT) diff --git a/examples/human_in_the_loop/email_writer_assistant.py b/examples/human_in_the_loop/email_writer_assistant.py deleted file mode 100644 index 7d7af9bb..00000000 --- a/examples/human_in_the_loop/email_writer_assistant.py +++ /dev/null @@ -1,78 +0,0 @@ -from typing import Any - -from dynamiq.nodes.agents.orchestrators.graph import END, START, GraphOrchestrator -from dynamiq.nodes.agents.orchestrators.graph_manager import GraphAgentManager -from dynamiq.prompts import Message, Prompt -from examples.llm_setup import setup_llm - -llm = setup_llm() - - -def generate_sketch(context: dict[str, Any]): - """Generate draft of email""" - messages = context.get("messages") - - if feedback := context.get("feedback"): - messages.append(Message(role="user", content=f"Generate text again taking into account feedback {feedback}")) - - response = llm.run( - input_data={}, - prompt=Prompt( - messages=messages, - ), - ).output["content"] - - messages.append(Message(role="assistant", content=response)) - - return {"result": response, "messages": messages} - - -def gather_feedback(context: dict[str, Any]): - """Gather feedback about email draft.""" - feedback = input( - f"Email draft:\n" - f"{context['messages'][-1]['content']}\n" - f"Type in SEND to send email, CANCEL to exit, or provide feedback to refine email: \n" - ) - - return {"result": feedback, "feedback": feedback} - - -def router(context: dict[str, Any]): - """Determines next state based on provided feedback.""" - feedback = context.get("feedback") - - if feedback == "SEND": - print("######### Email was sent! #########") - return END - - if feedback == "CANCEL": - print("######### Email was NOT sent! #########") - return END - - return "generate_sketch" - - -orchestrator = GraphOrchestrator( - name="Graph orchestrator", - manager=GraphAgentManager(llm=llm), -) - -orchestrator.add_state_by_tasks("generate_sketch", [generate_sketch]) -orchestrator.add_state_by_tasks("gather_feedback", [gather_feedback]) - -orchestrator.add_edge(START, "generate_sketch") -orchestrator.add_edge("generate_sketch", "gather_feedback") - -orchestrator.add_conditional_edge("gather_feedback", ["generate_sketch", END], router) - - -if __name__ == "__main__": - print("Welcome to email writer.") - email_details = input("Provide email details: ") - - orchestrator.context = { - "messages": [Message(role="user", content=email_details)], - } - - orchestrator.run(input_data={}) diff --git a/examples/human_in_the_loop/planning_approval/orchestrator.py b/examples/human_in_the_loop/planning_approval/orchestrator.py new file mode 100644 index 00000000..0ace96fb --- /dev/null +++ b/examples/human_in_the_loop/planning_approval/orchestrator.py @@ -0,0 +1,100 @@ +from dotenv import load_dotenv + +from dynamiq import Workflow +from dynamiq.callbacks import TracingCallbackHandler +from dynamiq.connections import Tavily as TavilyConnection +from dynamiq.flows import Flow +from dynamiq.nodes.agents.orchestrators import LinearOrchestrator +from dynamiq.nodes.agents.orchestrators.linear_manager import LinearAgentManager +from dynamiq.nodes.agents.react import ReActAgent +from dynamiq.nodes.tools import TavilyTool +from dynamiq.runnables import RunnableConfig +from dynamiq.types.feedback import PlanApprovalConfig +from dynamiq.utils.logger import logger +from examples.llm_setup import setup_llm + +# Load environment variables +load_dotenv() + +AGENT_RESEARCHER_ROLE = ( + "An expert in gathering information about a job. " + "The goal is to analyze the company website and the provided description " + "to extract insights on culture, values, and specific needs." +) + +AGENT_WRITER_ROLE = ( + "An expert in creating job descriptions. " + "The goal is to craft a detailed, engaging, and enticing job posting " + "that resonates with the company's values and attracts the right candidates." +) + +AGENT_REVIEWER_ROLE = ( + "An expert in reviewing and editing content. " + "The goal is to ensure the job description is accurate, engaging, " + "and aligned with the company's values and needs." +) + + +def create_workflow() -> Workflow: + """ + Create the workflow with all necessary agents and tools. + + Returns: + Workflow: The configured workflow. + """ + llm = setup_llm() + + search_connection = TavilyConnection() + tool_search = TavilyTool(connection=search_connection) + + agent_researcher = ReActAgent(name="Researcher Analyst", role=AGENT_RESEARCHER_ROLE, llm=llm, tools=[tool_search]) + agent_writer = ReActAgent( + name="Job Description Writer", + role=AGENT_WRITER_ROLE, + llm=llm, + ) + agent_reviewer = ReActAgent( + name="Job Description Reviewer and Editor", + role=AGENT_REVIEWER_ROLE, + llm=llm, + ) + agent_manager = LinearAgentManager(llm=llm) + + linear_orchestrator = LinearOrchestrator( + manager=agent_manager, + agents=[agent_researcher, agent_writer, agent_reviewer], + final_summarizer=True, + plan_approval=PlanApprovalConfig(enabled=True), + ) + + return Workflow( + flow=Flow(nodes=[linear_orchestrator]), + ) + + +def run_planner() -> tuple[str, dict]: + workflow = create_workflow() + + user_prompt = "Analyze the Google's company culture, values, and mission." # noqa: E501 + + tracing = TracingCallbackHandler() + try: + result = workflow.run( + input_data={"input": user_prompt}, + config=RunnableConfig(callbacks=[tracing]), + ) + + logger.info("Workflow completed successfully") + + output = result.output[workflow.flow.nodes[0].id]["output"]["content"] + print(output) + + return output, tracing.runs + + except Exception as e: + logger.error(f"An error occurred during workflow execution: {str(e)}") + return "", {} + + +if __name__ == "__main__": + run_planner() diff --git a/examples/human_in_the_loop/streaming_post_writer/client_streaming.html b/examples/human_in_the_loop/streaming_post_writer/client_streaming.html new file mode 100644 index 00000000..9258bfb8 --- /dev/null +++ b/examples/human_in_the_loop/streaming_post_writer/client_streaming.html @@ -0,0 +1,193 @@ + + +
+ + +