Skip to content

Commit

Permalink
Provide a standard way to get data_loader in util contexts (#1076)
Browse files Browse the repository at this point in the history
- Added `openfl/utilities/dataloading.py` to provide a way to get
data_loader from plan object. `get_dataloader` function will accept
the plan object and options like `prefer_minimal` and `input_shape`
to provide either full `DataLoader` with training capability or a light
`MockDataLoader` with input_shape to make sure we can make
task_runner instance without full data context.

- Updated `openfl/interface/plan.py` and replaced the data_loader
fetch logic to use the new `get_dataloader`

- Updated `openfl/interface/model.py` and replaced the data_loader
fetch logic to use the new `get_dataloader`. This will make the fx
command `fx model save` to be able to save the model without
any initial data being present in the workspace.

Signed-off-by: Joe Kim <[email protected]>
  • Loading branch information
jkk-intel authored Oct 7, 2024
1 parent 28b595d commit 5bbda0f
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 9 deletions.
25 changes: 22 additions & 3 deletions openfl/interface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
"""Model CLI module."""
from logging import getLogger
from pathlib import Path
from typing import Union

from click import Path as ClickPath
from click import confirm, group, option, pass_context, style

from openfl.federated import Plan
from openfl.pipelines import NoCompressionPipeline
from openfl.protocols import utils
from openfl.utilities.click_types import InputSpec
from openfl.utilities.dataloading import get_dataloader
from openfl.utilities.workspace import set_directory

logger = getLogger(__name__)
Expand Down Expand Up @@ -71,11 +74,22 @@ def model(context):
default="plan/data.yaml",
type=ClickPath(exists=True),
)
@option(
"-f",
"--input-shape",
cls=InputSpec,
required=False,
help="The input shape to the model. May be provided as a list:\n\n"
"--input-shape [1,28,28]\n\n"
"or as a dictionary for multihead models (must be passed in quotes):\n\n"
"--input-shape \"{'input_0': [1, 240, 240, 4],'output_1': [1, 240, 240, 1]}\"\n\n ",
)
def save_(
context,
plan_config,
cols_config,
data_config,
input_shape,
model_protobuf_path,
output_filepath,
):
Expand Down Expand Up @@ -103,7 +117,7 @@ def save_(
context.obj["fail"] = True
return

task_runner = get_model(plan_config, cols_config, data_config, model_protobuf_path)
task_runner = get_model(plan_config, cols_config, data_config, model_protobuf_path, input_shape)

task_runner.save_native(output_filepath)
logger.info("Saved model in native format: 🠆 %s", output_filepath)
Expand All @@ -114,6 +128,7 @@ def get_model(
cols_config: str,
data_config: str,
model_protobuf_path: str,
input_shape: Union[list, dict],
):
"""
Initialize TaskRunner and load it with provided model.pbuf.
Expand All @@ -128,6 +143,11 @@ def get_model(
cols_config (str): Authorized collaborator list.
data_config (str): The data set/shard configuration file.
model_protobuf_path (str): The model protobuf to convert.
input_shape (list | dict ?):
input_shape denoted by list notation `[a,b,c, ...]` or in case
of multihead models, dict object with individual layer keys such
as `{"input_0": [a,b,...], "output_1": [x,y,z, ...]}`
Defaults to `None`.
Returns:
task_runner (instance): TaskRunner instance.
Expand All @@ -146,8 +166,7 @@ def get_model(
cols_config_path=cols_config,
data_config_path=data_config,
)
collaborator_name = list(plan.cols_data_paths)[0]
data_loader = plan.get_data_loader(collaborator_name)
data_loader = get_dataloader(plan, prefer_minimal=True, input_shape=input_shape)
task_runner = plan.get_task_runner(data_loader=data_loader)

model_protobuf_path = Path(model_protobuf_path).resolve()
Expand Down
10 changes: 4 additions & 6 deletions openfl/interface/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from openfl.interface.cli_helper import get_workspace_parameter
from openfl.protocols import utils
from openfl.utilities.click_types import InputSpec
from openfl.utilities.mocks import MockDataLoader
from openfl.utilities.dataloading import get_dataloader
from openfl.utilities.path_check import is_directory_traversal
from openfl.utilities.split import split_tensor_dict_for_holdouts
from openfl.utilities.utils import getfqdn_env
Expand Down Expand Up @@ -170,11 +170,9 @@ def initialize(
logger.info(
"Attempting to generate initial model weights with" f" custom input shape {input_shape}"
)
data_loader = MockDataLoader(input_shape)
else:
# If feature shape is not provided, data is assumed to be present
collaborator_cname = list(plan.cols_data_paths)[0]
data_loader = plan.get_data_loader(collaborator_cname)

data_loader = get_dataloader(plan, prefer_minimal=True, input_shape=input_shape)

task_runner = plan.get_task_runner(data_loader)
tensor_pipe = plan.get_tensor_pipe()

Expand Down
87 changes: 87 additions & 0 deletions openfl/utilities/dataloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import os
import zipfile
from typing import Union

from openfl.federated import Plan
from openfl.federated.data.loader import DataLoader
from openfl.utilities.mocks import MockDataLoader


def get_dataloader(
plan: Plan,
prefer_minimal: bool = False,
input_shape: Union[list, dict] = None,
collaborator_index: int = 0,
) -> DataLoader:
"""Get dataloader instance from plan
NOTE: if `prefer_minimal` is False, cwd must be the workspace directory
because we need to construct dataloader from actual collaborator data path
with actual data present.
Args:
plan (Plan):
plan object linked with the dataloader
prefer_minimal (bool ?):
prefer to use MockDataLoader which can be used to more easily
instantiate task_runner without any initial data.
Default to `False`.
input_shape (list | dict ?):
input_shape denoted by list notation `[a,b,c, ...]` or in case
of multihead models, dict object with individual layer keys such
as `{"input_0": [a,b,...], "output_1": [x,y,z, ...]}`
Defaults to `None`.
collaborator_index (int ?):
which collaborator should be used for initializing dataloader
among collaborators specified in plan/data.yaml.
Defaults to `0`.
Returns:
data_loader (DataLoader): DataLoader instance
"""

# if specified, try to use minimal dataloader
if prefer_minimal:
# if input_shape not given, try to ascertain input_shape from plan
if not input_shape and "input_shape" in plan.config["data_loader"]["settings"]:
input_shape = plan.config["data_loader"]["settings"]["input_shape"]

# input_shape is resolved; we can use the minimal dataloader intended
# for util contexts which does not need a full dataloader with data
if input_shape:
data_loader: DataLoader = MockDataLoader(input_shape)
# generically inherit all attributes from data_loader.settings
for key, value in plan.config["data_loader"]["settings"].items():
setattr(data_loader, key, value)
return data_loader

# Fallback; try to get a dataloader by contructing it from the collaborator
# data directory path present in the the current workspace

collaborator_names = list(plan.cols_data_paths)
collatorators_count = len(collaborator_names)

if collaborator_index >= collatorators_count:
raise Exception(
f"Unable to construct full dataloader from collab_index={collaborator_index} "
f"when the plan has {collatorators_count} as total collaborator count. "
f"Please check plan/data.yaml file for current collaborator entries."
)

collaborator_name = collaborator_names[collaborator_index]
collaborator_data_path = plan.cols_data_paths[collaborator_name]

# use seed_data provided by data_loader config if available
if "seed_data" in plan.config["data_loader"]["settings"] and not os.path.isdir(
collaborator_data_path
):
os.makedirs(collaborator_data_path)
sample_data_zip_file = plan.config["data_loader"]["settings"]["seed_data"]
with zipfile.ZipFile(sample_data_zip_file, "r") as zip_ref:
zip_ref.extractall(collaborator_data_path)

data_loader = plan.get_data_loader(collaborator_name)

return data_loader
6 changes: 6 additions & 0 deletions openfl/utilities/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,9 @@ def __init__(self, feature_shape):

def get_feature_shape(self):
return self.feature_shape

def get_train_data_size(self):
return 0

def get_valid_data_size(self):
return 0

0 comments on commit 5bbda0f

Please sign in to comment.