-
Notifications
You must be signed in to change notification settings - Fork 944
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Task validation to C++ SDK (#2543)
- Loading branch information
1 parent
e70661a
commit 2800482
Showing
6 changed files
with
94 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
/************************************************************************************************* | ||
* | ||
* @file task_handler.h | ||
* | ||
* @brief Handle incoming or outgoing tasks | ||
* | ||
* @author The Flower Authors | ||
* | ||
* @version 1.0 | ||
* | ||
* @date 06/11/2023 | ||
* | ||
*************************************************************************************************/ | ||
|
||
#pragma once | ||
#include "client.h" | ||
#include "serde.h" | ||
|
||
bool validate_task_ins(const flwr::proto::TaskIns &task_ins, | ||
const bool discard_reconnect_ins); | ||
bool validate_task_res(const flwr::proto::TaskRes &task_res); | ||
flwr::proto::TaskRes configure_task_res(const flwr::proto::TaskRes &task_res, | ||
const flwr::proto::TaskIns &task_ins, | ||
const flwr::proto::Node &node); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
#include "task_handler.h" | ||
|
||
bool validate_task_ins(const flwr::proto::TaskIns &task_ins, | ||
const bool discard_reconnect_ins) { | ||
#pragma GCC diagnostic push | ||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations" | ||
return !(!task_ins.has_task() || | ||
(!task_ins.task().has_legacy_server_message() && | ||
!task_ins.task().has_sa()) || | ||
(discard_reconnect_ins && | ||
task_ins.task().legacy_server_message().has_reconnect_ins())); | ||
#pragma GCC diagnostic pop | ||
} | ||
|
||
bool validate_task_res(const flwr::proto::TaskRes &task_res) { | ||
// Retrieve initialized fields in TaskRes | ||
return (task_res.task_id().empty() && task_res.group_id().empty() && | ||
task_res.workload_id() == 0 && !task_res.task().has_producer() && | ||
!task_res.task().has_producer() && !task_res.task().has_consumer() && | ||
task_res.task().ancestry_size() == 0); | ||
} | ||
|
||
flwr::proto::TaskRes | ||
configure_task_res(const flwr::proto::TaskRes &task_res, | ||
const flwr::proto::TaskIns &ref_task_ins, | ||
const flwr::proto::Node &producer) { | ||
flwr::proto::TaskRes result_task_res; | ||
|
||
// Setting scalar fields | ||
result_task_res.set_task_id(""); // This will be generated by the server | ||
result_task_res.set_group_id(ref_task_ins.group_id()); | ||
result_task_res.set_workload_id(ref_task_ins.workload_id()); | ||
|
||
// Merge the task from the input task_res | ||
*result_task_res.mutable_task() = task_res.task(); | ||
|
||
// Construct and set the producer and consumer for the task | ||
std::unique_ptr<flwr::proto::Node> new_producer = | ||
std::make_unique<flwr::proto::Node>(producer); | ||
result_task_res.mutable_task()->set_allocated_producer( | ||
new_producer.release()); | ||
|
||
std::unique_ptr<flwr::proto::Node> new_consumer = | ||
std::make_unique<flwr::proto::Node>(ref_task_ins.task().producer()); | ||
result_task_res.mutable_task()->set_allocated_consumer( | ||
new_consumer.release()); | ||
|
||
// Set ancestry in the task | ||
result_task_res.mutable_task()->add_ancestry(ref_task_ins.task_id()); | ||
|
||
return result_task_res; | ||
} |