Skip to content

Commit

Permalink
Add Task validation to C++ SDK (#2543)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll authored Nov 6, 2023
1 parent e70661a commit 2800482
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 48 deletions.
2 changes: 1 addition & 1 deletion examples/quickstart-cpp/include/simple_client.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/***********************************************************************************************************
*
* @file libtorch_client.h
* @file simple_client.h
*
* @brief Define an example flower client, train and test method
*
Expand Down
10 changes: 5 additions & 5 deletions src/cc/flwr/include/grpc_rere.h
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
/*************************************************************************************************
*
* @file start.h
* @file grpc-rere.h
*
* @brief Create a gRPC channel to connect to the server and enable message
*communication
* @brief Provide functions for establishing gRPC request-response communication
*
* @author Lekang Jiang
* @author The Flower Authors
*
* @version 1.0
*
* @date 06/09/2021
* @date 06/11/2023
*
*************************************************************************************************/

#ifndef GRPC_RERE_H
#define GRPC_RERE_H
#pragma once
#include "message_handler.h"
#include "task_handler.h"
#include <grpcpp/grpcpp.h>

void create_node(const std::unique_ptr<flwr::proto::Fleet::Stub> &stub);
Expand Down
24 changes: 24 additions & 0 deletions src/cc/flwr/include/task_handler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*************************************************************************************************
*
* @file task_handler.h
*
* @brief Handle incoming or outgoing tasks
*
* @author The Flower Authors
*
* @version 1.0
*
* @date 06/11/2023
*
*************************************************************************************************/

#pragma once
#include "client.h"
#include "serde.h"

bool validate_task_ins(const flwr::proto::TaskIns &task_ins,
const bool discard_reconnect_ins);
bool validate_task_res(const flwr::proto::TaskRes &task_res);
flwr::proto::TaskRes configure_task_res(const flwr::proto::TaskRes &task_res,
const flwr::proto::TaskIns &task_ins,
const flwr::proto::Node &node);
23 changes: 12 additions & 11 deletions src/cc/flwr/src/grpc_rere.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,14 @@ receive(const std::unique_ptr<flwr::proto::Fleet::Stub> &stub) {

if (response.task_ins_list_size() > 0) {
flwr::proto::TaskIns task_ins = response.task_ins_list().at(0);
// TODO: Validate TaskIns

{
if (validate_task_ins(task_ins, true)) {
std::lock_guard<std::mutex> state_lock(state_mutex);
state[KEY_TASK_INS] = task_ins;
return task_ins;
}

return task_ins;
} else {
std::cerr << "TaskIns list is empty." << std::endl;
return std::nullopt;
}
std::cerr << "TaskIns list is empty." << std::endl;
return std::nullopt;
}

void send(const std::unique_ptr<flwr::proto::Fleet::Stub> &stub,
Expand All @@ -136,7 +132,12 @@ void send(const std::unique_ptr<flwr::proto::Fleet::Stub> &stub,
return;
}

// TODO: Validate TaskIns
if (!validate_task_res(task_res)) {
std::cerr << "TaskRes is invalid" << std::endl;
std::lock_guard<std::mutex> state_lock(state_mutex);
state[KEY_TASK_INS].reset();
return;
}

flwr::proto::TaskRes new_task_res =
configure_task_res(task_res, *task_ins, *node);
Expand All @@ -151,8 +152,8 @@ void send(const std::unique_ptr<flwr::proto::Fleet::Stub> &stub,
if (!status.ok()) {
std::cerr << "PushTaskRes RPC failed with status: "
<< status.error_message() << std::endl;
return;
} else {
}
{
std::lock_guard<std::mutex> state_lock(state_mutex);
state[KEY_TASK_INS].reset();
}
Expand Down
31 changes: 0 additions & 31 deletions src/cc/flwr/src/message_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,34 +102,3 @@ handle_task(flwr_local::Client *client, const flwr::proto::TaskIns &task_ins) {
return std::make_tuple(task_res, std::get<1>(legacy_res),
std::get<2>(legacy_res));
}

flwr::proto::TaskRes
configure_task_res(const flwr::proto::TaskRes &task_res,
const flwr::proto::TaskIns &ref_task_ins,
const flwr::proto::Node &producer) {
flwr::proto::TaskRes result_task_res;

// Setting scalar fields
result_task_res.set_task_id(""); // This will be generated by the server
result_task_res.set_group_id(ref_task_ins.group_id());
result_task_res.set_workload_id(ref_task_ins.workload_id());

// Merge the task from the input task_res
*result_task_res.mutable_task() = task_res.task();

// Construct and set the producer and consumer for the task
std::unique_ptr<flwr::proto::Node> new_producer =
std::make_unique<flwr::proto::Node>(producer);
result_task_res.mutable_task()->set_allocated_producer(
new_producer.release());

std::unique_ptr<flwr::proto::Node> new_consumer =
std::make_unique<flwr::proto::Node>(ref_task_ins.task().producer());
result_task_res.mutable_task()->set_allocated_consumer(
new_consumer.release());

// Set ancestry in the task
result_task_res.mutable_task()->add_ancestry(ref_task_ins.task_id());

return result_task_res;
}
52 changes: 52 additions & 0 deletions src/cc/flwr/src/task_handler.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include "task_handler.h"

bool validate_task_ins(const flwr::proto::TaskIns &task_ins,
const bool discard_reconnect_ins) {
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
return !(!task_ins.has_task() ||
(!task_ins.task().has_legacy_server_message() &&
!task_ins.task().has_sa()) ||
(discard_reconnect_ins &&
task_ins.task().legacy_server_message().has_reconnect_ins()));
#pragma GCC diagnostic pop
}

bool validate_task_res(const flwr::proto::TaskRes &task_res) {
// Retrieve initialized fields in TaskRes
return (task_res.task_id().empty() && task_res.group_id().empty() &&
task_res.workload_id() == 0 && !task_res.task().has_producer() &&
!task_res.task().has_producer() && !task_res.task().has_consumer() &&
task_res.task().ancestry_size() == 0);
}

flwr::proto::TaskRes
configure_task_res(const flwr::proto::TaskRes &task_res,
const flwr::proto::TaskIns &ref_task_ins,
const flwr::proto::Node &producer) {
flwr::proto::TaskRes result_task_res;

// Setting scalar fields
result_task_res.set_task_id(""); // This will be generated by the server
result_task_res.set_group_id(ref_task_ins.group_id());
result_task_res.set_workload_id(ref_task_ins.workload_id());

// Merge the task from the input task_res
*result_task_res.mutable_task() = task_res.task();

// Construct and set the producer and consumer for the task
std::unique_ptr<flwr::proto::Node> new_producer =
std::make_unique<flwr::proto::Node>(producer);
result_task_res.mutable_task()->set_allocated_producer(
new_producer.release());

std::unique_ptr<flwr::proto::Node> new_consumer =
std::make_unique<flwr::proto::Node>(ref_task_ins.task().producer());
result_task_res.mutable_task()->set_allocated_consumer(
new_consumer.release());

// Set ancestry in the task
result_task_res.mutable_task()->add_ancestry(ref_task_ins.task_id());

return result_task_res;
}

0 comments on commit 2800482

Please sign in to comment.