Skip to content

Commit

Permalink
Merge pull request #3 from rayrayraykk/state
Browse files Browse the repository at this point in the history
Add state agent for app
  • Loading branch information
ZiTao-Li authored Jan 16, 2024
2 parents 97be758 + 5d4dbb8 commit 38113e6
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 9 deletions.
3 changes: 2 additions & 1 deletion examples/game/customer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def reply(self, x: dict = None) -> Union[dict, tuple]:
# TODO:
# not sure if it is some implicit requirement of the tongyi chat api,
# the first/last message must have role 'user'.
x["role"] = "user"
if x is not None:
x["role"] = "user"

if self.stage == CustomerConv.WARMING_UP and "推荐" in x["content"]:
self.stage = CustomerConv.AFTER_MEAL_CHAT
Expand Down
14 changes: 6 additions & 8 deletions examples/game/main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
# -*- coding: utf-8 -*-
import os
import yaml
import sys
import inquirer
import random
import argparse
from loguru import logger

import rich.pretty

from agentscope.models import read_model_configs, load_model_by_name
import agentscope
from agentscope.models import read_model_configs
from agentscope.message import Msg
from agentscope.msghub import msghub
from customer import Customer
Expand Down Expand Up @@ -44,6 +41,8 @@ def invited_group_chat(invited_customer, player, cur_plot):
answer = inquirer.prompt(questions)["ans"]
if answer == "是":
msg = player(annoucement)
elif answer == "否":
msg = None
elif answer == "结束邀请对话":
break
for c in invited_customer:
Expand Down Expand Up @@ -235,14 +234,13 @@ def main(args):
args = parser.parse_args()
GAME_CONFIG = yaml.safe_load(open("./config/game_config.yaml"))

logger.add(sys.stderr, level="INFO")

TONGYI_CONFIG = {
"type": "tongyi",
"name": "tongyi_model",
"model_name": "qwen-max-1201",
"api_key": os.environ.get("TONGYI_API_KEY"),
}

read_model_configs(TONGYI_CONFIG)
agentscope.init(model_configs=[TONGYI_CONFIG], logger_level="INFO")

main(args)
2 changes: 2 additions & 0 deletions src/agentscope/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .rpc_agent import RpcAgentBase
from .dialog_agent import DialogAgent
from .dict_dialog_agent import DictDialogAgent
from .state_agent import StateAgent
from .user_agent import UserAgent

# todo: convert Operator to a common base class for AgentBase and PipelineBase
Expand All @@ -16,5 +17,6 @@
"RpcAgentBase",
"DialogAgent",
"DictDialogAgent",
"StateAgent",
"UserAgent",
]
101 changes: 101 additions & 0 deletions src/agentscope/agents/state_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-
""" State agent module. """
from typing import Any, Callable, Dict, Union
from loguru import logger

from .agent import AgentBase


class StateAgent(AgentBase):
"""
Manages the state of an agent, allowing for actions to be executed
based on the current state.
Methods:
reply(self, x: dict = None) -> dict: Processes the input based on
the current state handler.
register_state(self, state: str, handler: Callable, properties:
Dict[str, Any] = None): Registers a new state handler.
transition(self, new_state: str): Transitions the agent to a new state.
"""

def __init__(self, *arg: Any, **kwargs: Any):
super().__init__(*arg, **kwargs)
self.cur_state = None
self.state_handlers = {}
self.state_properties = {}

def reply(self, x: dict = None) -> dict:
"""
Define the actions taken by this agent. Handler the input based
on the current state handler and returns the response message.
Args:
x (`dict`, defaults to `None`):
Dialog history and some environment information
Returns:
The agent's response to the input.
"""
handler = self.state_handlers.get(self.cur_state)
if handler is None:
raise ValueError(
f"No handler registered for state '{self.cur_state}'",
)
msg = handler(x)
return msg

def register_state(
self,
state: Union[int, str, float, tuple],
handler: Callable,
properties: Dict[str, Any] = None,
) -> None:
"""
Registers a new state, its handler function, and optionally
properties associated with the state.
Args:
state (Union[int, str, float, tuple]): The name of the state to
register.
handler (Callable): The function that handles the state.
properties (dict, optional): A dictionary of properties related
to the state.
Returns:
None
"""
if state in self.state_handlers:
logger.warning(
f"State '{state}' is already registered. Overwriting.",
)
self.state_handlers[state] = handler
if properties:
self.state_properties[state] = properties

def transition(self, new_state: Union[int, str, float, tuple]) -> None:
"""
Transitions the agent to a new state and updates any associated
properties.
Args:
new_state (Union[int, str, float, tuple]): The state to which
the agent should transition.
Returns:
None
Raises:
ValueError: If the new_state is not registered.
"""
if new_state not in self.state_handlers:
raise ValueError(f"State '{new_state}' is not registered.")
self.cur_state = new_state
# Switch other properties related to the new state
if new_state in self.state_properties:
for prop, value in self.state_properties[new_state].items():
setattr(self, prop, value)
else:
logger.info(
f"No additional properties to switch for state '{new_state}'.",
)

0 comments on commit 38113e6

Please sign in to comment.