From 5bbda0fbdc12a67fcfbf37d16bf357ad603ae257 Mon Sep 17 00:00:00 2001 From: Joe Kim Date: Mon, 7 Oct 2024 11:45:57 -0700 Subject: [PATCH] Provide a standard way to get data_loader in util contexts (#1076) - Added `openfl/utilities/dataloading.py` to provide a way to get data_loader from plan object. `get_dataloader` function will accept the plan object and options like `prefer_minimal` and `input_shape` to provide either full `DataLoader` with training capability or a light `MockDataLoader` with input_shape to make sure we can make task_runner instance without full data context. - Updated `openfl/interface/plan.py` and replaced the data_loader fetch logic to use the new `get_dataloader` - Updated `openfl/interface/model.py` and replaced the data_loader fetch logic to use the new `get_dataloader`. This will make the fx command `fx model save` to be able to save the model without any initial data being present in the workspace. Signed-off-by: Joe Kim --- openfl/interface/model.py | 25 ++++++++-- openfl/interface/plan.py | 10 ++-- openfl/utilities/dataloading.py | 87 +++++++++++++++++++++++++++++++++ openfl/utilities/mocks.py | 6 +++ 4 files changed, 119 insertions(+), 9 deletions(-) create mode 100644 openfl/utilities/dataloading.py diff --git a/openfl/interface/model.py b/openfl/interface/model.py index e45f85d4cb..3cadb25051 100644 --- a/openfl/interface/model.py +++ b/openfl/interface/model.py @@ -5,6 +5,7 @@ """Model CLI module.""" from logging import getLogger from pathlib import Path +from typing import Union from click import Path as ClickPath from click import confirm, group, option, pass_context, style @@ -12,6 +13,8 @@ from openfl.federated import Plan from openfl.pipelines import NoCompressionPipeline from openfl.protocols import utils +from openfl.utilities.click_types import InputSpec +from openfl.utilities.dataloading import get_dataloader from openfl.utilities.workspace import set_directory logger = getLogger(__name__) @@ -71,11 +74,22 @@ def model(context): default="plan/data.yaml", type=ClickPath(exists=True), ) +@option( + "-f", + "--input-shape", + cls=InputSpec, + required=False, + help="The input shape to the model. May be provided as a list:\n\n" + "--input-shape [1,28,28]\n\n" + "or as a dictionary for multihead models (must be passed in quotes):\n\n" + "--input-shape \"{'input_0': [1, 240, 240, 4],'output_1': [1, 240, 240, 1]}\"\n\n ", +) def save_( context, plan_config, cols_config, data_config, + input_shape, model_protobuf_path, output_filepath, ): @@ -103,7 +117,7 @@ def save_( context.obj["fail"] = True return - task_runner = get_model(plan_config, cols_config, data_config, model_protobuf_path) + task_runner = get_model(plan_config, cols_config, data_config, model_protobuf_path, input_shape) task_runner.save_native(output_filepath) logger.info("Saved model in native format: 🠆 %s", output_filepath) @@ -114,6 +128,7 @@ def get_model( cols_config: str, data_config: str, model_protobuf_path: str, + input_shape: Union[list, dict], ): """ Initialize TaskRunner and load it with provided model.pbuf. @@ -128,6 +143,11 @@ def get_model( cols_config (str): Authorized collaborator list. data_config (str): The data set/shard configuration file. model_protobuf_path (str): The model protobuf to convert. + input_shape (list | dict ?): + input_shape denoted by list notation `[a,b,c, ...]` or in case + of multihead models, dict object with individual layer keys such + as `{"input_0": [a,b,...], "output_1": [x,y,z, ...]}` + Defaults to `None`. Returns: task_runner (instance): TaskRunner instance. @@ -146,8 +166,7 @@ def get_model( cols_config_path=cols_config, data_config_path=data_config, ) - collaborator_name = list(plan.cols_data_paths)[0] - data_loader = plan.get_data_loader(collaborator_name) + data_loader = get_dataloader(plan, prefer_minimal=True, input_shape=input_shape) task_runner = plan.get_task_runner(data_loader=data_loader) model_protobuf_path = Path(model_protobuf_path).resolve() diff --git a/openfl/interface/plan.py b/openfl/interface/plan.py index 5536e04086..f7d65eb84a 100644 --- a/openfl/interface/plan.py +++ b/openfl/interface/plan.py @@ -20,7 +20,7 @@ from openfl.interface.cli_helper import get_workspace_parameter from openfl.protocols import utils from openfl.utilities.click_types import InputSpec -from openfl.utilities.mocks import MockDataLoader +from openfl.utilities.dataloading import get_dataloader from openfl.utilities.path_check import is_directory_traversal from openfl.utilities.split import split_tensor_dict_for_holdouts from openfl.utilities.utils import getfqdn_env @@ -170,11 +170,9 @@ def initialize( logger.info( "Attempting to generate initial model weights with" f" custom input shape {input_shape}" ) - data_loader = MockDataLoader(input_shape) - else: - # If feature shape is not provided, data is assumed to be present - collaborator_cname = list(plan.cols_data_paths)[0] - data_loader = plan.get_data_loader(collaborator_cname) + + data_loader = get_dataloader(plan, prefer_minimal=True, input_shape=input_shape) + task_runner = plan.get_task_runner(data_loader) tensor_pipe = plan.get_tensor_pipe() diff --git a/openfl/utilities/dataloading.py b/openfl/utilities/dataloading.py new file mode 100644 index 0000000000..d1c047c3a8 --- /dev/null +++ b/openfl/utilities/dataloading.py @@ -0,0 +1,87 @@ +# Copyright 2020-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +import os +import zipfile +from typing import Union + +from openfl.federated import Plan +from openfl.federated.data.loader import DataLoader +from openfl.utilities.mocks import MockDataLoader + + +def get_dataloader( + plan: Plan, + prefer_minimal: bool = False, + input_shape: Union[list, dict] = None, + collaborator_index: int = 0, +) -> DataLoader: + """Get dataloader instance from plan + + NOTE: if `prefer_minimal` is False, cwd must be the workspace directory + because we need to construct dataloader from actual collaborator data path + with actual data present. + + Args: + plan (Plan): + plan object linked with the dataloader + prefer_minimal (bool ?): + prefer to use MockDataLoader which can be used to more easily + instantiate task_runner without any initial data. + Default to `False`. + input_shape (list | dict ?): + input_shape denoted by list notation `[a,b,c, ...]` or in case + of multihead models, dict object with individual layer keys such + as `{"input_0": [a,b,...], "output_1": [x,y,z, ...]}` + Defaults to `None`. + collaborator_index (int ?): + which collaborator should be used for initializing dataloader + among collaborators specified in plan/data.yaml. + Defaults to `0`. + + Returns: + data_loader (DataLoader): DataLoader instance + """ + + # if specified, try to use minimal dataloader + if prefer_minimal: + # if input_shape not given, try to ascertain input_shape from plan + if not input_shape and "input_shape" in plan.config["data_loader"]["settings"]: + input_shape = plan.config["data_loader"]["settings"]["input_shape"] + + # input_shape is resolved; we can use the minimal dataloader intended + # for util contexts which does not need a full dataloader with data + if input_shape: + data_loader: DataLoader = MockDataLoader(input_shape) + # generically inherit all attributes from data_loader.settings + for key, value in plan.config["data_loader"]["settings"].items(): + setattr(data_loader, key, value) + return data_loader + + # Fallback; try to get a dataloader by contructing it from the collaborator + # data directory path present in the the current workspace + + collaborator_names = list(plan.cols_data_paths) + collatorators_count = len(collaborator_names) + + if collaborator_index >= collatorators_count: + raise Exception( + f"Unable to construct full dataloader from collab_index={collaborator_index} " + f"when the plan has {collatorators_count} as total collaborator count. " + f"Please check plan/data.yaml file for current collaborator entries." + ) + + collaborator_name = collaborator_names[collaborator_index] + collaborator_data_path = plan.cols_data_paths[collaborator_name] + + # use seed_data provided by data_loader config if available + if "seed_data" in plan.config["data_loader"]["settings"] and not os.path.isdir( + collaborator_data_path + ): + os.makedirs(collaborator_data_path) + sample_data_zip_file = plan.config["data_loader"]["settings"]["seed_data"] + with zipfile.ZipFile(sample_data_zip_file, "r") as zip_ref: + zip_ref.extractall(collaborator_data_path) + + data_loader = plan.get_data_loader(collaborator_name) + + return data_loader diff --git a/openfl/utilities/mocks.py b/openfl/utilities/mocks.py index 33afee626a..9b87711845 100644 --- a/openfl/utilities/mocks.py +++ b/openfl/utilities/mocks.py @@ -13,3 +13,9 @@ def __init__(self, feature_shape): def get_feature_shape(self): return self.feature_shape + + def get_train_data_size(self): + return 0 + + def get_valid_data_size(self): + return 0