Skip to content

Commit

Permalink
[openfl.experimental] Format using black and lint with flake8
Browse files Browse the repository at this point in the history
Signed-off-by: Shah, Karan <[email protected]>
  • Loading branch information
MasterSkepticista committed May 30, 2024
1 parent c059910 commit 5e3af9b
Show file tree
Hide file tree
Showing 26 changed files with 1,326 additions and 898 deletions.
150 changes: 94 additions & 56 deletions openfl/experimental/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,20 @@ class Aggregator:
None
"""

def __init__(self,
aggregator_uuid: str,
federation_uuid: str,
authorized_cols: List,
flow: Any,
rounds_to_train: int = 1,
checkpoint: bool = False,
private_attributes_callable: Callable = None,
private_attributes_kwargs: Dict = {},
single_col_cert_common_name: str = None,
log_metric_callback: Callable = None,
**kwargs) -> None:
def __init__(
self,
aggregator_uuid: str,
federation_uuid: str,
authorized_cols: List,
flow: Any,
rounds_to_train: int = 1,
checkpoint: bool = False,
private_attributes_callable: Callable = None,
private_attributes_kwargs: Dict = {},
single_col_cert_common_name: str = None,
log_metric_callback: Callable = None,
**kwargs,
) -> None:

self.logger = getLogger(__name__)

Expand Down Expand Up @@ -83,7 +85,8 @@ def __init__(self,
self.flow._foreach_methods = []
self.logger.info("MetaflowInterface creation.")
self.flow._metaflow_interface = MetaflowInterface(
self.flow.__class__, "single_process")
self.flow.__class__, "single_process"
)
self.flow._run_id = self.flow._metaflow_interface.create_run()
self.flow.runtime = FederatedRuntime()
self.flow.runtime.aggregator = "aggregator"
Expand Down Expand Up @@ -114,9 +117,9 @@ def __set_attributes_to_clone(self, clone: Any) -> None:
for name, attr in self.__private_attrs.items():
setattr(clone, name, attr)

def __delete_agg_attrs_from_clone(self,
clone: Any,
replace_str: str = None) -> None:
def __delete_agg_attrs_from_clone(
self, clone: Any, replace_str: str = None
) -> None:
"""
Remove aggregator private attributes from FLSpec clone before
transition from Aggregator step to collaborator steps.
Expand All @@ -127,7 +130,8 @@ def __delete_agg_attrs_from_clone(self,
for attr_name in self.__private_attrs:
if hasattr(clone, attr_name):
self.__private_attrs.update(
{attr_name: getattr(clone, attr_name)})
{attr_name: getattr(clone, attr_name)}
)
if replace_str:
setattr(clone, attr_name, replace_str)
else:
Expand All @@ -139,7 +143,8 @@ def _log_big_warning(self) -> None:
f"\n{the_dragon}\nYOU ARE RUNNING IN SINGLE COLLABORATOR CERT MODE! THIS IS"
f" NOT PROPER PKI AND "
f"SHOULD ONLY BE USED IN DEVELOPMENT SETTINGS!!!! YE HAVE BEEN"
f" WARNED!!!")
f" WARNED!!!"
)

@staticmethod
def _get_sleep_time() -> int:
Expand Down Expand Up @@ -180,33 +185,36 @@ def run_flow(self) -> None:
if len_connected_collabs < len_sel_collabs:
# Waiting for collaborators to connect.
self.logger.info(
"Waiting for " +
f"{len_connected_collabs}/{len_sel_collabs}" +
" collaborators to connect...")
"Waiting for "
+ f"{len_connected_collabs}/{len_sel_collabs}"
+ " collaborators to connect..."
)
elif self.tasks_sent_to_collaborators != len_sel_collabs:
self.logger.info(
"Waiting for " +
f"{self.tasks_sent_to_collaborators}/{len_sel_collabs}"
+ " to make requests for tasks...")
"Waiting for "
+ f"{self.tasks_sent_to_collaborators}/{len_sel_collabs}"
+ " to make requests for tasks..."
)
else:
# Waiting for selected collaborators to send the results.
self.logger.info(
"Waiting for " +
f"{self.collaborators_counter}/{len_sel_collabs}" +
" collaborators to send results...")
"Waiting for "
+ f"{self.collaborators_counter}/{len_sel_collabs}"
+ " collaborators to send results..."
)
time.sleep(Aggregator._get_sleep_time())

self.collaborator_task_results.clear()
f_name = self.next_step
if hasattr(self, "instance_snapshot"):
self.flow.restore_instance_snapshot(
self.flow, list(self.instance_snapshot))
self.flow, list(self.instance_snapshot)
)
delattr(self, "instance_snapshot")

def call_checkpoint(self,
ctx: Any,
f: Callable,
stream_buffer: bytes = None) -> None:
def call_checkpoint(
self, ctx: Any, f: Callable, stream_buffer: bytes = None
) -> None:
"""
Perform checkpoint task.
Expand All @@ -233,8 +241,9 @@ def call_checkpoint(self,
f = pickle.loads(f)
if isinstance(stream_buffer, bytes):
# Set stream buffer as function parameter
setattr(f.__func__, "_stream_buffer",
pickle.loads(stream_buffer))
setattr(
f.__func__, "_stream_buffer", pickle.loads(stream_buffer)
)

checkpoint(ctx, f)

Expand Down Expand Up @@ -269,22 +278,34 @@ def get_tasks(self, collaborator_name: str) -> Tuple:
)
# FIXME: 0, and "" instead of None is just for protobuf compatibility.
# Cleaner solution?
return 0, "", None, Aggregator._get_sleep_time(
), self.time_to_quit
return (
0,
"",
None,
Aggregator._get_sleep_time(),
self.time_to_quit,
)

# If not time to quit then sleep for 10 seconds
time.sleep(Aggregator._get_sleep_time())

# Get collaborator step, and clone for requesting collaborator
next_step, clone = self.__collaborator_tasks_queue[
collaborator_name].get()
collaborator_name
].get()

self.tasks_sent_to_collaborators += 1
self.logger.info(
"Sending tasks to collaborator" +
f" {collaborator_name} for round {self.current_round}...")
return self.current_round, next_step, pickle.dumps(
clone), 0, self.time_to_quit
"Sending tasks to collaborator"
+ f" {collaborator_name} for round {self.current_round}..."
)
return (
self.current_round,
next_step,
pickle.dumps(clone),
0,
self.time_to_quit,
)

def do_task(self, f_name: str) -> Any:
"""
Expand All @@ -310,7 +331,8 @@ def do_task(self, f_name: str) -> Any:
f()
# Take the checkpoint of "end" step
self.__delete_agg_attrs_from_clone(
self.flow, "Private attributes: Not Available.")
self.flow, "Private attributes: Not Available."
)
self.call_checkpoint(self.flow, f)
self.__set_attributes_to_clone(self.flow)
# Check if all rounds of external loop is executed
Expand Down Expand Up @@ -347,7 +369,8 @@ def do_task(self, f_name: str) -> Any:
f(*selected_clones)

self.__delete_agg_attrs_from_clone(
self.flow, "Private attributes: Not Available.")
self.flow, "Private attributes: Not Available."
)
# Take the checkpoint of executed step
self.call_checkpoint(self.flow, f)
self.__set_attributes_to_clone(self.flow)
Expand All @@ -367,7 +390,8 @@ def do_task(self, f_name: str) -> Any:
self.clones_dict, self.instance_snapshot, self.kwargs = temp

self.selected_collaborators = getattr(
self.flow, self.kwargs["foreach"])
self.flow, self.kwargs["foreach"]
)
else:
self.kwargs = self.flow.execute_task_args[3]

Expand All @@ -379,8 +403,13 @@ def do_task(self, f_name: str) -> Any:

return f_name if f_name != "end" else None

def send_task_results(self, collab_name: str, round_number: int,
next_step: str, clone_bytes: bytes) -> None:
def send_task_results(
self,
collab_name: str,
round_number: int,
next_step: str,
clone_bytes: bytes,
) -> None:
"""
After collaborator execution, collaborator will call this function via gRPc
to send next function.
Expand All @@ -398,10 +427,13 @@ def send_task_results(self, collab_name: str, round_number: int,
if round_number is not self.current_round:
self.logger.warning(
f"Collaborator {collab_name} is reporting results"
f" for the wrong round: {round_number}. Ignoring...")
f" for the wrong round: {round_number}. Ignoring..."
)
else:
self.logger.info(f"Collaborator {collab_name} sent task results"
f" for round {round_number}.")
self.logger.info(
f"Collaborator {collab_name} sent task results"
f" for round {round_number}."
)
# Unpickle the clone (FLSpec object)
clone = pickle.loads(clone_bytes)
# Update the clone in clones_dict dictionary
Expand All @@ -416,11 +448,13 @@ def send_task_results(self, collab_name: str, round_number: int,
self.collaborator_task_results.set()
# Empty tasks_sent_to_collaborators list for next time.
if self.tasks_sent_to_collaborators == len(
self.selected_collaborators):
self.selected_collaborators
):
self.tasks_sent_to_collaborators = 0

def valid_collaborator_cn_and_id(self, cert_common_name: str,
collaborator_common_name: str) -> bool:
def valid_collaborator_cn_and_id(
self, cert_common_name: str, collaborator_common_name: str
) -> bool:
"""
Determine if the collaborator certificate and ID are valid for this federation.
Expand All @@ -437,13 +471,17 @@ def valid_collaborator_cn_and_id(self, cert_common_name: str,
# FIXME: "" instead of None is just for protobuf compatibility.
# Cleaner solution?
if self.single_col_cert_common_name == "":
return (cert_common_name == collaborator_common_name and
collaborator_common_name in self.authorized_cols)
return (
cert_common_name == collaborator_common_name
and collaborator_common_name in self.authorized_cols
)
# otherwise, common_name must be in whitelist and
# collaborator_common_name must be in authorized_cols
else:
return (cert_common_name == self.single_col_cert_common_name and
collaborator_common_name in self.authorized_cols)
return (
cert_common_name == self.single_col_cert_common_name
and collaborator_common_name in self.authorized_cols
)

def all_quit_jobs_sent(self) -> bool:
"""Assert all quit jobs are sent to collaborators."""
Expand Down
59 changes: 36 additions & 23 deletions openfl/experimental/component/collaborator/collaborator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@ class Collaborator:
\* - Plan setting.
"""

def __init__(self,
collaborator_name: str,
aggregator_uuid: str,
federation_uuid: str,
client: Any,
private_attributes_callable: Any = None,
private_attributes_kwargs: Dict = {},
**kwargs) -> None:
def __init__(
self,
collaborator_name: str,
aggregator_uuid: str,
federation_uuid: str,
client: Any,
private_attributes_callable: Any = None,
private_attributes_kwargs: Dict = {},
**kwargs,
) -> None:

self.name = collaborator_name
self.aggregator_uuid = aggregator_uuid
Expand Down Expand Up @@ -78,9 +80,9 @@ def __set_attributes_to_clone(self, clone: Any) -> None:
for name, attr in self.__private_attrs.items():
setattr(clone, name, attr)

def __delete_agg_attrs_from_clone(self,
clone: Any,
replace_str: str = None) -> None:
def __delete_agg_attrs_from_clone(
self, clone: Any, replace_str: str = None
) -> None:
"""
Remove aggregator private attributes from FLSpec clone before
transition from Aggregator step to collaborator steps
Expand All @@ -98,14 +100,16 @@ def __delete_agg_attrs_from_clone(self,
for attr_name in self.__private_attrs:
if hasattr(clone, attr_name):
self.__private_attrs.update(
{attr_name: getattr(clone, attr_name)})
{attr_name: getattr(clone, attr_name)}
)
if replace_str:
setattr(clone, attr_name, replace_str)
else:
delattr(clone, attr_name)

def call_checkpoint(self, ctx: Any, f: Callable,
stream_buffer: Any) -> None:
def call_checkpoint(
self, ctx: Any, f: Callable, stream_buffer: Any
) -> None:
"""
Call checkpoint gRPC.
Expand All @@ -117,9 +121,12 @@ def call_checkpoint(self, ctx: Any, f: Callable,
Returns:
None
"""
self.client.call_checkpoint(self.name,
pickle.dumps(ctx), pickle.dumps(f),
pickle.dumps(stream_buffer))
self.client.call_checkpoint(
self.name,
pickle.dumps(ctx),
pickle.dumps(f),
pickle.dumps(stream_buffer),
)

def run(self) -> None:
"""
Expand Down Expand Up @@ -156,10 +163,13 @@ def send_task_results(self, next_step: str, clone: Any) -> None:
Returns:
None
"""
self.logger.info(f"Round {self.round_number},"
f" collaborator {self.name} is sending results...")
self.client.send_task_results(self.name, self.round_number, next_step,
pickle.dumps(clone))
self.logger.info(
f"Round {self.round_number},"
f" collaborator {self.name} is sending results..."
)
self.client.send_task_results(
self.name, self.round_number, next_step, pickle.dumps(clone)
)

def get_tasks(self) -> Tuple:
"""
Expand All @@ -176,7 +186,9 @@ def get_tasks(self) -> Tuple:
"""
self.logger.info("Waiting for tasks...")
temp = self.client.get_tasks(self.name)
self.round_number, next_step, clone_bytes, sleep_time, time_to_quit = temp
self.round_number, next_step, clone_bytes, sleep_time, time_to_quit = (
temp
)

return next_step, pickle.loads(clone_bytes), sleep_time, time_to_quit

Expand All @@ -201,7 +213,8 @@ def do_task(self, f_name: str, ctx: Any) -> Tuple:
f()
# Checkpoint the function
self.__delete_agg_attrs_from_clone(
ctx, "Private attributes: Not Available.")
ctx, "Private attributes: Not Available."
)
self.call_checkpoint(ctx, f, f._stream_buffer)
self.__set_attributes_to_clone(ctx)

Expand Down
Loading

0 comments on commit 5e3af9b

Please sign in to comment.