diff --git a/examples/dynamic/insurance_openai.py b/examples/dynamic/insurance_openai.py index 630f1a4..131a708 100644 --- a/examples/dynamic/insurance_openai.py +++ b/examples/dynamic/insurance_openai.py @@ -148,15 +148,13 @@ def create_initial_node(): ], "functions": [ { - "type": "function", - "function": { - "name": "collect_age", - "description": "Record customer's age", - "parameters": { - "type": "object", - "properties": {"age": {"type": "integer"}}, - "required": ["age"], - }, + "name": "collect_age", + "handler": collect_age, + "description": "Record customer's age", + "parameters": { + "type": "object", + "properties": {"age": {"type": "integer"}}, + "required": ["age"], }, } ], @@ -177,17 +175,15 @@ def create_marital_status_node(): ], "functions": [ { - "type": "function", - "function": { - "name": "collect_marital_status", - "description": "Record marital status", - "parameters": { - "type": "object", - "properties": { - "marital_status": {"type": "string", "enum": ["single", "married"]} - }, - "required": ["marital_status"], + "name": "collect_marital_status", + "handler": collect_marital_status, + "description": "Record marital status", + "parameters": { + "type": "object", + "properties": { + "marital_status": {"type": "string", "enum": ["single", "married"]} }, + "required": ["marital_status"], }, } ], @@ -210,21 +206,19 @@ def create_quote_calculation_node(age: int, marital_status: str): ], "functions": [ { - "type": "function", - "function": { - "name": "calculate_quote", - "description": "Calculate initial insurance quote", - "parameters": { - "type": "object", - "properties": { - "age": {"type": "integer"}, - "marital_status": { - "type": "string", - "enum": ["single", "married"], - }, + "name": "calculate_quote", + "handler": calculate_quote, + "description": "Calculate initial insurance quote", + "parameters": { + "type": "object", + "properties": { + "age": {"type": "integer"}, + "marital_status": { + "type": "string", + "enum": ["single", "married"], }, - "required": ["age", "marital_status"], }, + "required": ["age", "marital_status"], }, } ], @@ -250,26 +244,25 @@ def create_quote_results_node(quote: Dict[str, Any]): ], "functions": [ { - "type": "function", - "function": { - "name": "update_coverage", - "description": "Update coverage options", - "parameters": { - "type": "object", - "properties": { - "coverage_amount": {"type": "integer"}, - "deductible": {"type": "integer"}, - }, - "required": ["coverage_amount", "deductible"], + "name": "update_coverage", + "handler": update_coverage, + "description": "Update coverage options", + "parameters": { + "type": "object", + "properties": { + "coverage_amount": {"type": "integer"}, + "deductible": {"type": "integer"}, }, + "required": ["coverage_amount", "deductible"], }, }, { - "type": "function", - "function": { - "name": "end_quote", - "description": "Complete the quote process", - "parameters": {"type": "object", "properties": {}}, + "name": "end_quote", + "handler": end_quote, + "description": "Complete the quote process", + "parameters": { + "type": "object", + "properties": {}, }, }, ], @@ -360,15 +353,6 @@ async def main(): tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-helios-en") llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") - # Create function handlers dictionary - function_handlers = { - "collect_age": collect_age, - "collect_marital_status": collect_marital_status, - "calculate_quote": calculate_quote, - "update_coverage": update_coverage, - "end_quote": end_quote, - } - # Create initial context messages = [ { @@ -405,9 +389,6 @@ async def main(): ) flow_manager.state = {} # Initialize state storage - # Register all functions - await flow_manager.register_functions(function_handlers) - @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant): await transport.capture_participant_transcription(participant["id"]) diff --git a/src/pipecat_flows/base.py b/src/pipecat_flows/base.py index 7391e61..d3d1e9a 100644 --- a/src/pipecat_flows/base.py +++ b/src/pipecat_flows/base.py @@ -16,7 +16,7 @@ Both StaticFlowManager and DynamicFlowManager build upon this base. """ -from abc import ABC, abstractmethod +from abc import ABC from typing import Any, Callable, List, Optional from loguru import logger @@ -133,7 +133,6 @@ async def _execute_actions( if post_actions: await self.action_manager.execute_actions(post_actions) - @abstractmethod async def _validate_initialization(self) -> None: """Validate that the manager is properly initialized. diff --git a/src/pipecat_flows/dynamic.py b/src/pipecat_flows/dynamic.py index 7083981..79558cc 100644 --- a/src/pipecat_flows/dynamic.py +++ b/src/pipecat_flows/dynamic.py @@ -3,43 +3,18 @@ # # SPDX-License-Identifier: BSD 2-Clause License # -""" -Dynamic Flow Manager for Pipecat Flows - -This module provides the DynamicFlowManager for handling runtime-determined conversation -flows. It's designed for cases where the conversation structure needs to be determined -during runtime based on external data, API calls, or complex business logic. - -Example: - # Define simple function handler - async def collect_age(args: Dict[str, Any]) -> Dict[str, Any]: - age = args["age"] - return {"age": age} - - # Define transition callback - async def handle_transitions(function_name: str, args: Dict[str, Any], flow_manager): - if function_name == "collect_age": - if args["age"] < 25: - await flow_manager.set_node("young_adult", young_adult_config) - else: - await flow_manager.set_node("standard", standard_config) - - # Initialize and use flow manager - flow_manager = DynamicFlowManager(task, llm, tts, transition_callback=handle_transitions) - await flow_manager.register_functions({"collect_age": collect_age}) - await flow_manager.initialize(initial_messages) -""" - -from typing import Any, Awaitable, Callable, Dict, List, Optional + +from typing import Any, Awaitable, Callable, Dict, List, Optional, Set from loguru import logger -from pipecat.frames.frames import LLMMessagesUpdateFrame, LLMSetToolsFrame +from pipecat.frames.frames import ( + LLMMessagesUpdateFrame, + LLMSetToolsFrame, +) +from pipecat.pipeline.task import PipelineTask from .base import BaseFlowManager -from .exceptions import ( - FlowError, - FlowInitializationError, -) +from .exceptions import FlowError, FlowInitializationError def create_handler_wrapper( @@ -52,17 +27,16 @@ def create_handler_wrapper( Returns: Wrapped handler compatible with Pipecat's LLM function calling system - - Example: - # Original simple handler - async def collect_age(args: Dict[str, Any]) -> Dict[str, Any]: - return {"age": args["age"]} - - # Wrapped for Pipecat - wrapped = create_handler_wrapper(collect_age) """ - async def wrapped(function_name, tool_call_id, args, llm, context, result_callback): + async def wrapped( + function_name: str, + tool_call_id: str, + args: Dict[str, Any], + llm: Any, + context: Any, + result_callback: Callable, + ) -> None: logger.debug(f"Handler called for {function_name} with args: {args}") result = await handler(args) await result_callback(result) @@ -72,31 +46,33 @@ async def wrapped(function_name, tool_call_id, args, llm, context, result_callba class DynamicFlowManager(BaseFlowManager): - """Manages dynamically created conversation flows in Pipecat applications. - - The DynamicFlowManager provides a framework for creating conversational AI - applications where the flow of conversation is determined at runtime. This is - particularly useful for complex interactions where the next state depends on - user input, external data, or business logic. - - The flow manager handles: - - Setting up conversation nodes - - Managing LLM context and available functions - - Executing pre/post actions - - Calling transition callback when functions are executed - - The application provides: - - Simple function handlers for processing user input - - Node configurations for conversation states - - Transition callback for determining flow progression - - Business logic for flow decisions + """Manages dynamically created conversation flows. + + Designed for flows where nodes and functions are created during runtime. + Each node specifies its available functions, and the flow manager handles + registration and state transitions. + + Example: + async def handle_transitions( + function_name: str, + args: Dict[str, Any], + flow_manager: "DynamicFlowManager" + ) -> None: + # Query business logic + next_step = await get_next_step(args) + + # Create new node based on results + await flow_manager.set_node( + "next_step", + create_node_for_step(next_step) + ) """ def __init__( self, - task, - llm, - tts=None, + task: PipelineTask, + llm: Any, + tts: Optional[Any] = None, transition_callback: Optional[ Callable[[str, Dict[str, Any], "DynamicFlowManager"], Awaitable[None]] ] = None, @@ -107,14 +83,12 @@ def __init__( task: PipelineTask instance for queueing frames llm: LLM service instance tts: Optional TTS service for voice actions - transition_callback: Async callback for handling state transitions. - Called after function execution with: - - function_name: Name of function called - - arguments: Arguments passed to function - - flow_manager: Reference to this instance + transition_callback: Optional callback for handling transitions """ super().__init__(task, llm, tts) self.transition_callback = transition_callback + self.state: Dict[str, Any] = {} + self.current_functions: Set[str] = set() async def initialize(self, initial_messages: List[dict]) -> None: """Initialize the flow with starting messages. @@ -131,7 +105,7 @@ async def initialize(self, initial_messages: List[dict]) -> None: return try: - # Queue initial context frames + # Set initial context with no tools await self.task.queue_frame(LLMMessagesUpdateFrame(messages=initial_messages)) await self.task.queue_frame(LLMSetToolsFrame(tools=[])) logger.debug("Initialized dynamic flow manager") @@ -139,46 +113,47 @@ async def initialize(self, initial_messages: List[dict]) -> None: self.initialized = False raise FlowInitializationError(f"Failed to initialize flow: {str(e)}") from e - def _validate_node_config(self, node_id: str, node_config: Dict[str, Any]) -> None: - """Validate node configuration structure. - - Args: - node_id: Identifier for the node - node_config: Node configuration to validate - - Raises: - ValueError: If configuration is invalid - """ - if "messages" not in node_config: - raise ValueError(f"Node '{node_id}' missing required 'messages' field") - if "functions" not in node_config: - raise ValueError(f"Node '{node_id}' missing required 'functions' field") - - if not isinstance(node_config["messages"], list): - raise ValueError(f"Node '{node_id}' messages must be a list") - if not isinstance(node_config["functions"], list): - raise ValueError(f"Node '{node_id}' functions must be a list") - async def set_node(self, node_id: str, node_config: Dict[str, Any]) -> None: """Set up a new conversation node. - This method: - 1. Validates the node configuration - 2. Executes any pre-actions - 3. Updates LLM context with new messages and functions - 4. Executes any post-actions - 5. Updates current node state + Handles: + 1. Function registration for the node + 2. Pre-actions execution + 3. LLM context updates + 4. Post-actions execution Args: - node_id: Identifier for the node - node_config: Complete node configuration including: + node_id: Identifier for the new node + node_config: Node configuration including: - messages: List of messages for LLM context - - functions: List of available functions - - pre_actions: Optional actions to execute before context update - - post_actions: Optional actions to execute after context update + - functions: List of function configurations + - pre_actions: Optional actions to execute before transition + - post_actions: Optional actions to execute after transition - Raises: - FlowError: If node setup fails + Example: + await flow_manager.set_node( + "collect_info", + { + "messages": [{ + "role": "system", + "content": "Collect user information" + }], + "functions": [ + { + "name": "save_info", + "handler": save_info_handler, + "description": "Save user information", + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + } + } + } + ] + } + ) """ await self._validate_initialization() @@ -189,77 +164,84 @@ async def set_node(self, node_id: str, node_config: Dict[str, Any]) -> None: if pre_actions := node_config.get("pre_actions"): await self._execute_actions(pre_actions=pre_actions) - # Update LLM context - await self._update_llm_context(node_config["messages"], node_config["functions"]) + # Register functions and create tools list + tools = [] + new_functions: Set[str] = set() + + for func_config in node_config["functions"]: + name = func_config["name"] + if name not in self.current_functions: + # Register new function + await self.llm.register_function( + name, create_handler_wrapper(func_config["handler"]) + ) + new_functions.add(name) + + # Create function definition in provider-specific format + function_def = { + "name": name, + "description": func_config["description"], + "parameters": func_config["parameters"], + } + + # Let the adapter format it correctly + tools.extend(self.adapter.format_functions([function_def])) + + # Update LLM context with new messages and tools + await self._update_llm_context(node_config["messages"], tools) # Execute post-actions if any if post_actions := node_config.get("post_actions"): await self._execute_actions(post_actions=post_actions) + # Update state self.current_node = node_id - logger.debug(f"Node set successfully: {node_id}") + self.current_functions = new_functions + + logger.debug(f"Successfully set node: {node_id}") except Exception as e: raise FlowError(f"Failed to set node {node_id}: {str(e)}") from e - async def register_functions( - self, functions: Dict[str, Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]] - ) -> None: - """Register functions with the LLM service. + async def handle_function_call(self, function_name: str, args: Dict[str, Any]) -> None: + """Handle function calls and transitions. - Functions should be simple async handlers that take a dictionary of arguments - and return a dictionary of results. The framework handles adapting these - simple handlers to Pipecat's needs and managing transitions. + This method: + 1. Executes the function + 2. Calls the transition callback if provided + 3. Updates state based on results Args: - functions: Dictionary mapping function names to handlers. - Handlers should be async functions with the signature: - async def handler(args: Dict[str, Any]) -> Dict[str, Any] + function_name: Name of the called function + args: Arguments passed to the function + """ + try: + # Execute transition callback if provided + if self.transition_callback: + await self.transition_callback(function_name, args, self) - Example: - async def collect_age(args: Dict[str, Any]) -> Dict[str, Any]: - age = args["age"] - return {"age": age} + logger.debug(f"Handled function call: {function_name}") - await flow_manager.register_functions({ - "collect_age": collect_age - }) - """ - for name, handler in functions.items(): - logger.debug(f"Registering function: {name}") - - async def wrapper( - function_name: str, - tool_call_id: str, - arguments: Dict[str, Any], - llm: Any, - context: Any, - result_callback: Callable, - ) -> None: - # First wrap the simple handler to match Pipecat's needs - pipecat_handler = create_handler_wrapper(handler) - - # Call handler and get result - await pipecat_handler( - function_name, tool_call_id, arguments, llm, context, result_callback - ) - - # Call transition callback if provided - if self.transition_callback: - logger.debug(f"Triggering transition for {function_name}") - try: - await self.transition_callback(function_name, arguments, self) - except Exception as e: - logger.error(f"Error in transition callback: {str(e)}") - # Don't re-raise the error - we want to continue execution - - self.llm.register_function(name, wrapper) - logger.debug(f"Registered function: {name}") - - async def _validate_initialization(self) -> None: - """Validate manager is initialized. + except Exception as e: + raise FlowError(f"Error handling function {function_name}: {str(e)}") from e + + def _validate_node_config(self, node_id: str, config: Dict[str, Any]) -> None: + """Validate node configuration structure. + + Args: + node_id: Identifier for the node being validated + config: Node configuration to validate Raises: - FlowError: If manager not initialized + ValueError: If configuration is invalid """ - await super()._validate_initialization() + if "messages" not in config: + raise ValueError(f"Node '{node_id}' missing required 'messages' field") + if "functions" not in config: + raise ValueError(f"Node '{node_id}' missing required 'functions' field") + + for func in config["functions"]: + required = {"name", "handler", "description", "parameters"} + missing = required - set(func.keys()) + if missing: + raise ValueError(f"Function in node '{node_id}' missing required fields: {missing}") diff --git a/src/pipecat_flows/static.py b/src/pipecat_flows/static.py index a8179ee..0cdec12 100644 --- a/src/pipecat_flows/static.py +++ b/src/pipecat_flows/static.py @@ -42,7 +42,6 @@ from .base import BaseFlowManager from .config import FlowConfig from .exceptions import ( - FlowError, FlowInitializationError, FlowTransitionError, InvalidFunctionError, @@ -202,13 +201,3 @@ async def handle_transition(self, function_name: str) -> None: except Exception as e: raise FlowTransitionError(f"Failed to execute transition: {str(e)}") from e - - async def _validate_initialization(self) -> None: - """Validate manager is initialized. - - Raises: - FlowError: If manager not initialized - """ - await super()._validate_initialization() - if not self.flow: - raise FlowError("Flow state not properly initialized")