Skip to content

Commit

Permalink
Refactor C++ code (#2522)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll authored Oct 18, 2023
1 parent 02908d8 commit 530345e
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 126 deletions.
2 changes: 1 addition & 1 deletion examples/quickstart-cpp/src/simple_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ flwr_local::ParametersRes SimpleFlwrClient::get_parameters() {
tensors.push_back(oss2.str());

std::string tensor_str = "cpp_double";
return flwr_local::Parameters(tensors, tensor_str);
return flwr_local::ParametersRes(flwr_local::Parameters(tensors, tensor_str));
};

void SimpleFlwrClient::set_parameters(flwr_local::Parameters params) {
Expand Down
4 changes: 2 additions & 2 deletions src/cc/flwr/include/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace flwr_local {
*
*/
class Client {
public:
public:
/**
*
* @brief Return the current local model parameters
Expand Down Expand Up @@ -60,4 +60,4 @@ class Client {
*/
virtual EvaluateRes evaluate(EvaluateIns ins) = 0;
};
} // namespace flwr_local
} // namespace flwr_local
16 changes: 8 additions & 8 deletions src/cc/flwr/include/message_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,24 @@
#include "client.h"
#include "serde.h"
using flwr::proto::ClientMessage;
using ClientMessage_Disconnect=flwr::proto::ClientMessage_DisconnectRes;
using ClientMessage_Disconnect = flwr::proto::ClientMessage_DisconnectRes;
using flwr::proto::ClientMessage_EvaluateRes;
using flwr::proto::ClientMessage_FitRes;
using flwr::proto::Reason;
using flwr::proto::ServerMessage;
using flwr::proto::ServerMessage_EvaluateIns;
using flwr::proto::ServerMessage_FitIns;
using ServerMessage_Reconnect=flwr::proto::ServerMessage_ReconnectIns;
using ServerMessage_Reconnect = flwr::proto::ServerMessage_ReconnectIns;

std::tuple<ClientMessage, int> _reconnect(
ServerMessage_Reconnect reconnect_msg);
std::tuple<ClientMessage, int>
_reconnect(ServerMessage_Reconnect reconnect_msg);

ClientMessage _get_parameters(flwr_local::Client* client);
ClientMessage _get_parameters(flwr_local::Client *client);

ClientMessage _fit(flwr_local::Client* client, ServerMessage_FitIns fit_msg);
ClientMessage _fit(flwr_local::Client *client, ServerMessage_FitIns fit_msg);

ClientMessage _evaluate(flwr_local::Client* client,
ClientMessage _evaluate(flwr_local::Client *client,
ServerMessage_EvaluateIns evaluate_msg);

std::tuple<ClientMessage, int, bool> handle(flwr_local::Client* client,
std::tuple<ClientMessage, int, bool> handle(flwr_local::Client *client,
ServerMessage server_msg);
15 changes: 9 additions & 6 deletions src/cc/flwr/include/serde.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
* ********************************************************************************************************/

#pragma once
// cppcheck-suppress missingInclude
#include "transport.grpc.pb.h"
// cppcheck-suppress missingInclude
#include "transport.pb.h"
#include "typing.h"
using flwr::proto::ClientMessage;
Expand All @@ -23,7 +25,7 @@ using flwr::proto::Reason;
using ProtoScalar = flwr::proto::Scalar;
using flwr::proto::ClientMessage_EvaluateRes;
using flwr::proto::ClientMessage_FitRes;
using ClientMessage_ParametersRes=flwr::proto::ClientMessage_GetParametersRes;
using ClientMessage_ParametersRes = flwr::proto::ClientMessage_GetParametersRes;
using flwr::proto::ServerMessage_EvaluateIns;
using flwr::proto::ServerMessage_FitIns;

Expand Down Expand Up @@ -51,20 +53,21 @@ flwr_local::Scalar scalar_from_proto(ProtoScalar scalar_msg);
* Serialize client metrics type to protobuf metrics type
* "Any" is used in Python, this part might be changed if needed
*/
google::protobuf::Map<std::string, ProtoScalar> metrics_to_proto(
flwr_local::Metrics metrics);
google::protobuf::Map<std::string, ProtoScalar>
metrics_to_proto(flwr_local::Metrics metrics);

/**
* Deserialize protobuf metrics type to client metrics type
* "Any" is used in Python, this part might be changed if needed
*/
flwr_local::Metrics metrics_from_proto(
google::protobuf::Map<std::string, ProtoScalar> proto);
flwr_local::Metrics
metrics_from_proto(google::protobuf::Map<std::string, ProtoScalar> proto);

/**
* Serialize client ParametersRes type to protobuf ParametersRes type
*/
ClientMessage_ParametersRes parameters_res_to_proto(flwr_local::ParametersRes res);
ClientMessage_ParametersRes
parameters_res_to_proto(flwr_local::ParametersRes res);

/**
* Deserialize protobuf FitIns type to client FitIns type
Expand Down
12 changes: 6 additions & 6 deletions src/cc/flwr/include/start.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
#ifndef START_H
#define START_H
#pragma once
#include <grpcpp/grpcpp.h>
#include "client.h"
#include "message_handler.h"
#include <grpcpp/grpcpp.h>
using flwr::proto::ClientMessage;
using flwr::proto::FlowerService;
using flwr::proto::ServerMessage;
Expand All @@ -27,7 +27,7 @@ using grpc::ClientContext;
using grpc::ClientReaderWriter;
using grpc::Status;

#define GRPC_MAX_MESSAGE_LENGTH 536870912 // == 512 * 1024 * 1024
#define GRPC_MAX_MESSAGE_LENGTH 536870912 // == 512 * 1024 * 1024

/**
* @brief Start a C++ Flower Client which connects to a gRPC server
Expand All @@ -52,9 +52,9 @@ using grpc::Status;
*/

class start {
public:
static void start_client(std::string server_address,
flwr_local::Client* client,
int grpc_max_message_length = GRPC_MAX_MESSAGE_LENGTH);
public:
static void
start_client(std::string server_address, flwr_local::Client *client,
int grpc_max_message_length = GRPC_MAX_MESSAGE_LENGTH);
};
#endif
96 changes: 49 additions & 47 deletions src/cc/flwr/include/typing.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace flwr_local {
*
*/
class Scalar {
public:
public:
// Getters
std::optional<bool> getBool() { return b; }
std::optional<std::string> getBytes() { return bytes; }
Expand All @@ -40,12 +40,12 @@ class Scalar {

// Setters
void setBool(bool b) { this->b = b; }
void setBytes(std::string bytes) { this->bytes = bytes; }
void setBytes(const std::string &bytes) { this->bytes = bytes; }
void setDouble(double d) { this->d = d; }
void setInt(int i) { this->i = i; }
void setString(std::string string) { this->string = string; }
void setString(const std::string &string) { this->string = string; }

private:
private:
std::optional<bool> b = std::nullopt;
std::optional<std::string> bytes = std::nullopt;
std::optional<double> d = std::nullopt;
Expand All @@ -59,22 +59,25 @@ typedef std::map<std::string, flwr_local::Scalar> Metrics;
* Model parameters
*/
class Parameters {
public:
public:
Parameters() {}
Parameters(std::list<std::string> tensors, std::string tensor_type)
Parameters(const std::list<std::string> &tensors,
const std::string &tensor_type)
: tensors(tensors), tensor_type(tensor_type) {}

// Getters
std::list<std::string> getTensors() { return tensors; }
std::string getTensor_type() { return tensor_type; }

// Setters
void setTensors(std::list<std::string> tensors) { this->tensors = tensors; }
void setTensor_type(std::string tensor_type) {
void setTensors(const std::list<std::string> &tensors) {
this->tensors = tensors;
}
void setTensor_type(const std::string &tensor_type) {
this->tensor_type = tensor_type;
}

private:
private:
std::list<std::string> tensors;
std::string tensor_type;
};
Expand All @@ -83,35 +86,37 @@ class Parameters {
* Response when asked to return parameters
*/
class ParametersRes {
public:
ParametersRes(Parameters parameters) : parameters(parameters) {}
public:
explicit ParametersRes(const Parameters &parameters)
: parameters(parameters) {}

Parameters getParameters() { return parameters; }
void setParameters(Parameters p) { parameters = p; }
void setParameters(const Parameters &p) { parameters = p; }

private:
private:
Parameters parameters;
};

/**
* Fit instructions for a client
*/
class FitIns {
public:
FitIns(Parameters parameters, std::map<std::string, flwr_local::Scalar> config)
public:
FitIns(const Parameters &parameters,
const std::map<std::string, flwr_local::Scalar> &config)
: parameters(parameters), config(config) {}

// Getters
Parameters getParameters() { return parameters; }
std::map<std::string, Scalar> getConfig() { return config; }

// Setters
void setParameters(Parameters p) { parameters = p; }
void setConfig(std::map<std::string, Scalar> config) {
void setParameters(const Parameters &p) { parameters = p; }
void setConfig(const std::map<std::string, Scalar> &config) {
this->config = config;
}

private:
private:
Parameters parameters;
std::map<std::string, Scalar> config;
};
Expand All @@ -120,17 +125,12 @@ class FitIns {
* Fit response from a client
*/
class FitRes {
public:
public:
FitRes() {}
FitRes(Parameters parameters,
int num_examples,
int num_examples_ceil,
float fit_duration,
Metrics metrics)
: parameters(parameters),
num_examples(num_examples),
fit_duration(fit_duration),
metrics(metrics) {}
FitRes(const Parameters &parameters, int num_examples, int num_examples_ceil,
float fit_duration, const Metrics &metrics)
: parameters(parameters), num_examples(num_examples),
fit_duration(fit_duration), metrics(metrics) {}

// Getters
Parameters getParameters() { return parameters; }
Expand All @@ -143,16 +143,16 @@ class FitRes {
std::optional<Metrics> getMetrics() { return metrics; }

// Setters
void setParameters(Parameters p) { parameters = p; }
void setParameters(const Parameters &p) { parameters = p; }
void setNum_example(int n) { num_examples = n; }
/*void setNum_examples_ceil(int n)
{
num_examples_ceil = n;
}*/
void setFit_duration(float f) { fit_duration = f; }
void setMetrics(flwr_local::Metrics m) { metrics = m; }
void setMetrics(const flwr_local::Metrics &m) { metrics = m; }

private:
private:
Parameters parameters;
int num_examples;
// std::optional<int> num_examples_ceil = std::nullopt;
Expand All @@ -164,21 +164,22 @@ class FitRes {
* Evaluate instructions for a client
*/
class EvaluateIns {
public:
EvaluateIns(Parameters parameters, std::map<std::string, Scalar> config)
public:
EvaluateIns(const Parameters &parameters,
const std::map<std::string, Scalar> &config)
: parameters(parameters), config(config) {}

// Getters
Parameters getParameters() { return parameters; }
std::map<std::string, Scalar> getConfig() { return config; }

// Setters
void setParameters(Parameters p) { parameters = p; }
void setConfig(std::map<std::string, Scalar> config) {
void setParameters(const Parameters &p) { parameters = p; }
void setConfig(const std::map<std::string, Scalar> &config) {
this->config = config;
}

private:
private:
Parameters parameters;
std::map<std::string, Scalar> config;
};
Expand All @@ -187,9 +188,10 @@ class EvaluateIns {
* Evaluate response from a client
*/
class EvaluateRes {
public:
public:
EvaluateRes() {}
EvaluateRes(float loss, int num_examples, float accuracy, Metrics metrics)
EvaluateRes(float loss, int num_examples, float accuracy,
const Metrics &metrics)
: loss(loss), num_examples(num_examples), metrics(metrics) {}

// Getters
Expand All @@ -200,9 +202,9 @@ class EvaluateRes {
// Setters
void setLoss(float f) { loss = f; }
void setNum_example(int n) { num_examples = n; }
void setMetrics(Metrics m) { metrics = m; }
void setMetrics(const Metrics &m) { metrics = m; }

private:
private:
float loss;
int num_examples;
std::optional<Metrics> metrics = std::nullopt;
Expand All @@ -212,29 +214,29 @@ typedef std::map<std::string, flwr_local::Scalar> Config;
typedef std::map<std::string, flwr_local::Scalar> Properties;

class PropertiesIns {
public:
public:
PropertiesIns() {}

std::map<std::string, flwr_local::Scalar> getPropertiesIns() {
return static_cast<std::map<std::string, flwr_local::Scalar>>(config);
}

void setPropertiesIns(Config c) { config = c; }
void setPropertiesIns(const Config &c) { config = c; }

private:
private:
Config config;
};

class PropertiesRes {
public:
public:
PropertiesRes() {}

Properties getPropertiesRes() { return properties; }

void setPropertiesRes(Properties p) { properties = p; }
void setPropertiesRes(const Properties &p) { properties = p; }

private:
private:
Properties properties;
};

} // namespace flwr_local
} // namespace flwr_local
Loading

0 comments on commit 530345e

Please sign in to comment.