Skip to content

Commit

Permalink
add state agent
Browse files Browse the repository at this point in the history
  • Loading branch information
rayrayraykk committed Jan 15, 2024
1 parent 06929d7 commit 43dc8f1
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/agentscope/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .rpc_agent import RpcAgentBase
from .dialog_agent import DialogAgent
from .dict_dialog_agent import DictDialogAgent
from .state_agent import StateAgent

# todo: convert Operator to a common base class for AgentBase and PipelineBase
_Operator = Callable[..., dict]
Expand All @@ -15,4 +16,5 @@
"RpcAgentBase",
"DialogAgent",
"DictDialogAgent",
"StateAgent",
]
99 changes: 99 additions & 0 deletions src/agentscope/agents/state_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# -*- coding: utf-8 -*-
""" State agent module. """
from typing import Any, Callable, Dict
from loguru import logger

from .agent import AgentBase


class StateAgent(AgentBase):
"""
Manages the state of an agent, allowing for actions to be executed
based on the current state.
Methods:
reply(self, x: dict = None) -> dict: Processes the input based on
the current state handler.
register_state(self, state: str, handler: Callable, properties:
Dict[str, Any] = None): Registers a new state handler.
transition(self, new_state: str): Transitions the agent to a new state.
"""

def __init__(self, *arg: Any, **kwargs: Any):
super().__init__(*arg, **kwargs)
self.cur_state = None
self.state_handlers = {}
self.state_properties = {}

def reply(self, x: dict = None) -> dict:
"""
Define the actions taken by this agent. Handler the input based
on the current state handler and returns the response message.
Args:
x (`dict`, defaults to `None`):
Dialog history and some environment information
Returns:
The agent's response to the input.
"""
handler = self.state_handlers.get(self.cur_state)
if handler is None:
raise ValueError(
f"No handler registered for state '{self.cur_state}'",
)
msg = handler(x)
return msg

def register_state(
self,
state: str,
handler: Callable,
properties: Dict[str, Any] = None,
) -> None:
"""
Registers a new state, its handler function, and optionally
properties associated with the state.
Args:
state (str): The name of the state to register.
handler (Callable): The function that handles the state.
properties (dict, optional): A dictionary of properties related
to the state.
Returns:
None
"""
if state in self.state_handlers:
logger.warning(
f"State '{state}' is already registered. Overwriting.",
)
self.state_handlers[state] = handler
if properties:
self.state_properties[state] = properties

def transition(self, new_state: str) -> None:
"""
Transitions the agent to a new state and updates any associated
properties.
Args:
new_state (str): The state to which the agent should transition.
Returns:
None
Raises:
ValueError: If the new_state is not registered.
"""
if new_state not in self.state_handlers:
raise ValueError(f"State '{new_state}' is not registered.")
self.cur_state = new_state
# Switch other properties related to the new state
if new_state in self.state_properties:
for prop, value in self.state_properties[new_state].items():
setattr(self, prop, value)
else:
logger.info(
f"No additional properties to switch for state '{new_state}'.",
)

0 comments on commit 43dc8f1

Please sign in to comment.