diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index ed64f9d241cc..a85e40815a55 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -2,11 +2,15 @@ ## Unreleased +### What's new? + - **Support custom** `ClientManager` **in** `start_driver()` ([#2292](https://github.com/adap/flower/pull/2292)) - **Update REST API to support create and delete nodes** ([#2283](https://github.com/adap/flower/pull/2283)) -### What's new? +- **Update the C++ SDK** ([#2537](https://github/com/adap/flower/pull/2537), [#2528](https://github/com/adap/flower/pull/2528), [#2523](https://github.com/adap/flower/pull/2523), [#2522](https://github.com/adap/flower/pull/2522)) + + Add gRPC request-response capability to the C++ SDK. - **Fix the incorrect return types of Strategy** ([#2432](https://github.com/adap/flower/pull/2432/files)) diff --git a/examples/quickstart-cpp/CMakeLists.txt b/examples/quickstart-cpp/CMakeLists.txt index 79af6a0ef17e..d2b23ae38b21 100644 --- a/examples/quickstart-cpp/CMakeLists.txt +++ b/examples/quickstart-cpp/CMakeLists.txt @@ -30,34 +30,115 @@ endif() ###################### ### FLWR_GRPC_PROTO +get_filename_component(FLWR_PROTO_BASE_PATH "../../src/proto/" ABSOLUTE) +get_filename_component(FLWR_TRANS_PROTO "../../src/proto/flwr/proto/transport.proto" ABSOLUTE) +get_filename_component(FLWR_NODE_PROTO "../../src/proto/flwr/proto/node.proto" ABSOLUTE) +get_filename_component(FLWR_TASK_PROTO "../../src/proto/flwr/proto/task.proto" ABSOLUTE) +get_filename_component(FLWR_FLEET_PROTO "../../src/proto/flwr/proto/fleet.proto" ABSOLUTE) + +set(FLWR_TRANS_PROTO_SRCS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/transport.pb.cc") +set(FLWR_TRANS_PROTO_HDRS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/transport.pb.h") +set(FLWR_TRANS_GRPC_SRCS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/transport.grpc.pb.cc") +set(FLWR_TRANS_GRPC_HDRS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/transport.grpc.pb.h") + +set(FLWR_NODE_PROTO_SRCS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/node.pb.cc") +set(FLWR_NODE_PROTO_HDRS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/node.pb.h") +set(FLWR_NODE_GRPC_SRCS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/node.grpc.pb.cc") +set(FLWR_NODE_GRPC_HDRS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/node.grpc.pb.h") + +set(FLWR_TASK_PROTO_SRCS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/task.pb.cc") +set(FLWR_TASK_PROTO_HDRS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/task.pb.h") +set(FLWR_TASK_GRPC_SRCS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/task.grpc.pb.cc") +set(FLWR_TASK_GRPC_HDRS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/task.grpc.pb.h") + +set(FLWR_FLEET_PROTO_SRCS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/fleet.pb.cc") +set(FLWR_FLEET_PROTO_HDRS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/fleet.pb.h") +set(FLWR_FLEET_GRPC_SRCS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/fleet.grpc.pb.cc") +set(FLWR_FLEET_GRPC_HDRS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/fleet.grpc.pb.h") -get_filename_component(FLWR_PROTO "../../src/proto/flwr/proto/transport.proto" ABSOLUTE) -get_filename_component(FLWR_PROTO_PATH "${FLWR_PROTO}" PATH) +# External building command to generate gRPC source files. +add_custom_command( + OUTPUT "${FLWR_TRANS_PROTO_SRCS}" + "${FLWR_TRANS_PROTO_HDRS}" + "${FLWR_TRANS_GRPC_SRCS}" + "${FLWR_TRANS_GRPC_HDRS}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" + --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" + -I "${FLWR_PROTO_BASE_PATH}" + --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" + "${FLWR_TRANS_PROTO}" + DEPENDS "${FLWR_TRANS_PROTO}" +) + +add_custom_command( + OUTPUT + "${FLWR_NODE_PROTO_SRCS}" + "${FLWR_NODE_PROTO_HDRS}" + "${FLWR_NODE_GRPC_SRCS}" + "${FLWR_NODE_GRPC_HDRS}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" + --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" + -I "${FLWR_PROTO_BASE_PATH}" + --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" + "${FLWR_NODE_PROTO}" + DEPENDS + "${FLWR_NODE_PROTO}" +) -set(FLWR_PROTO_SRCS "${CMAKE_CURRENT_BINARY_DIR}/transport.pb.cc") -set(FLWR_PROTO_HDRS "${CMAKE_CURRENT_BINARY_DIR}/transport.pb.h") -set(FLWR_GRPC_SRCS "${CMAKE_CURRENT_BINARY_DIR}/transport.grpc.pb.cc") -set(FLAR_GRPC_HDRS "${CMAKE_CURRENT_BINARY_DIR}/transport.grpc.pb.h") +add_custom_command( + OUTPUT + "${FLWR_TASK_PROTO_SRCS}" + "${FLWR_TASK_PROTO_HDRS}" + "${FLWR_TASK_GRPC_SRCS}" + "${FLWR_TASK_GRPC_HDRS}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" + --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" + -I "${FLWR_PROTO_BASE_PATH}" + --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" + "${FLWR_TASK_PROTO}" + DEPENDS + "${FLWR_TASK_PROTO}" +) -# External building command to generate gRPC source files. add_custom_command( - OUTPUT "${FLWR_PROTO_SRCS}" "${FLWR_PROTO_HDRS}" "${FLWR_GRPC_SRCS}" "${FLWR_GRPC_HDRS}" + OUTPUT + "${FLWR_FLEET_PROTO_SRCS}" + "${FLWR_FLEET_PROTO_HDRS}" + "${FLWR_FLEET_GRPC_SRCS}" + "${FLWR_FLEET_GRPC_HDRS}" COMMAND ${_PROTOBUF_PROTOC} ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" - -I "${FLWR_PROTO_PATH}" + -I "${FLWR_PROTO_BASE_PATH}" --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" - "${FLWR_PROTO}" - DEPENDS "${FLWR_PROTO}" + "${FLWR_FLEET_PROTO}" + DEPENDS + "${FLWR_FLEET_PROTO}" ) -add_library(flwr_grpc_proto - ${FLWR_GRPC_SRCS} - ${FLWR_GRPC_HDRS} - ${FLWR_PROTO_SRCS} - ${FLWR_PROTO_HDRS} +add_library(flwr_grpc_proto STATIC + ${FLWR_TRANS_GRPC_SRCS} + ${FLWR_TRANS_GRPC_HDRS} + ${FLWR_TRANS_PROTO_SRCS} + ${FLWR_TRANS_PROTO_HDRS} + ${FLWR_NODE_GRPC_SRCS} + ${FLWR_NODE_GRPC_HDRS} + ${FLWR_NODE_PROTO_SRCS} + ${FLWR_NODE_PROTO_HDRS} + ${FLWR_TASK_GRPC_SRCS} + ${FLWR_TASK_GRPC_HDRS} + ${FLWR_TASK_PROTO_SRCS} + ${FLWR_TASK_PROTO_HDRS} + ${FLWR_FLEET_GRPC_SRCS} + ${FLWR_FLEET_GRPC_HDRS} + ${FLWR_FLEET_PROTO_SRCS} + ${FLWR_FLEET_PROTO_HDRS} ) + target_include_directories(flwr_grpc_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) target_link_libraries(flwr_grpc_proto diff --git a/examples/quickstart-cpp/driver.py b/examples/quickstart-cpp/driver.py new file mode 100644 index 000000000000..037623ee77cf --- /dev/null +++ b/examples/quickstart-cpp/driver.py @@ -0,0 +1,10 @@ +import flwr as fl +from fedavg_cpp import FedAvgCpp + +# Start Flower server for three rounds of federated learning +if __name__ == "__main__": + fl.driver.start_driver( + server_address="0.0.0.0:9091", + config=fl.server.ServerConfig(num_rounds=3), + strategy=FedAvgCpp(), + ) diff --git a/examples/quickstart-cpp/src/main.cc b/examples/quickstart-cpp/src/main.cc index fb3c533a3841..f294f9d69473 100644 --- a/examples/quickstart-cpp/src/main.cc +++ b/examples/quickstart-cpp/src/main.cc @@ -2,44 +2,58 @@ #include "start.h" int main(int argc, char **argv) { - if (argc != 3) { - std::cout << "Client takes three arguments as follows: " << std::endl; - std::cout << "./client CLIENT_ID SERVER_URL" << std::endl; - std::cout << "Example: ./flwr_client 0 '127.0.0.1:8080'" << std::endl; - return 0; - } - - // Parsing arguments - const std::string CLIENT_ID = argv[1]; - const std::string SERVER_URL = argv[2]; - - // Populate local datasets - std::vector ms{3.5, 9.3}; // b + m_0*x0 + m_1*x1 - double b = 1.7; - std::cout <<"Training set:" << std::endl; - SyntheticDataset local_training_data = SyntheticDataset(ms, b, 1000); - std::cout << std::endl; - - std::cout <<"Validation set:" << std::endl; - SyntheticDataset local_validation_data = SyntheticDataset(ms, b, 100); - std::cout << std::endl; - - std::cout <<"Test set:" << std::endl; - SyntheticDataset local_test_data = SyntheticDataset(ms, b, 500); - std::cout << std::endl; - - // Define a model - LineFitModel model = LineFitModel(500, 0.01, ms.size()); - - // Initialize TorchClient - SimpleFlwrClient client(CLIENT_ID, model, local_training_data, local_validation_data, local_test_data); - - // Define a server address - std::string server_add = SERVER_URL; - - // Start client + if (argc != 3 && argc != 4) { + std::cout << "Client takes three mandatory arguments and one optional as " + "follows: " + << std::endl; + std::cout << "./client CLIENT_ID SERVER_URL [GRPC_MODE]" << std::endl; + std::cout + << "GRPC_MODE is optional and can be either 'bidi' (default) or 'rere'." + << std::endl; + std::cout << "Example: ./flwr_client 0 '127.0.0.1:8080' bidi" << std::endl; + std::cout << "This is the same as: ./flwr_client 0 '127.0.0.1:8080'" + << std::endl; + return 0; + } + + // Parsing arguments + const std::string CLIENT_ID = argv[1]; + const std::string SERVER_URL = argv[2]; + + // Populate local datasets + std::vector ms{3.5, 9.3}; // b + m_0*x0 + m_1*x1 + double b = 1.7; + std::cout << "Training set:" << std::endl; + SyntheticDataset local_training_data = SyntheticDataset(ms, b, 1000); + std::cout << std::endl; + + std::cout << "Validation set:" << std::endl; + SyntheticDataset local_validation_data = SyntheticDataset(ms, b, 100); + std::cout << std::endl; + + std::cout << "Test set:" << std::endl; + SyntheticDataset local_test_data = SyntheticDataset(ms, b, 500); + std::cout << std::endl; + + // Define a model + LineFitModel model = LineFitModel(500, 0.01, ms.size()); + + // Initialize TorchClient + SimpleFlwrClient client(CLIENT_ID, model, local_training_data, + local_validation_data, local_test_data); + + // Define a server address + std::string server_add = SERVER_URL; + + if (argc == 4 && std::string(argv[3]) == "rere") { + std::cout << "Starting rere client" << std::endl; + // Start rere client + start::start_rere_client(server_add, &client); + } else { + std::cout << "Starting bidi client" << std::endl; + // Start bidi client start::start_client(server_add, &client); + } - return 0; + return 0; } - diff --git a/src/cc/flwr/CMakeLists.txt b/src/cc/flwr/CMakeLists.txt index 8ab7dc4c2964..c3d11d3c0e33 100644 --- a/src/cc/flwr/CMakeLists.txt +++ b/src/cc/flwr/CMakeLists.txt @@ -27,32 +27,112 @@ else() endif() # FLWR_GRPC_PROTO +get_filename_component(FLWR_PROTO_BASE_PATH "../../proto/" ABSOLUTE) +get_filename_component(FLWR_TRANS_PROTO "../../proto/flwr/proto/transport.proto" ABSOLUTE) +get_filename_component(FLWR_NODE_PROTO "../../proto/flwr/proto/node.proto" ABSOLUTE) +get_filename_component(FLWR_TASK_PROTO "../../proto/flwr/proto/task.proto" ABSOLUTE) +get_filename_component(FLWR_FLEET_PROTO "../../proto/flwr/proto/fleet.proto" ABSOLUTE) + +set(FLWR_TRANS_PROTO_SRCS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/transport.pb.cc") +set(FLWR_TRANS_PROTO_HDRS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/transport.pb.h") +set(FLWR_TRANS_GRPC_SRCS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/transport.grpc.pb.cc") +set(FLWR_TRANS_GRPC_HDRS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/transport.grpc.pb.h") + +set(FLWR_NODE_PROTO_SRCS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/node.pb.cc") +set(FLWR_NODE_PROTO_HDRS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/node.pb.h") +set(FLWR_NODE_GRPC_SRCS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/node.grpc.pb.cc") +set(FLWR_NODE_GRPC_HDRS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/node.grpc.pb.h") + +set(FLWR_TASK_PROTO_SRCS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/task.pb.cc") +set(FLWR_TASK_PROTO_HDRS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/task.pb.h") +set(FLWR_TASK_GRPC_SRCS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/task.grpc.pb.cc") +set(FLWR_TASK_GRPC_HDRS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/task.grpc.pb.h") + +set(FLWR_FLEET_PROTO_SRCS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/fleet.pb.cc") +set(FLWR_FLEET_PROTO_HDRS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/fleet.pb.h") +set(FLWR_FLEET_GRPC_SRCS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/fleet.grpc.pb.cc") +set(FLWR_FLEET_GRPC_HDRS "${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/fleet.grpc.pb.h") -get_filename_component(FLWR_PROTO "../../proto/flwr/proto/transport.proto" ABSOLUTE) -get_filename_component(FLWR_PROTO_PATH "${FLWR_PROTO}" PATH) +# External building command to generate gRPC source files. +add_custom_command( + OUTPUT "${FLWR_TRANS_PROTO_SRCS}" + "${FLWR_TRANS_PROTO_HDRS}" + "${FLWR_TRANS_GRPC_SRCS}" + "${FLWR_TRANS_GRPC_HDRS}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" + --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" + -I "${FLWR_PROTO_BASE_PATH}" + --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" + "${FLWR_TRANS_PROTO}" + DEPENDS "${FLWR_TRANS_PROTO}" +) -set(FLWR_PROTO_SRCS "${CMAKE_CURRENT_BINARY_DIR}/transport.pb.cc") -set(FLWR_PROTO_HDRS "${CMAKE_CURRENT_BINARY_DIR}/transport.pb.h") -set(FLWR_GRPC_SRCS "${CMAKE_CURRENT_BINARY_DIR}/transport.grpc.pb.cc") -set(FLAR_GRPC_HDRS "${CMAKE_CURRENT_BINARY_DIR}/transport.grpc.pb.h") +add_custom_command( + OUTPUT + "${FLWR_NODE_PROTO_SRCS}" + "${FLWR_NODE_PROTO_HDRS}" + "${FLWR_NODE_GRPC_SRCS}" + "${FLWR_NODE_GRPC_HDRS}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" + --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" + -I "${FLWR_PROTO_BASE_PATH}" + --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" + "${FLWR_NODE_PROTO}" + DEPENDS + "${FLWR_NODE_PROTO}" +) + +add_custom_command( + OUTPUT + "${FLWR_TASK_PROTO_SRCS}" + "${FLWR_TASK_PROTO_HDRS}" + "${FLWR_TASK_GRPC_SRCS}" + "${FLWR_TASK_GRPC_HDRS}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" + --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" + -I "${FLWR_PROTO_BASE_PATH}" + --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" + "${FLWR_TASK_PROTO}" + DEPENDS + "${FLWR_TASK_PROTO}" +) -# External building command to generate gRPC source files. add_custom_command( - OUTPUT "${FLWR_PROTO_SRCS}" "${FLWR_PROTO_HDRS}" "${FLWR_GRPC_SRCS}" "${FLWR_GRPC_HDRS}" + OUTPUT + "${FLWR_FLEET_PROTO_SRCS}" + "${FLWR_FLEET_PROTO_HDRS}" + "${FLWR_FLEET_GRPC_SRCS}" + "${FLWR_FLEET_GRPC_HDRS}" COMMAND ${_PROTOBUF_PROTOC} ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" - -I "${FLWR_PROTO_PATH}" + -I "${FLWR_PROTO_BASE_PATH}" --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" - "${FLWR_PROTO}" - DEPENDS "${FLWR_PROTO}" + "${FLWR_FLEET_PROTO}" + DEPENDS + "${FLWR_FLEET_PROTO}" ) add_library(flwr_grpc_proto STATIC - ${FLWR_GRPC_SRCS} - ${FLWR_GRPC_HDRS} - ${FLWR_PROTO_SRCS} - ${FLWR_PROTO_HDRS} + ${FLWR_TRANS_GRPC_SRCS} + ${FLWR_TRANS_GRPC_HDRS} + ${FLWR_TRANS_PROTO_SRCS} + ${FLWR_TRANS_PROTO_HDRS} + ${FLWR_NODE_GRPC_SRCS} + ${FLWR_NODE_GRPC_HDRS} + ${FLWR_NODE_PROTO_SRCS} + ${FLWR_NODE_PROTO_HDRS} + ${FLWR_TASK_GRPC_SRCS} + ${FLWR_TASK_GRPC_HDRS} + ${FLWR_TASK_PROTO_SRCS} + ${FLWR_TASK_PROTO_HDRS} + ${FLWR_FLEET_GRPC_SRCS} + ${FLWR_FLEET_GRPC_HDRS} + ${FLWR_FLEET_PROTO_SRCS} + ${FLWR_FLEET_PROTO_HDRS} ) target_include_directories(flwr_grpc_proto @@ -97,8 +177,14 @@ install(TARGETS flwr_merged EXPORT flwrTargets ) install( FILES - ${CMAKE_CURRENT_BINARY_DIR}/transport.grpc.pb.h - ${CMAKE_CURRENT_BINARY_DIR}/transport.pb.h + ${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/fleet.grpc.pb.h + ${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/fleet.pb.h + ${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/node.grpc.pb.h + ${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/node.pb.h + ${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/task.grpc.pb.h + ${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/task.pb.h + ${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/transport.grpc.pb.h + ${CMAKE_CURRENT_BINARY_DIR}/flwr/proto/transport.pb.h DESTINATION include ) install(DIRECTORY include/ DESTINATION include) diff --git a/src/cc/flwr/include/grpc_rere.h b/src/cc/flwr/include/grpc_rere.h new file mode 100644 index 000000000000..eabd05cbc2c9 --- /dev/null +++ b/src/cc/flwr/include/grpc_rere.h @@ -0,0 +1,29 @@ +/************************************************************************************************* + * + * @file start.h + * + * @brief Create a gRPC channel to connect to the server and enable message + *communication + * + * @author Lekang Jiang + * + * @version 1.0 + * + * @date 06/09/2021 + * + *************************************************************************************************/ + +#ifndef GRPC_RERE_H +#define GRPC_RERE_H +#pragma once +#include "message_handler.h" +#include + +void create_node(const std::unique_ptr &stub); +void delete_node(const std::unique_ptr &stub); +void send(const std::unique_ptr &stub, + flwr::proto::TaskRes task_res); +std::optional +receive(const std::unique_ptr &stub); + +#endif diff --git a/src/cc/flwr/include/message_handler.h b/src/cc/flwr/include/message_handler.h index 2c2229d2f972..7f008f38912e 100644 --- a/src/cc/flwr/include/message_handler.h +++ b/src/cc/flwr/include/message_handler.h @@ -37,3 +37,10 @@ ClientMessage _evaluate(flwr_local::Client *client, std::tuple handle(flwr_local::Client *client, ServerMessage server_msg); + +std::tuple +handle_task(flwr_local::Client *client, const flwr::proto::TaskIns &task_ins); + +flwr::proto::TaskRes configure_task_res(const flwr::proto::TaskRes &task_res, + const flwr::proto::TaskIns &task_ins, + const flwr::proto::Node &node); diff --git a/src/cc/flwr/include/serde.h b/src/cc/flwr/include/serde.h index 314668ac650b..aceef69b6e95 100644 --- a/src/cc/flwr/include/serde.h +++ b/src/cc/flwr/include/serde.h @@ -14,9 +14,13 @@ #pragma once // cppcheck-suppress missingInclude -#include "transport.grpc.pb.h" +#include "flwr/proto/transport.grpc.pb.h" // cppcheck-suppress missingInclude -#include "transport.pb.h" +#include "flwr/proto/transport.pb.h" +// cppcheck-suppress missingInclude +#include "flwr/proto/fleet.grpc.pb.h" +// cppcheck-suppress missingInclude +#include "flwr/proto/fleet.pb.h" #include "typing.h" using flwr::proto::ClientMessage; using flwr::proto::ServerMessage; diff --git a/src/cc/flwr/include/start.h b/src/cc/flwr/include/start.h index a99dd8520dd7..2c233be8249c 100644 --- a/src/cc/flwr/include/start.h +++ b/src/cc/flwr/include/start.h @@ -17,15 +17,10 @@ #define START_H #pragma once #include "client.h" +#include "grpc_rere.h" #include "message_handler.h" #include -using flwr::proto::ClientMessage; -using flwr::proto::FlowerService; -using flwr::proto::ServerMessage; -using grpc::Channel; -using grpc::ClientContext; -using grpc::ClientReaderWriter; -using grpc::Status; +#include #define GRPC_MAX_MESSAGE_LENGTH 536870912 // == 512 * 1024 * 1024 @@ -56,5 +51,8 @@ class start { static void start_client(std::string server_address, flwr_local::Client *client, int grpc_max_message_length = GRPC_MAX_MESSAGE_LENGTH); + static void + start_rere_client(std::string server_address, flwr_local::Client *client, + int grpc_max_message_length = GRPC_MAX_MESSAGE_LENGTH); }; #endif diff --git a/src/cc/flwr/src/grpc_rere.cc b/src/cc/flwr/src/grpc_rere.cc new file mode 100644 index 000000000000..d9d61938b9f4 --- /dev/null +++ b/src/cc/flwr/src/grpc_rere.cc @@ -0,0 +1,159 @@ +#include "grpc_rere.h" + +const std::string KEY_NODE = "node"; +const std::string KEY_TASK_INS = "current_task_ins"; + +std::map> node_store; +std::map> state; + +std::mutex node_store_mutex; +std::mutex state_mutex; + +std::optional get_node_from_store() { + std::lock_guard lock(node_store_mutex); + auto node = node_store.find(KEY_NODE); + if (node == node_store.end() || !node->second.has_value()) { + std::cerr << "Node instance missing" << std::endl; + return std::nullopt; + } + return node->second; +} + +std::optional get_current_task_ins() { + std::lock_guard state_lock(state_mutex); + auto current_task_ins = state.find(KEY_TASK_INS); + if (current_task_ins == state.end() || + !current_task_ins->second.has_value()) { + std::cerr << "No current TaskIns" << std::endl; + return std::nullopt; + } + return current_task_ins->second; +} + +void create_node(const std::unique_ptr &stub) { + flwr::proto::CreateNodeRequest create_node_request; + flwr::proto::CreateNodeResponse create_node_response; + + grpc::ClientContext context; + grpc::Status status = + stub->CreateNode(&context, create_node_request, &create_node_response); + + if (!status.ok()) { + std::cerr << "CreateNode RPC failed: " << status.error_message() + << std::endl; + return; + } + + // Validate the response + if (!create_node_response.has_node()) { + std::cerr << "Received response does not contain a node." << std::endl; + return; + } + + { + std::lock_guard lock(node_store_mutex); + node_store[KEY_NODE] = create_node_response.node(); + } +} + +void delete_node(const std::unique_ptr &stub) { + auto node = get_node_from_store(); + if (!node) { + return; + } + flwr::proto::DeleteNodeRequest delete_node_request; + flwr::proto::DeleteNodeResponse delete_node_response; + + auto heap_node = new flwr::proto::Node(*node); + delete_node_request.set_allocated_node(heap_node); + + grpc::ClientContext context; + grpc::Status status = + stub->DeleteNode(&context, delete_node_request, &delete_node_response); + + if (!status.ok()) { + std::cerr << "DeleteNode RPC failed with status: " << status.error_message() + << std::endl; + delete heap_node; // Make sure to delete if status is not ok + return; + } else { + delete_node_request.release_node(); // Release if status is ok + } + + // TODO: Check if Node needs to be removed from local map + // node_store.erase(node); +} + +std::optional +receive(const std::unique_ptr &stub) { + auto node = get_node_from_store(); + if (!node) { + return std::nullopt; + } + flwr::proto::PullTaskInsResponse response; + flwr::proto::PullTaskInsRequest request; + + request.set_allocated_node(new flwr::proto::Node(*node)); + + grpc::ClientContext context; + grpc::Status status = stub->PullTaskIns(&context, request, &response); + + // Release ownership so that the heap_node won't be deleted when request + // goes out of scope. + request.release_node(); + + if (!status.ok()) { + std::cerr << "PullTaskIns RPC failed with status: " + << status.error_message() << std::endl; + return std::nullopt; + } + + if (response.task_ins_list_size() > 0) { + flwr::proto::TaskIns task_ins = response.task_ins_list().at(0); + // TODO: Validate TaskIns + + { + std::lock_guard state_lock(state_mutex); + state[KEY_TASK_INS] = task_ins; + } + + return task_ins; + } else { + std::cerr << "TaskIns list is empty." << std::endl; + return std::nullopt; + } +} + +void send(const std::unique_ptr &stub, + flwr::proto::TaskRes task_res) { + auto node = get_node_from_store(); + if (!node) { + return; + } + + auto task_ins = get_current_task_ins(); + if (!task_ins) { + return; + } + + // TODO: Validate TaskIns + + flwr::proto::TaskRes new_task_res = + configure_task_res(task_res, *task_ins, *node); + + flwr::proto::PushTaskResRequest request; + *request.add_task_res_list() = new_task_res; + flwr::proto::PushTaskResResponse response; + + grpc::ClientContext context; + grpc::Status status = stub->PushTaskRes(&context, request, &response); + + if (!status.ok()) { + std::cerr << "PushTaskRes RPC failed with status: " + << status.error_message() << std::endl; + return; + } else { + std::lock_guard state_lock(state_mutex); + state[KEY_TASK_INS].reset(); + } +} diff --git a/src/cc/flwr/src/message_handler.cc b/src/cc/flwr/src/message_handler.cc index b985548f6cea..772910e351df 100644 --- a/src/cc/flwr/src/message_handler.cc +++ b/src/cc/flwr/src/message_handler.cc @@ -67,3 +67,69 @@ std::tuple handle(flwr_local::Client *client, } throw "Unkown server message"; } + +std::tuple +handle_task(flwr_local::Client *client, const flwr::proto::TaskIns &task_ins) { + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + if (!task_ins.task().has_legacy_server_message()) { + // TODO: Handle SecureAggregation + throw std::runtime_error("Task still needs legacy server message"); + } + ServerMessage server_msg = task_ins.task().legacy_server_message(); +#pragma GCC diagnostic pop + + std::tuple legacy_res = handle(client, server_msg); + std::unique_ptr client_message = + std::make_unique(std::get<0>(legacy_res)); + + flwr::proto::TaskRes task_res; + task_res.set_task_id(""); + task_res.set_group_id(""); + task_res.set_workload_id(0); + + std::unique_ptr task = + std::make_unique(); + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + task->set_allocated_legacy_client_message( + client_message.release()); // Ownership transferred to `task` +#pragma GCC diagnostic pop + + task_res.set_allocated_task(task.release()); + return std::make_tuple(task_res, std::get<1>(legacy_res), + std::get<2>(legacy_res)); +} + +flwr::proto::TaskRes +configure_task_res(const flwr::proto::TaskRes &task_res, + const flwr::proto::TaskIns &ref_task_ins, + const flwr::proto::Node &producer) { + flwr::proto::TaskRes result_task_res; + + // Setting scalar fields + result_task_res.set_task_id(""); // This will be generated by the server + result_task_res.set_group_id(ref_task_ins.group_id()); + result_task_res.set_workload_id(ref_task_ins.workload_id()); + + // Merge the task from the input task_res + *result_task_res.mutable_task() = task_res.task(); + + // Construct and set the producer and consumer for the task + std::unique_ptr new_producer = + std::make_unique(producer); + result_task_res.mutable_task()->set_allocated_producer( + new_producer.release()); + + std::unique_ptr new_consumer = + std::make_unique(ref_task_ins.task().producer()); + result_task_res.mutable_task()->set_allocated_consumer( + new_consumer.release()); + + // Set ancestry in the task + result_task_res.mutable_task()->add_ancestry(ref_task_ins.task_id()); + + return result_task_res; +} diff --git a/src/cc/flwr/src/start.cc b/src/cc/flwr/src/start.cc index e6c8362995fc..4e441a9172d0 100644 --- a/src/cc/flwr/src/start.cc +++ b/src/cc/flwr/src/start.cc @@ -12,16 +12,17 @@ void start::start_client(std::string server_address, flwr_local::Client *client, args.SetMaxSendMessageSize(grpc_max_message_length); // Establish an insecure gRPC connection to a gRPC server - std::shared_ptr channel = grpc::CreateCustomChannel( + std::shared_ptr channel = grpc::CreateCustomChannel( server_address, grpc::InsecureChannelCredentials(), args); // Create stub - std::unique_ptr stub_ = - FlowerService::NewStub(channel); + std::unique_ptr stub_ = + flwr::proto::FlowerService::NewStub(channel); // Read and write messages - ClientContext context; - std::shared_ptr> + grpc::ClientContext context; + std::shared_ptr> reader_writer(stub_->Join(&context)); ServerMessage sm; while (reader_writer->Read(&sm)) { @@ -35,7 +36,7 @@ void start::start_client(std::string server_address, flwr_local::Client *client, reader_writer->WritesDone(); // Check connection status - Status status = reader_writer->Finish(); + grpc::Status status = reader_writer->Finish(); if (sleep_duration == 0) { std::cout << "Disconnect and shut down." << std::endl; @@ -47,3 +48,58 @@ void start::start_client(std::string server_address, flwr_local::Client *client, // sleep_duration << "second(s)" << std::endl; Sleep(sleep_duration * 1000); } } + +// cppcheck-suppress unusedFunction +void start::start_rere_client(std::string server_address, + flwr_local::Client *client, + int grpc_max_message_length) { + while (true) { + int sleep_duration = 0; + + // Set channel parameters + grpc::ChannelArguments args; + args.SetMaxReceiveMessageSize(grpc_max_message_length); + args.SetMaxSendMessageSize(grpc_max_message_length); + + // Establish an insecure gRPC connection to a gRPC server + std::shared_ptr channel = grpc::CreateCustomChannel( + server_address, grpc::InsecureChannelCredentials(), args); + + // Create stub + std::unique_ptr stub_ = + flwr::proto::Fleet::NewStub(channel); + + // Read and write messages + + create_node(stub_); + + while (true) { + auto task_ins = receive(stub_); + if (!task_ins) { + std::this_thread::sleep_for(std::chrono::seconds(3)); + continue; + } + auto [task_res, sleep_duration, keep_going] = + handle_task(client, task_ins.value()); + send(stub_, task_res); + if (!keep_going) { + break; + } + } + + delete_node(stub_); + if (sleep_duration == 0) { + std::cout << "Disconnect and shut down." << std::endl; + break; + } + + std::cout << "Disconnect, then re-establish connection after" + << sleep_duration << "second(s)" << std::endl; + std::this_thread::sleep_for(std::chrono::seconds(sleep_duration)); + + if (sleep_duration == 0) { + std::cout << "Disconnect and shut down." << std::endl; + break; + } + } +}