Skip to content

Commit

Permalink
removed python native api dependency in interactive api experiment t.py
Browse files Browse the repository at this point in the history
Signed-off-by: yes <[email protected]>
  • Loading branch information
tanwarsh committed Nov 18, 2024
1 parent 5e3c907 commit b77f518
Showing 1 changed file with 117 additions and 2 deletions.
119 changes: 117 additions & 2 deletions openfl/interface/interactive_api/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""Python low-level API module."""
import os
import time
import flatten_json
from collections import defaultdict
from copy import deepcopy
from logging import getLogger
Expand All @@ -21,7 +22,6 @@
from openfl.interface.aggregation_functions import AggregationFunction, WeightedAverage
from openfl.interface.cli import setup_logging
from openfl.interface.cli_helper import WORKSPACE
from openfl.native import update_plan
from openfl.utilities.split import split_tensor_dict_for_holdouts
from openfl.utilities.utils import rmtree
from openfl.utilities.workspace import dump_requirements_file
Expand Down Expand Up @@ -596,7 +596,7 @@ def _prepare_plan(
}

if override_config:
self.plan = update_plan(override_config, plan=self.plan, resolve=False)
self.plan = self.update_plan(override_config, plan=self.plan, resolve=False)

def _serialize_interface_objects(self, model_provider, task_keeper, data_loader, task_assigner):
"""
Expand Down Expand Up @@ -631,6 +631,121 @@ def _serialize_interface_objects(self, model_provider, task_keeper, data_loader,
for filename, object_ in obj_dict.items():
serializer.serialize(object_, self.plan.config["api_layer"]["settings"][filename])

def setup_plan(self, log_level="CRITICAL"):
"""
Dump the plan with all defaults + overrides set.
Args:
log_level (str, optional): The log level Whether to save the plan to
disk.
Defaults to 'CRITICAL'.
Returns:
plan: Plan object.
"""
plan_config = "plan/plan.yaml"
cols_config = "plan/cols.yaml"
data_config = "plan/data.yaml"

current_level = self.logger.root.level
getLogger().setLevel(log_level)
plan = Plan.parse(
plan_config_path=Path(plan_config),
cols_config_path=Path(cols_config),
data_config_path=Path(data_config),
resolve=False,
)
getLogger().setLevel(current_level)

return plan

def flatten(self, config, return_complete=False):
"""
Flatten nested config.
Args:
config (dict): The configuration dictionary to flatten.
return_complete (bool, optional): Whether to return the complete
flattened config. Defaults to False.
Returns:
flattened_config (dict): The flattened configuration dictionary.
"""
flattened_config = flatten_json.flatten(config, ".")
if not return_complete:
keys_to_remove = [k for k, v in flattened_config.items() if ("defaults" in k or v is None)]
else:
keys_to_remove = [k for k, v in flattened_config.items() if v is None]
for k in keys_to_remove:
del flattened_config[k]

return flattened_config

def unflatten(self, config, separator="."):
"""Unfolds `config` settings that have `separator` in their names.
Args:
config (dict): The flattened configuration dictionary to unfold.
separator (str, optional): The separator used in the flattened config.
Defaults to '.'.
Returns:
config (dict): The unfolded configuration dictionary.
"""
config = flatten_json.unflatten_list(config, separator)
return config


def update_plan(self, override_config, plan=None, resolve=True):
"""Updates the plan with the provided override and saves it to disk.
Args:
override_config (dict): A dictionary of values to override in the plan.
plan (Plan, optional): The plan to update. If None, a new plan is set
up. Defaults to None.
resolve (bool, optional): Whether to resolve the plan. Defaults to
True.
Returns:
plan (object): The updated plan.
"""
if plan is None:
plan = self.setup_plan()
flat_plan_config = self.flatten(plan.config, return_complete=True)

org_list_keys_with_count = {}
for k in flat_plan_config:
k_split = k.rsplit(".", 1)
if k_split[1].isnumeric():
if k_split[0] in org_list_keys_with_count:
org_list_keys_with_count[k_split[0]] += 1
else:
org_list_keys_with_count[k_split[0]] = 1

for key, val in override_config.items():
if key in org_list_keys_with_count:
# remove old list corresponding to this key entirely
for idx in range(org_list_keys_with_count[key]):
del flat_plan_config[f"{key}.{idx}"]
self.logger.info("Updating %s to %s... ", key, val)
elif key in flat_plan_config:
self.logger.info("Updating %s to %s... ", key, val)
else:
# TODO: We probably need to validate the new key somehow
self.logger.info(
"Did not find %s in config. Make sure it should exist. Creating...",
key,
)
if type(val) is list:
for idx, v in enumerate(val):
flat_plan_config[f"{key}.{idx}"] = v
else:
flat_plan_config[key] = val

plan.config = self.unflatten(flat_plan_config, ".")
if resolve:
plan.resolve()
return plan

class TaskKeeper:
"""Task keeper class.
Expand Down

0 comments on commit b77f518

Please sign in to comment.