Skip to content

Commit

Permalink
Code adaptations for new rules
Browse files Browse the repository at this point in the history
Signed-off-by: Shah, Karan <[email protected]>
  • Loading branch information
MasterSkepticista committed Nov 28, 2024
1 parent 0f7dc74 commit e976d19
Show file tree
Hide file tree
Showing 26 changed files with 83 additions and 293 deletions.
2 changes: 1 addition & 1 deletion openfl/component/director/director.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def update_envoy_status(
if not shard_info:
raise ShardNotFoundError(f"Unknown shard {envoy_name}")

shard_info["is_online"]: True
shard_info["is_online"] = True
shard_info["is_experiment_running"] = is_experiment_running
shard_info["valid_duration"] = 2 * self.envoy_health_check_period
shard_info["last_updated"] = time.time()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(
rounds_to_train: int = 1,
checkpoint: bool = False,
private_attributes_callable: Callable = None,
private_attributes_kwargs: Dict = {},
private_attributes: Dict = {},
private_attributes_kwargs: Dict = None,
private_attributes: Dict = None,
single_col_cert_common_name: str = None,
log_metric_callback: Callable = None,
**kwargs,
Expand Down Expand Up @@ -232,7 +232,7 @@ def call_checkpoint(self, ctx: Any, f: Callable, stream_buffer: bytes = None) ->
f = pickle.loads(f)
if isinstance(stream_buffer, bytes):
# Set stream buffer as function parameter
setattr(f.__func__, "_stream_buffer", pickle.loads(stream_buffer))
f.__func__._stream_buffer = pickle.loads(stream_buffer)

checkpoint(ctx, f)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def __init__(
federation_uuid: str,
client: Any,
private_attributes_callable: Any = None,
private_attributes_kwargs: Dict = {},
private_attributes: Dict = {},
private_attributes_kwargs: Dict = None,
private_attributes: Dict = None,
**kwargs,
) -> None:
self.name = collaborator_name
Expand Down
98 changes: 1 addition & 97 deletions openfl/experimental/workflow/interface/cli/cli_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
"""Module with auxiliary CLI helper functions."""
import os
import re
import shutil
from itertools import islice
from os import environ, stat
from os import environ
from pathlib import Path
from sys import argv

Expand All @@ -31,20 +30,6 @@ def pretty(o):
echo(style(f"{k:<{m}} : ", fg="blue") + style(f"{v}", fg="cyan"))


def tree(path):
"""Print current directory file tree."""
echo(f"+ {path}")

for path in sorted(path.rglob("*")):
depth = len(path.relative_to(path).parts)
space = " " * depth

if path.is_file():
echo(f"{space}f {path.name}")
else:
echo(f"{space}d {path.name}")


def print_tree(
dir_path: Path,
level: int = -1,
Expand Down Expand Up @@ -91,87 +76,6 @@ def inner(dir_path: Path, prefix: str = "", level=-1):
echo(f"\n{directories} directories" + (f", {files} files" if files else ""))


def copytree(
src,
dst,
symlinks=False,
ignore=None,
ignore_dangling_symlinks=False,
dirs_exist_ok=False,
):
"""From Python 3.8 'shutil' which include 'dirs_exist_ok' option."""

with os.scandir(src) as itr:
entries = list(itr)

copy_function = shutil.copy2

def _copytree():
if ignore is not None:
ignored_names = ignore(os.fspath(src), [x.name for x in entries])
else:
ignored_names = set()

os.makedirs(dst, exist_ok=dirs_exist_ok)
errors = []
use_srcentry = copy_function is shutil.copy2 or copy_function is shutil.copy

for srcentry in entries:
if srcentry.name in ignored_names:
continue
srcname = os.path.join(src, srcentry.name)
dstname = os.path.join(dst, srcentry.name)
srcobj = srcentry if use_srcentry else srcname
try:
is_symlink = srcentry.is_symlink()
if is_symlink and os.name == "nt":
lstat = srcentry.stat(follow_symlinks=False)
if lstat.st_reparse_tag == stat.IO_REPARSE_TAG_MOUNT_POINT:
is_symlink = False
if is_symlink:
linkto = os.readlink(srcname)
if symlinks:
os.symlink(linkto, dstname)
shutil.copystat(srcobj, dstname, follow_symlinks=not symlinks)
else:
if not os.path.exists(linkto) and ignore_dangling_symlinks:
continue
if srcentry.is_dir():
copytree(
srcobj,
dstname,
symlinks,
ignore,
dirs_exist_ok=dirs_exist_ok,
)
else:
copy_function(srcobj, dstname)
elif srcentry.is_dir():
copytree(
srcobj,
dstname,
symlinks,
ignore,
dirs_exist_ok=dirs_exist_ok,
)
else:
copy_function(srcobj, dstname)
except OSError as why:
errors.append((srcname, dstname, str(why)))
except Exception as err:
errors.extend(err.args[0])
try:
shutil.copystat(src, dst)
except OSError as why:
if getattr(why, "winerror", None) is None:
errors.append((src, dst, str(why)))
if errors:
raise Exception(errors)
return dst

return _copytree()


def get_workspace_parameter(name):
"""Get a parameter from the workspace config file (.workspace)."""
# Update the .workspace file to show the current workspace plan
Expand Down
2 changes: 1 addition & 1 deletion openfl/experimental/workflow/interface/cli/collaborator.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def certify_(collaborator_name, silent, request_pkg, import_):
certify(collaborator_name, silent, request_pkg, import_)


def certify(collaborator_name, silent, request_pkg=None, import_=False):
def certify(collaborator_name, silent, request_pkg=None, import_=False): # noqa C901
"""Sign/certify collaborator certificate key pair."""

common_name = f"{collaborator_name}"
Expand Down
3 changes: 1 addition & 2 deletions openfl/experimental/workflow/interface/cli/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@
from tempfile import mkdtemp
from typing import Tuple

from click import Choice
from click import Choice, confirm, echo, group, option, pass_context, style
from click import Path as ClickPath
from click import confirm, echo, group, option, pass_context, style
from cryptography.hazmat.primitives import serialization

from openfl.cryptography.ca import generate_root_cert, generate_signing_csr, sign_certificate
Expand Down
4 changes: 2 additions & 2 deletions openfl/experimental/workflow/placement/placement.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def wrapper(*args, **kwargs):
print(f"\nCalling {f.__name__}")
with RedirectStdStreamContext() as context_stream:
# context_stream capture stdout and stderr for the function f.__name__
setattr(wrapper, "_stream_buffer", context_stream)
wrapper._stream_buffer = context_stream
f(*args, **kwargs)

return wrapper
Expand Down Expand Up @@ -92,7 +92,7 @@ def wrapper(*args, **kwargs):
print(f"\nCalling {f.__name__}")
with RedirectStdStreamContext() as context_stream:
# context_stream capture stdout and stderr for the function f.__name__
setattr(wrapper, "_stream_buffer", context_stream)
wrapper._stream_buffer = context_stream
f(*args, **kwargs)

return wrapper
3 changes: 1 addition & 2 deletions openfl/experimental/workflow/runtime/federated_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from openfl.experimental.workflow.runtime.runtime import Runtime

if TYPE_CHECKING:
from openfl.experimental.workflow.interface import Aggregator
from openfl.experimental.workflow.interface import Collaborator
from openfl.experimental.workflow.interface import Aggregator, Collaborator

from typing import List, Type

Expand Down
2 changes: 1 addition & 1 deletion openfl/experimental/workflow/runtime/local_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def __get_aggregator_object(self, aggregator: Type[Aggregator]) -> Any:
)

interface_module = importlib.import_module("openfl.experimental.workflow.interface")
aggregator_class = getattr(interface_module, "Aggregator")
aggregator_class = interface_module.Aggregator

aggregator_actor = ray.remote(aggregator_class).options(
num_cpus=agg_cpus, num_gpus=agg_gpus
Expand Down
2 changes: 1 addition & 1 deletion openfl/experimental/workflow/utilities/runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def check_resource_allocation(num_gpus, each_participant_gpu_usage):
# buffer to cycle though since need_assigned will change sizes as we
# assign participants
current_dict = need_assigned.copy()
for i, (participant_name, participant_gpu_usage) in enumerate(current_dict.items()):
for (participant_name, participant_gpu_usage) in current_dict.items():
if gpu == 0:
break
if gpu < participant_gpu_usage:
Expand Down
4 changes: 2 additions & 2 deletions openfl/experimental/workflow/utilities/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
flow_obj,
run_id,
show_html=False,
ds_root=f"{Path.home()}/.metaflow",
ds_root=None,
):
"""Initializes the InspectFlow with a flow object, run ID, an optional
flag to show the UI in a web browser, and an optional root directory
Expand All @@ -41,7 +41,7 @@ def __init__(
ds_root (str, optional): The root directory for the data store.
Defaults to "~/.metaflow".
"""
self.ds_root = ds_root
self.ds_root = ds_root or f"{Path.home()}/.metaflow"
self.show_html = show_html
self.run_id = run_id
self.flow_name = flow_obj.__class__.__name__
Expand Down
9 changes: 2 additions & 7 deletions openfl/experimental/workflow/workspace_export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,7 @@ def generate_plan_yaml(self):
"""
Generates plan.yaml
"""
flspec = getattr(
importlib.import_module("openfl.experimental.workflow.interface"), "FLSpec"
)
flspec = importlib.import_module("openfl.experimental.workflow.interface").FLSpec
# Get flow classname
_, self.flow_class_name = self.__get_class_name_and_sourcecode_from_parent_class(flspec)
# Get expected arguments of flow class
Expand Down Expand Up @@ -343,10 +341,7 @@ def generate_data_yaml(self):

# If flow classname is not yet found
if not hasattr(self, "flow_class_name"):
flspec = getattr(
importlib.import_module("openfl.experimental.workflow.interface"),
"FLSpec",
)
flspec = importlib.import_module("openfl.experimental.workflow.interface").FLSpec
_, self.flow_class_name = self.__get_class_name_and_sourcecode_from_parent_class(flspec)

# Import flow class
Expand Down
2 changes: 1 addition & 1 deletion openfl/federated/plan/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def ignore_aliases(self, data):
yaml_path.write_text(dump(config))

@staticmethod
def parse(
def parse( # noqa: C901
plan_config_path: Path,
cols_config_path: Path = None,
data_config_path: Path = None,
Expand Down
32 changes: 20 additions & 12 deletions openfl/federated/task/runner_xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def __init__(self, **kwargs):
Attributes:
global_model (xgb.Booster): The global XGBoost model.
required_tensorkeys_for_function (dict): A dictionary to store required tensor keys for each function.
required_tensorkeys_for_function (dict): A dictionary to store required tensor keys
for each function.
"""
super().__init__(**kwargs)
self.global_model = None
Expand All @@ -58,11 +59,13 @@ def rebuild_model(self, input_tensor_dict):
"""
Rebuilds the model using the provided input tensor dictionary.
This method checks if the 'local_tree' key in the input tensor dictionary is either a non-empty numpy array
If this condition is met, it updates the internal tensor dictionary with the provided input.
This method checks if the 'local_tree' key in the input tensor dictionary is either a
non-empty numpy array. If this condition is met, it updates the internal tensor dictionary
with the provided input.
Parameters:
input_tensor_dict (dict): A dictionary containing tensor data. It must include the key 'local_tree'
input_tensor_dict (dict): A dictionary containing tensor data.
It must include the key 'local_tree'
Returns:
None
Expand Down Expand Up @@ -90,11 +93,13 @@ def validate_task(self, col_name, round_num, input_tensor_dict, **kwargs):
"""
data = self.data_loader.get_valid_dmatrix()

# during agg validation, self.bst will still be None. during local validation, it will have a value - no need to rebuild
# during agg validation, self.bst will still be None. during local validation,
# it will have a value - no need to rebuild
if self.bst is None:
self.rebuild_model(input_tensor_dict)

# if self.bst is still None after rebuilding, then there was no initial global model, so set metric to 0
# if self.bst is still None after rebuilding, then there was no initial global model, so
# set metric to 0
if self.bst is None:
# for first round agg validation, there is no model so set metric to 0
# TODO: this is not robust, especially if using a loss metric
Expand Down Expand Up @@ -188,16 +193,18 @@ def get_tensor_dict(self, with_opt_vars=False):
"""
Retrieves the tensor dictionary containing the model's tree structure.
This method returns a dictionary with the key 'local_tree', which contains the model's tree structure as a numpy array.
If the model has not been initialized (`self.bst` is None), it returns an empty numpy array.
If the global model is not set or is empty, it returns the entire model as a numpy array.
Otherwise, it returns only the trees added in the latest training session.
This method returns a dictionary with the key 'local_tree', which contains the model's tree
structure as a numpy array. If the model has not been initialized (`self.bst` is None), it
returns an empty numpy array. If the global model is not set or is empty, it returns the
entire model as a numpy array. Otherwise, it returns only the trees added in the latest
training session.
Parameters:
with_opt_vars (bool): N/A for XGBoost (Default=False).
Returns:
dict: A dictionary with the key 'local_tree' containing the model's tree structure as a numpy array.
dict: A dictionary with the key 'local_tree' containing the model's tree structure as a
numpy array.
"""

if self.bst is None:
Expand Down Expand Up @@ -377,7 +384,8 @@ def validate_(self, data) -> Metric:
Validate the XGBoost model.
Args:
validation_dataloader (dict): A dictionary containing the validation data with keys 'dmatrix' and 'labels'.
validation_dataloader (dict): A dictionary containing the validation data with keys
'dmatrix' and 'labels'.
Returns:
Metric: A Metric object containing the validation accuracy.
Expand Down
3 changes: 2 additions & 1 deletion openfl/interface/aggregation_functions/fed_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def append_trees(global_model, local_trees):
Parameters:
global_model (dict): A dictionary representing the global model.
local_trees (list): A list of dictionaries representing the local trees to be appended to the global model.
local_trees (list): A list of dictionaries representing the local trees to be appended to the
global model.
Returns:
dict: The updated global model with the local trees appended.
Expand Down
Loading

0 comments on commit e976d19

Please sign in to comment.