From 9c2cd1bb3babd02bcdae3df6f0e306666b080f0a Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Thu, 5 Sep 2024 14:35:36 +0800 Subject: [PATCH] bug fix after merge --- src/agentscope/cpp_server/worker.cc | 102 +++++++++++++++++++-------- src/agentscope/cpp_server/worker.h | 3 + src/agentscope/rpc/worker_args.proto | 7 +- 3 files changed, 79 insertions(+), 33 deletions(-) diff --git a/src/agentscope/cpp_server/worker.cc b/src/agentscope/cpp_server/worker.cc index 09167477b..a191982af 100644 --- a/src/agentscope/cpp_server/worker.cc +++ b/src/agentscope/cpp_server/worker.cc @@ -19,12 +19,12 @@ using std::to_string; using namespace pybind11::literals; using WorkerArgs::AgentArgs; -using WorkerArgs::AgentMemoryReturn; using WorkerArgs::CreateAgentArgs; using WorkerArgs::ModelConfigsArgs; using WorkerArgs::ObserveArgs; using WorkerArgs::ReplyArgs; -using WorkerArgs::ReplyReturn; +using WorkerArgs::MsgReturn; +using WorkerArgs::AgentListReturn; Task::Task(const int task_id) : _task_id(task_id), @@ -85,6 +85,9 @@ Worker::Worker( _max_tasks(std::max(max_tasks, 1u)), _max_timeout_seconds(std::max(max_timeout_seconds, 1u)) { + py::object serialize_lib = py::module::import("agentscope.serialize"); + _serialize = serialize_lib.attr("serialize"); + _deserialize = serialize_lib.attr("deserialize"); py::gil_scoped_release release; struct stat info; if (stat("./logs/", &info) != 0) @@ -526,7 +529,7 @@ pair Worker::get_task_result(const int task_id) { string result_str = _tasks[idx].second->get_result(); logger("get_task_result 3: task_id = " + to_string(task_id) + " idx = " + to_string(idx) + " result_str = [" + result_str + "]"); - ReplyReturn result; + MsgReturn result; result.ParseFromString(result_str); logger("get_task_result 4: task_id = " + to_string(task_id) + " idx = " + to_string(idx) + " result_ok = " + to_string(result.ok()) + " result_str = [" + result_str + "]"); return make_pair(result.ok(), result.message()); @@ -730,16 +733,27 @@ string Worker::call_get_agent_list() call_id_list.push_back(call_id); } } - string final_result = "["; + // string final_result = "["; + vector result_list; for (auto call_id : call_id_list) { - string result = get_result(call_id); - logger("call_get_agent_list 1: call_id = " + to_string(call_id) + " result = [" + result + "]"); - if (final_result != "[" && !result.empty()) - final_result += ","; - final_result += result; + string result_str = get_result(call_id); + AgentListReturn result; + result.ParseFromString(result_str); + for (const auto &agent_str : result.agent_str_list()) + { + result_list.push_back(agent_str); + } + // logger("call_get_agent_list 1: call_id = " + to_string(call_id) + " result = [" + result + "]"); + // if (final_result != "[" && !result.empty()) + // final_result += ","; + // final_result += result; } - final_result += "]"; + // final_result += "]"; + py::gil_scoped_acquire acquire; + logger("call_get_agent_list 1: result_list.size() = [" + to_string(result_list.size()) + "]"); + // py::object serialize_lib = py::module::import("agentscope.serialize"); + string final_result = _serialize(result_list).cast(); logger("call_get_agent_list 2: result = [" + final_result + "]"); return final_result; } @@ -747,16 +761,21 @@ string Worker::call_get_agent_list() void Worker::get_agent_list_worker(const int call_id) { py::gil_scoped_acquire acquire; - vector agent_str_list; + // vector agent_str_list; + AgentListReturn result; { shared_lock lock(_agent_pool_mutex); for (auto &iter : _agent_pool) { - agent_str_list.push_back(iter.second.attr("__str__")().cast()); + // agent_str_list.push_back(iter.second.attr("__str__")().cast()); + result.add_agent_str_list(iter.second.attr("__str__")().cast()); } } - string result = py::module::import("json").attr("dumps")(agent_str_list).cast(); - set_result(call_id, result.substr(1, result.size() - 2)); + // string result = py::module::import("json").attr("dumps")(agent_str_list).cast(); + // py::object serialize_lib = py::module::import("agentscope.serialize"); + // string result = serialize_lib.attr("serialize")(agent_str_list).cast(); + // set_result(call_id, result.substr(1, result.size() - 2)); + set_result(call_id, result.SerializeAsString()); } string Worker::call_set_model_configs(const string &model_configs) @@ -800,8 +819,10 @@ pair Worker::call_get_agent_memory(const string &agent_id) AgentArgs args; args.set_agent_id(agent_id); int call_id = call_worker_func(worker_id, function_ids::get_agent_memory, &args); - string result = get_result(call_id); - return make_pair(result[0] == 'T', result.substr(1, result.size() - 1)); + string result_str = get_result(call_id); + MsgReturn result; + result.ParseFromString(result_str); + return make_pair(result.ok(), result.message()); } void Worker::get_agent_memory_worker(const int call_id) @@ -814,15 +835,23 @@ void Worker::get_agent_memory_worker(const int call_id) shared_lock lock(_agent_pool_mutex); py::object agent = _agent_pool[agent_id]; py::object memory = agent.attr("memory"); + MsgReturn result; if (memory.is_none()) { - set_result(call_id, "FAgent [" + agent_id + "] has no memory."); + // set_result(call_id, "FAgent [" + agent_id + "] has no memory."); + result.set_ok(false); + result.set_message("Agent [" + agent_id + "] has no memory."); } else { py::object memory_info = memory.attr("get_memory")(); - set_result(call_id, "T" + py::module::import("json").attr("dumps")(memory_info).cast()); + // py::object serialize_lib = py::module::import("agentscope.serialize"); + string memory_msg = _serialize(memory_info).cast(); + result.set_ok(true); + result.set_message(memory_msg); + // set_result(call_id, "T" + py::module::import("json").attr("dumps")(memory_info).cast()); } + set_result(call_id, result.SerializeAsString()); } pair Worker::call_reply(const string &agent_id, const string &message) @@ -863,12 +892,10 @@ void Worker::reply_worker(const int call_id) shared_lock lock(_agent_pool_mutex); py::object agent = _agent_pool[agent_id]; py::object message_lib = py::module::import("agentscope.message"); - py::object py_message = message.size() ? message_lib.attr("deserialize")(message) : py::none(); + // py::object serialize_lib = py::module::import("agentscope.serialize"); + py::object py_message = message.size() ? _deserialize(message) : py::none(); - py::object msg_class = message_lib.attr("Msg"); - py::object msg = msg_class( - "name"_a = agent.attr("name"), "content"_a = py::none(), "task_id"_a = task_id); - string msg_str = msg.attr("serialize")().cast(); + string msg_str = to_string(task_id); logger("reply_worker 3: call_id = " + to_string(call_id) + " agent_id = " + agent_id + " task_id = " + to_string(task_id) + " callback_id = " + to_string(callback_id) + " msg_str = " + msg_str); set_result(call_id, msg_str); @@ -878,12 +905,13 @@ void Worker::reply_worker(const int call_id) { py_message.attr("update_value")(); } - ReplyReturn result; + MsgReturn result; try { logger("reply_worker 3.1: call_id = " + to_string(call_id) + " agent_id = " + agent_id + " task_id = " + to_string(task_id) + " callback_id = " + to_string(callback_id) + " call reply"); result.set_ok(true); - result.set_message(agent.attr("reply")(py_message).attr("serialize")().cast()); + py::object reply_msg = agent.attr("reply")(py_message); + result.set_message(_serialize(reply_msg).cast()); } catch (const std::exception &e) { @@ -907,6 +935,7 @@ pair Worker::call_observe(const string &agent_id, const string &me args.set_message(message); int call_id = call_worker_func(worker_id, function_ids::observe, &args); string result = get_result(call_id); + logger("call_observe 2: call_id = " + to_string(call_id) + " result = " + result); return make_pair(true, result); } @@ -922,13 +951,28 @@ void Worker::observe_worker(const int call_id) py::object agent = _agent_pool[agent_id]; py::object message_lib = py::module::import("agentscope.message"); py::object PlaceholderMessage_class = message_lib.attr("PlaceholderMessage"); + // py::object serialize_lib = py::module::import("agentscope.serialize"); logger("observe_worker 1: call_id = " + to_string(call_id) + " message = " + message); - py::object py_messages = message.size() ? message_lib.attr("deserialize")(message) : py::list(); - for (auto &py_message : py_messages) + py::object py_messages = message.size() ? _deserialize(message) : py::list(); + // if (py::isinstance(py_messages, py::list())) + // { + // py_messages.attr("update_value")(); + // } + if (py::isinstance(py_messages)) + { + for (auto &py_message : py_messages) + { + if (py::isinstance(py_message, PlaceholderMessage_class)) + { + py_message.attr("update_value")(); + } + } + } + else { - if (py::isinstance(py_message, PlaceholderMessage_class)) + if (py::isinstance(py_messages, PlaceholderMessage_class)) { - py_message.attr("update_value")(); + py_messages.attr("update_value")(); } } py::print("observe_worker: py_messages = ", py_messages); diff --git a/src/agentscope/cpp_server/worker.h b/src/agentscope/cpp_server/worker.h index 50983c72d..2985f554c 100644 --- a/src/agentscope/cpp_server/worker.h +++ b/src/agentscope/cpp_server/worker.h @@ -106,6 +106,9 @@ class Worker const unsigned int _max_tasks; const unsigned int _max_timeout_seconds; + // common used functions + py::object _serialize, _deserialize; + enum function_ids { create_agent = 0, diff --git a/src/agentscope/rpc/worker_args.proto b/src/agentscope/rpc/worker_args.proto index 3931f62c3..4d4ae3860 100644 --- a/src/agentscope/rpc/worker_args.proto +++ b/src/agentscope/rpc/worker_args.proto @@ -28,12 +28,11 @@ message ModelConfigsArgs { bytes model_configs = 1; } -message AgentMemoryReturn { - bool ok = 1; - string message = 2; +message AgentListReturn { + repeated string agent_str_list = 1; } -message ReplyReturn { +message MsgReturn { bool ok = 1; string message = 2; }