From b77f518160357aee00121b62ca5b19488f4f3c42 Mon Sep 17 00:00:00 2001 From: yes Date: Mon, 18 Nov 2024 05:09:44 -0800 Subject: [PATCH] removed python native api dependency in interactive api experiment t.py Signed-off-by: yes --- .../interface/interactive_api/experiment.py | 119 +++++++++++++++++- 1 file changed, 117 insertions(+), 2 deletions(-) diff --git a/openfl/interface/interactive_api/experiment.py b/openfl/interface/interactive_api/experiment.py index 6e09c5bbeb..362cf88945 100644 --- a/openfl/interface/interactive_api/experiment.py +++ b/openfl/interface/interactive_api/experiment.py @@ -5,6 +5,7 @@ """Python low-level API module.""" import os import time +import flatten_json from collections import defaultdict from copy import deepcopy from logging import getLogger @@ -21,7 +22,6 @@ from openfl.interface.aggregation_functions import AggregationFunction, WeightedAverage from openfl.interface.cli import setup_logging from openfl.interface.cli_helper import WORKSPACE -from openfl.native import update_plan from openfl.utilities.split import split_tensor_dict_for_holdouts from openfl.utilities.utils import rmtree from openfl.utilities.workspace import dump_requirements_file @@ -596,7 +596,7 @@ def _prepare_plan( } if override_config: - self.plan = update_plan(override_config, plan=self.plan, resolve=False) + self.plan = self.update_plan(override_config, plan=self.plan, resolve=False) def _serialize_interface_objects(self, model_provider, task_keeper, data_loader, task_assigner): """ @@ -631,6 +631,121 @@ def _serialize_interface_objects(self, model_provider, task_keeper, data_loader, for filename, object_ in obj_dict.items(): serializer.serialize(object_, self.plan.config["api_layer"]["settings"][filename]) + def setup_plan(self, log_level="CRITICAL"): + """ + Dump the plan with all defaults + overrides set. + + Args: + log_level (str, optional): The log level Whether to save the plan to + disk. + Defaults to 'CRITICAL'. + + Returns: + plan: Plan object. + """ + plan_config = "plan/plan.yaml" + cols_config = "plan/cols.yaml" + data_config = "plan/data.yaml" + + current_level = self.logger.root.level + getLogger().setLevel(log_level) + plan = Plan.parse( + plan_config_path=Path(plan_config), + cols_config_path=Path(cols_config), + data_config_path=Path(data_config), + resolve=False, + ) + getLogger().setLevel(current_level) + + return plan + + def flatten(self, config, return_complete=False): + """ + Flatten nested config. + + Args: + config (dict): The configuration dictionary to flatten. + return_complete (bool, optional): Whether to return the complete + flattened config. Defaults to False. + + Returns: + flattened_config (dict): The flattened configuration dictionary. + """ + flattened_config = flatten_json.flatten(config, ".") + if not return_complete: + keys_to_remove = [k for k, v in flattened_config.items() if ("defaults" in k or v is None)] + else: + keys_to_remove = [k for k, v in flattened_config.items() if v is None] + for k in keys_to_remove: + del flattened_config[k] + + return flattened_config + + def unflatten(self, config, separator="."): + """Unfolds `config` settings that have `separator` in their names. + + Args: + config (dict): The flattened configuration dictionary to unfold. + separator (str, optional): The separator used in the flattened config. + Defaults to '.'. + + Returns: + config (dict): The unfolded configuration dictionary. + """ + config = flatten_json.unflatten_list(config, separator) + return config + + + def update_plan(self, override_config, plan=None, resolve=True): + """Updates the plan with the provided override and saves it to disk. + + Args: + override_config (dict): A dictionary of values to override in the plan. + plan (Plan, optional): The plan to update. If None, a new plan is set + up. Defaults to None. + resolve (bool, optional): Whether to resolve the plan. Defaults to + True. + + Returns: + plan (object): The updated plan. + """ + if plan is None: + plan = self.setup_plan() + flat_plan_config = self.flatten(plan.config, return_complete=True) + + org_list_keys_with_count = {} + for k in flat_plan_config: + k_split = k.rsplit(".", 1) + if k_split[1].isnumeric(): + if k_split[0] in org_list_keys_with_count: + org_list_keys_with_count[k_split[0]] += 1 + else: + org_list_keys_with_count[k_split[0]] = 1 + + for key, val in override_config.items(): + if key in org_list_keys_with_count: + # remove old list corresponding to this key entirely + for idx in range(org_list_keys_with_count[key]): + del flat_plan_config[f"{key}.{idx}"] + self.logger.info("Updating %s to %s... ", key, val) + elif key in flat_plan_config: + self.logger.info("Updating %s to %s... ", key, val) + else: + # TODO: We probably need to validate the new key somehow + self.logger.info( + "Did not find %s in config. Make sure it should exist. Creating...", + key, + ) + if type(val) is list: + for idx, v in enumerate(val): + flat_plan_config[f"{key}.{idx}"] = v + else: + flat_plan_config[key] = val + + plan.config = self.unflatten(flat_plan_config, ".") + if resolve: + plan.resolve() + return plan class TaskKeeper: """Task keeper class.