From 530345ea95f6d32fe566715c8c3f321fd3e837da Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Wed, 18 Oct 2023 14:02:34 +0200 Subject: [PATCH] Refactor C++ code (#2522) --- examples/quickstart-cpp/src/simple_client.cc | 2 +- src/cc/flwr/include/client.h | 4 +- src/cc/flwr/include/message_handler.h | 16 ++-- src/cc/flwr/include/serde.h | 15 +-- src/cc/flwr/include/start.h | 12 +-- src/cc/flwr/include/typing.h | 96 ++++++++++---------- src/cc/flwr/src/message_handler.cc | 12 +-- src/cc/flwr/src/serde.cc | 86 ++++++++---------- src/cc/flwr/src/start.cc | 4 +- 9 files changed, 121 insertions(+), 126 deletions(-) diff --git a/examples/quickstart-cpp/src/simple_client.cc b/examples/quickstart-cpp/src/simple_client.cc index b47a4f140431..c246722ed15d 100644 --- a/examples/quickstart-cpp/src/simple_client.cc +++ b/examples/quickstart-cpp/src/simple_client.cc @@ -32,7 +32,7 @@ flwr_local::ParametersRes SimpleFlwrClient::get_parameters() { tensors.push_back(oss2.str()); std::string tensor_str = "cpp_double"; - return flwr_local::Parameters(tensors, tensor_str); + return flwr_local::ParametersRes(flwr_local::Parameters(tensors, tensor_str)); }; void SimpleFlwrClient::set_parameters(flwr_local::Parameters params) { diff --git a/src/cc/flwr/include/client.h b/src/cc/flwr/include/client.h index 503f95b77400..7fe1f64b7bf1 100644 --- a/src/cc/flwr/include/client.h +++ b/src/cc/flwr/include/client.h @@ -22,7 +22,7 @@ namespace flwr_local { * */ class Client { - public: +public: /** * * @brief Return the current local model parameters @@ -60,4 +60,4 @@ class Client { */ virtual EvaluateRes evaluate(EvaluateIns ins) = 0; }; -} // namespace flwr_local +} // namespace flwr_local diff --git a/src/cc/flwr/include/message_handler.h b/src/cc/flwr/include/message_handler.h index 102c1d3da79a..2c2229d2f972 100644 --- a/src/cc/flwr/include/message_handler.h +++ b/src/cc/flwr/include/message_handler.h @@ -16,24 +16,24 @@ #include "client.h" #include "serde.h" using flwr::proto::ClientMessage; -using ClientMessage_Disconnect=flwr::proto::ClientMessage_DisconnectRes; +using ClientMessage_Disconnect = flwr::proto::ClientMessage_DisconnectRes; using flwr::proto::ClientMessage_EvaluateRes; using flwr::proto::ClientMessage_FitRes; using flwr::proto::Reason; using flwr::proto::ServerMessage; using flwr::proto::ServerMessage_EvaluateIns; using flwr::proto::ServerMessage_FitIns; -using ServerMessage_Reconnect=flwr::proto::ServerMessage_ReconnectIns; +using ServerMessage_Reconnect = flwr::proto::ServerMessage_ReconnectIns; -std::tuple _reconnect( - ServerMessage_Reconnect reconnect_msg); +std::tuple +_reconnect(ServerMessage_Reconnect reconnect_msg); -ClientMessage _get_parameters(flwr_local::Client* client); +ClientMessage _get_parameters(flwr_local::Client *client); -ClientMessage _fit(flwr_local::Client* client, ServerMessage_FitIns fit_msg); +ClientMessage _fit(flwr_local::Client *client, ServerMessage_FitIns fit_msg); -ClientMessage _evaluate(flwr_local::Client* client, +ClientMessage _evaluate(flwr_local::Client *client, ServerMessage_EvaluateIns evaluate_msg); -std::tuple handle(flwr_local::Client* client, +std::tuple handle(flwr_local::Client *client, ServerMessage server_msg); diff --git a/src/cc/flwr/include/serde.h b/src/cc/flwr/include/serde.h index 32e683be4b86..314668ac650b 100644 --- a/src/cc/flwr/include/serde.h +++ b/src/cc/flwr/include/serde.h @@ -13,7 +13,9 @@ * ********************************************************************************************************/ #pragma once +// cppcheck-suppress missingInclude #include "transport.grpc.pb.h" +// cppcheck-suppress missingInclude #include "transport.pb.h" #include "typing.h" using flwr::proto::ClientMessage; @@ -23,7 +25,7 @@ using flwr::proto::Reason; using ProtoScalar = flwr::proto::Scalar; using flwr::proto::ClientMessage_EvaluateRes; using flwr::proto::ClientMessage_FitRes; -using ClientMessage_ParametersRes=flwr::proto::ClientMessage_GetParametersRes; +using ClientMessage_ParametersRes = flwr::proto::ClientMessage_GetParametersRes; using flwr::proto::ServerMessage_EvaluateIns; using flwr::proto::ServerMessage_FitIns; @@ -51,20 +53,21 @@ flwr_local::Scalar scalar_from_proto(ProtoScalar scalar_msg); * Serialize client metrics type to protobuf metrics type * "Any" is used in Python, this part might be changed if needed */ -google::protobuf::Map metrics_to_proto( - flwr_local::Metrics metrics); +google::protobuf::Map +metrics_to_proto(flwr_local::Metrics metrics); /** * Deserialize protobuf metrics type to client metrics type * "Any" is used in Python, this part might be changed if needed */ -flwr_local::Metrics metrics_from_proto( - google::protobuf::Map proto); +flwr_local::Metrics +metrics_from_proto(google::protobuf::Map proto); /** * Serialize client ParametersRes type to protobuf ParametersRes type */ -ClientMessage_ParametersRes parameters_res_to_proto(flwr_local::ParametersRes res); +ClientMessage_ParametersRes +parameters_res_to_proto(flwr_local::ParametersRes res); /** * Deserialize protobuf FitIns type to client FitIns type diff --git a/src/cc/flwr/include/start.h b/src/cc/flwr/include/start.h index c6ad9b19ec87..a99dd8520dd7 100644 --- a/src/cc/flwr/include/start.h +++ b/src/cc/flwr/include/start.h @@ -16,9 +16,9 @@ #ifndef START_H #define START_H #pragma once -#include #include "client.h" #include "message_handler.h" +#include using flwr::proto::ClientMessage; using flwr::proto::FlowerService; using flwr::proto::ServerMessage; @@ -27,7 +27,7 @@ using grpc::ClientContext; using grpc::ClientReaderWriter; using grpc::Status; -#define GRPC_MAX_MESSAGE_LENGTH 536870912 // == 512 * 1024 * 1024 +#define GRPC_MAX_MESSAGE_LENGTH 536870912 // == 512 * 1024 * 1024 /** * @brief Start a C++ Flower Client which connects to a gRPC server @@ -52,9 +52,9 @@ using grpc::Status; */ class start { - public: - static void start_client(std::string server_address, - flwr_local::Client* client, - int grpc_max_message_length = GRPC_MAX_MESSAGE_LENGTH); +public: + static void + start_client(std::string server_address, flwr_local::Client *client, + int grpc_max_message_length = GRPC_MAX_MESSAGE_LENGTH); }; #endif diff --git a/src/cc/flwr/include/typing.h b/src/cc/flwr/include/typing.h index 17240bb75f3c..5aee90b6c215 100644 --- a/src/cc/flwr/include/typing.h +++ b/src/cc/flwr/include/typing.h @@ -30,7 +30,7 @@ namespace flwr_local { * */ class Scalar { - public: +public: // Getters std::optional getBool() { return b; } std::optional getBytes() { return bytes; } @@ -40,12 +40,12 @@ class Scalar { // Setters void setBool(bool b) { this->b = b; } - void setBytes(std::string bytes) { this->bytes = bytes; } + void setBytes(const std::string &bytes) { this->bytes = bytes; } void setDouble(double d) { this->d = d; } void setInt(int i) { this->i = i; } - void setString(std::string string) { this->string = string; } + void setString(const std::string &string) { this->string = string; } - private: +private: std::optional b = std::nullopt; std::optional bytes = std::nullopt; std::optional d = std::nullopt; @@ -59,9 +59,10 @@ typedef std::map Metrics; * Model parameters */ class Parameters { - public: +public: Parameters() {} - Parameters(std::list tensors, std::string tensor_type) + Parameters(const std::list &tensors, + const std::string &tensor_type) : tensors(tensors), tensor_type(tensor_type) {} // Getters @@ -69,12 +70,14 @@ class Parameters { std::string getTensor_type() { return tensor_type; } // Setters - void setTensors(std::list tensors) { this->tensors = tensors; } - void setTensor_type(std::string tensor_type) { + void setTensors(const std::list &tensors) { + this->tensors = tensors; + } + void setTensor_type(const std::string &tensor_type) { this->tensor_type = tensor_type; } - private: +private: std::list tensors; std::string tensor_type; }; @@ -83,13 +86,14 @@ class Parameters { * Response when asked to return parameters */ class ParametersRes { - public: - ParametersRes(Parameters parameters) : parameters(parameters) {} +public: + explicit ParametersRes(const Parameters ¶meters) + : parameters(parameters) {} Parameters getParameters() { return parameters; } - void setParameters(Parameters p) { parameters = p; } + void setParameters(const Parameters &p) { parameters = p; } - private: +private: Parameters parameters; }; @@ -97,8 +101,9 @@ class ParametersRes { * Fit instructions for a client */ class FitIns { - public: - FitIns(Parameters parameters, std::map config) +public: + FitIns(const Parameters ¶meters, + const std::map &config) : parameters(parameters), config(config) {} // Getters @@ -106,12 +111,12 @@ class FitIns { std::map getConfig() { return config; } // Setters - void setParameters(Parameters p) { parameters = p; } - void setConfig(std::map config) { + void setParameters(const Parameters &p) { parameters = p; } + void setConfig(const std::map &config) { this->config = config; } - private: +private: Parameters parameters; std::map config; }; @@ -120,17 +125,12 @@ class FitIns { * Fit response from a client */ class FitRes { - public: +public: FitRes() {} - FitRes(Parameters parameters, - int num_examples, - int num_examples_ceil, - float fit_duration, - Metrics metrics) - : parameters(parameters), - num_examples(num_examples), - fit_duration(fit_duration), - metrics(metrics) {} + FitRes(const Parameters ¶meters, int num_examples, int num_examples_ceil, + float fit_duration, const Metrics &metrics) + : parameters(parameters), num_examples(num_examples), + fit_duration(fit_duration), metrics(metrics) {} // Getters Parameters getParameters() { return parameters; } @@ -143,16 +143,16 @@ class FitRes { std::optional getMetrics() { return metrics; } // Setters - void setParameters(Parameters p) { parameters = p; } + void setParameters(const Parameters &p) { parameters = p; } void setNum_example(int n) { num_examples = n; } /*void setNum_examples_ceil(int n) { num_examples_ceil = n; }*/ void setFit_duration(float f) { fit_duration = f; } - void setMetrics(flwr_local::Metrics m) { metrics = m; } + void setMetrics(const flwr_local::Metrics &m) { metrics = m; } - private: +private: Parameters parameters; int num_examples; // std::optional num_examples_ceil = std::nullopt; @@ -164,8 +164,9 @@ class FitRes { * Evaluate instructions for a client */ class EvaluateIns { - public: - EvaluateIns(Parameters parameters, std::map config) +public: + EvaluateIns(const Parameters ¶meters, + const std::map &config) : parameters(parameters), config(config) {} // Getters @@ -173,12 +174,12 @@ class EvaluateIns { std::map getConfig() { return config; } // Setters - void setParameters(Parameters p) { parameters = p; } - void setConfig(std::map config) { + void setParameters(const Parameters &p) { parameters = p; } + void setConfig(const std::map &config) { this->config = config; } - private: +private: Parameters parameters; std::map config; }; @@ -187,9 +188,10 @@ class EvaluateIns { * Evaluate response from a client */ class EvaluateRes { - public: +public: EvaluateRes() {} - EvaluateRes(float loss, int num_examples, float accuracy, Metrics metrics) + EvaluateRes(float loss, int num_examples, float accuracy, + const Metrics &metrics) : loss(loss), num_examples(num_examples), metrics(metrics) {} // Getters @@ -200,9 +202,9 @@ class EvaluateRes { // Setters void setLoss(float f) { loss = f; } void setNum_example(int n) { num_examples = n; } - void setMetrics(Metrics m) { metrics = m; } + void setMetrics(const Metrics &m) { metrics = m; } - private: +private: float loss; int num_examples; std::optional metrics = std::nullopt; @@ -212,29 +214,29 @@ typedef std::map Config; typedef std::map Properties; class PropertiesIns { - public: +public: PropertiesIns() {} std::map getPropertiesIns() { return static_cast>(config); } - void setPropertiesIns(Config c) { config = c; } + void setPropertiesIns(const Config &c) { config = c; } - private: +private: Config config; }; class PropertiesRes { - public: +public: PropertiesRes() {} Properties getPropertiesRes() { return properties; } - void setPropertiesRes(Properties p) { properties = p; } + void setPropertiesRes(const Properties &p) { properties = p; } - private: +private: Properties properties; }; -} // namespace flwr_local +} // namespace flwr_local diff --git a/src/cc/flwr/src/message_handler.cc b/src/cc/flwr/src/message_handler.cc index 86254317c160..b985548f6cea 100644 --- a/src/cc/flwr/src/message_handler.cc +++ b/src/cc/flwr/src/message_handler.cc @@ -1,7 +1,7 @@ #include "message_handler.h" -std::tuple _reconnect( - ServerMessage_Reconnect reconnect_msg) { +std::tuple +_reconnect(ServerMessage_Reconnect reconnect_msg) { // Determine the reason for sending Disconnect message Reason reason = Reason::ACK; int sleep_duration = 0; @@ -19,14 +19,14 @@ std::tuple _reconnect( return std::make_tuple(cm, sleep_duration); } -ClientMessage _get_parameters(flwr_local::Client* client) { +ClientMessage _get_parameters(flwr_local::Client *client) { ClientMessage cm; *(cm.mutable_get_parameters_res()) = parameters_res_to_proto(client->get_parameters()); return cm; } -ClientMessage _fit(flwr_local::Client* client, ServerMessage_FitIns fit_msg) { +ClientMessage _fit(flwr_local::Client *client, ServerMessage_FitIns fit_msg) { // Deserialize fit instruction flwr_local::FitIns fit_ins = fit_ins_from_proto(fit_msg); // Perform fit @@ -37,7 +37,7 @@ ClientMessage _fit(flwr_local::Client* client, ServerMessage_FitIns fit_msg) { return cm; } -ClientMessage _evaluate(flwr_local::Client* client, +ClientMessage _evaluate(flwr_local::Client *client, ServerMessage_EvaluateIns evaluate_msg) { // Deserialize evaluate instruction flwr_local::EvaluateIns evaluate_ins = evaluate_ins_from_proto(evaluate_msg); @@ -49,7 +49,7 @@ ClientMessage _evaluate(flwr_local::Client* client, return cm; } -std::tuple handle(flwr_local::Client* client, +std::tuple handle(flwr_local::Client *client, ServerMessage server_msg) { if (server_msg.has_reconnect_ins()) { std::tuple rec = _reconnect(server_msg.reconnect_ins()); diff --git a/src/cc/flwr/src/serde.cc b/src/cc/flwr/src/serde.cc index c6237fa187fe..eeb3bf8e44f3 100644 --- a/src/cc/flwr/src/serde.cc +++ b/src/cc/flwr/src/serde.cc @@ -7,7 +7,7 @@ MessageParameters parameters_to_proto(flwr_local::Parameters parameters) { MessageParameters mp; mp.set_tensor_type(parameters.getTensor_type()); - for (auto& i : parameters.getTensors()) { + for (auto &i : parameters.getTensors()) { mp.add_tensors(i); } return mp; @@ -60,23 +60,23 @@ ProtoScalar scalar_to_proto(flwr_local::Scalar scalar_msg) { flwr_local::Scalar scalar_from_proto(ProtoScalar scalar_msg) { flwr_local::Scalar scalar; switch (scalar_msg.scalar_case()) { - case 1: - scalar.setDouble(scalar_msg.double_()); - return scalar; - case 8: - scalar.setInt(scalar_msg.sint64()); - return scalar; - case 13: - scalar.setBool(scalar_msg.bool_()); - return scalar; - case 14: - scalar.setString(scalar_msg.string()); - return scalar; - case 15: - scalar.setBytes(scalar_msg.bytes()); - return scalar; - case 0: - break; + case 1: + scalar.setDouble(scalar_msg.double_()); + return scalar; + case 8: + scalar.setInt(scalar_msg.sint64()); + return scalar; + case 13: + scalar.setBool(scalar_msg.bool_()); + return scalar; + case 14: + scalar.setString(scalar_msg.string()); + return scalar; + case 15: + scalar.setBytes(scalar_msg.bytes()); + return scalar; + case 0: + break; } throw "Error scalar type"; } @@ -85,11 +85,11 @@ flwr_local::Scalar scalar_from_proto(ProtoScalar scalar_msg) { * Serialize client metrics type to protobuf metrics type * "Any" is used in Python, this part might be changed if needed */ -google::protobuf::Map metrics_to_proto( - flwr_local::Metrics metrics) { +google::protobuf::Map +metrics_to_proto(flwr_local::Metrics metrics) { google::protobuf::Map proto; - for (auto& [key, value] : metrics) { + for (auto &[key, value] : metrics) { proto[key] = scalar_to_proto(value); } @@ -100,11 +100,11 @@ google::protobuf::Map metrics_to_proto( * Deserialize protobuf metrics type to client metrics type * "Any" is used in Python, this part might be changed if needed */ -flwr_local::Metrics metrics_from_proto( - google::protobuf::Map proto) { +flwr_local::Metrics +metrics_from_proto(google::protobuf::Map proto) { flwr_local::Metrics metrics; - for (auto& [key, value] : proto) { + for (auto &[key, value] : proto) { metrics[key] = scalar_from_proto(value); } return metrics; @@ -113,7 +113,8 @@ flwr_local::Metrics metrics_from_proto( /** * Serialize client ParametersRes type to protobuf ParametersRes type */ -ClientMessage_ParametersRes parameters_res_to_proto(flwr_local::ParametersRes res) { +ClientMessage_ParametersRes +parameters_res_to_proto(flwr_local::ParametersRes res) { MessageParameters mp = parameters_to_proto(res.getParameters()); ClientMessage_ParametersRes cpr; *(cpr.mutable_parameters()) = mp; @@ -136,21 +137,16 @@ ClientMessage_FitRes fit_res_to_proto(flwr_local::FitRes res) { ClientMessage_FitRes cres; MessageParameters parameters_proto = parameters_to_proto(res.getParameters()); - google::protobuf::Map< ::std::string, ::flwr::proto::Scalar>* - metrics_msg; - if (res.getMetrics() == std::nullopt) { - metrics_msg = NULL; - } else { - google::protobuf::Map< ::std::string, ::flwr::proto::Scalar> proto = - metrics_to_proto(res.getMetrics().value()); - metrics_msg = &proto; + google::protobuf::Map<::std::string, ::flwr::proto::Scalar> metrics_msg; + if (res.getMetrics() != std::nullopt) { + metrics_msg = metrics_to_proto(res.getMetrics().value()); } // Forward - compatible case *(cres.mutable_parameters()) = parameters_proto; cres.set_num_examples(res.getNum_example()); - if (metrics_msg != NULL) { - *cres.mutable_metrics() = *metrics_msg; + if (!metrics_msg.empty()) { + *cres.mutable_metrics() = metrics_msg; } return cres; } @@ -169,26 +165,20 @@ flwr_local::EvaluateIns evaluate_ins_from_proto(ServerMessage_EvaluateIns msg) { */ ClientMessage_EvaluateRes evaluate_res_to_proto(flwr_local::EvaluateRes res) { ClientMessage_EvaluateRes cres; - google::protobuf::Map< ::std::string, ::flwr::proto::Scalar>* - metrics_msg; - google::protobuf::Map< ::std::string, ::flwr::proto::Scalar> proto; - if (res.getMetrics() == std::nullopt) { - metrics_msg = NULL; - } else { - proto = metrics_to_proto(res.getMetrics().value()); - metrics_msg = &proto; + google::protobuf::Map<::std::string, ::flwr::proto::Scalar> metrics_msg; + if (res.getMetrics() != std::nullopt) { + metrics_msg = metrics_to_proto(res.getMetrics().value()); } - // Forward - compatible case cres.set_loss(res.getLoss()); cres.set_num_examples(res.getNum_example()); - if (metrics_msg != NULL) { - auto& map = *cres.mutable_metrics(); + if (!metrics_msg.empty()) { + auto &map = *cres.mutable_metrics(); - for (auto& [key, value] : *metrics_msg) { + for (auto &[key, value] : metrics_msg) { map[key] = value; } } - + return cres; } diff --git a/src/cc/flwr/src/start.cc b/src/cc/flwr/src/start.cc index f8b7b3cc631e..e6c8362995fc 100644 --- a/src/cc/flwr/src/start.cc +++ b/src/cc/flwr/src/start.cc @@ -1,7 +1,7 @@ #include "start.h" -void start::start_client(std::string server_address, - flwr_local::Client* client, +// cppcheck-suppress unusedFunction +void start::start_client(std::string server_address, flwr_local::Client *client, int grpc_max_message_length) { while (true) { int sleep_duration = 0;