diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index 8c5e672d25..7d61402369 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -335,7 +335,19 @@ def get_aggregator(self, tensor_dict=None): defaults[SETTINGS]['log_metric_callback'] = log_metric_callback if self.aggregator_ is None: - self.aggregator_ = Plan.build(**defaults, initial_tensor_dict=tensor_dict) + # Try loading the aggregator from the serialized object file + api_layer = self.config.get('api_layer', {}) + api_settings = api_layer.get('settings', {}) + obj_file = api_settings.get('aggregator_interface_file') + if obj_file is not None: + self.aggregator_ = self.restore_object(obj_file) + + if self.aggregator_ is None: # Failed to load the object file + self.aggregator_ = Plan.build(**defaults, initial_tensor_dict=tensor_dict) + else: + # Pass the settings to the deserialized object + self.aggregator_.initialize(**defaults[SETTINGS], + initial_tensor_dict=tensor_dict) return self.aggregator_ @@ -491,7 +503,18 @@ def get_collaborator(self, collaborator_name, root_certificate=None, private_key ) if self.collaborator_ is None: - self.collaborator_ = Plan.build(**defaults) + # Try loading the collaborator from the serialized object file + api_layer = self.config.get('api_layer', {}) + api_settings = api_layer.get('settings', {}) + obj_file = api_settings.get('collaborator_interface_file') + if obj_file is not None: + self.collaborator_ = self.restore_object(obj_file) + + if self.collaborator_ is None: # Failed to load from the object file + self.collaborator_ = Plan.build(**defaults) + else: + # Pass the settings to the deserialized object + self.collaborator_.initialize(**defaults[SETTINGS]) return self.collaborator_ @@ -588,8 +611,12 @@ def get_serializer_plugin(self, **kwargs): def restore_object(self, filename): """Deserialize an object.""" + import os + serializer_plugin = self.get_serializer_plugin() if serializer_plugin is None: return None + if not os.path.exists(filename): + return None obj = serializer_plugin.restore_object(filename) return obj diff --git a/openfl/interface/interactive_api/experiment.py b/openfl/interface/interactive_api/experiment.py index 218a6eca8d..daca38b21d 100644 --- a/openfl/interface/interactive_api/experiment.py +++ b/openfl/interface/interactive_api/experiment.py @@ -175,11 +175,14 @@ def remove_experiment_data(self): self.logger.info(log_message) def prepare_workspace_distribution(self, model_provider, task_keeper, data_loader, + aggregator, collaborator, task_assigner, pip_install_options: Tuple[str] = ()): """Prepare an archive from a user workspace.""" # Save serialized python objects to disc - self._serialize_interface_objects(model_provider, task_keeper, data_loader, task_assigner) + self._serialize_interface_objects(model_provider, task_keeper, data_loader, + aggregator, collaborator, + task_assigner) # Save the prepared plan Plan.dump(Path(f'./plan/{self.plan.name}'), self.plan.config, freeze=False) @@ -193,6 +196,8 @@ def prepare_workspace_distribution(self, model_provider, task_keeper, data_loade def start(self, *, model_provider, task_keeper, data_loader, rounds_to_train: int, + aggregator=None, + collaborator=None, task_assigner=None, override_config: dict = None, delta_updates: bool = False, @@ -210,6 +215,8 @@ def start(self, *, model_provider, task_keeper, data_loader, task_keeper - Task Interface instance. data_loader - Data Interface instance. rounds_to_train - required number of training rounds for the experiment. + aggregator - Aggregator Interface instance. + collaborator - Collaborator Interface instance. delta_updates - [bool] Tells if collaborators should send delta updates for the locally tuned models. If set to False, whole checkpoints will be sent. opt_treatment - Optimizer state treatment policy. @@ -234,10 +241,13 @@ def start(self, *, model_provider, task_keeper, data_loader, override_config=override_config, model_interface_file='model_obj.pkl', tasks_interface_file='tasks_obj.pkl', - dataloader_interface_file='loader_obj.pkl') + dataloader_interface_file='loader_obj.pkl', + aggregator_interface_file='aggregator_obj.pkl', + collaborator_interface_file='collaborator_obj.pkl') self.prepare_workspace_distribution( model_provider, task_keeper, data_loader, + aggregator, collaborator, task_assigner, pip_install_options ) @@ -360,6 +370,8 @@ def _prepare_plan(self, model_provider, data_loader, override_config=None, model_interface_file='model_obj.pkl', tasks_interface_file='tasks_obj.pkl', dataloader_interface_file='loader_obj.pkl', + aggregator_interface_file='aggregator_obj.pkl', + collaborator_interface_file='collaborator_obj.pkl', aggregation_function_interface_file='aggregation_function_obj.pkl', task_assigner_file='task_assigner_obj.pkl'): """Fill plan.yaml file using user provided setting.""" @@ -410,6 +422,8 @@ def _prepare_plan(self, model_provider, data_loader, 'model_interface_file': model_interface_file, 'tasks_interface_file': tasks_interface_file, 'dataloader_interface_file': dataloader_interface_file, + 'aggregator_interface_file': aggregator_interface_file, + 'collaborator_interface_file': collaborator_interface_file, 'aggregation_function_interface_file': aggregation_function_interface_file, 'task_assigner_file': task_assigner_file } @@ -423,6 +437,8 @@ def _serialize_interface_objects( model_provider, task_keeper, data_loader, + aggregator, + collaborator, task_assigner ): """Save python objects to be restored on collaborators.""" @@ -436,12 +452,15 @@ def _serialize_interface_objects( 'model_interface_file': model_provider, 'tasks_interface_file': task_keeper, 'dataloader_interface_file': data_loader, + 'aggregator_interface_file': aggregator, + 'collaborator_interface_file': collaborator, 'aggregation_function_interface_file': task_keeper.aggregation_functions, 'task_assigner_file': task_assigner } for filename, object_ in obj_dict.items(): - serializer.serialize(object_, self.plan.config['api_layer']['settings'][filename]) + if object_ is not None: + serializer.serialize(object_, self.plan.config['api_layer']['settings'][filename]) class TaskKeeper: @@ -668,3 +687,207 @@ def get_train_data_size(self): def get_valid_data_size(self): """Information for aggregation.""" raise NotImplementedError + + +class AggregatorInterface: + """ + The class to define an aggregator. + + The experiment manager can define a customized aggregator object for some ad-hoc + federated learning task. + + """ + + def __init__(self, **kwargs): + self.kwargs = kwargs + + def initialize(self, + + aggregator_uuid, + federation_uuid, + authorized_cols, + + rounds_to_train=256, + single_col_cert_common_name=None, + **kwargs): + """ + This method is called by the Plan component. + + Args: + aggregator_uuid (str): Aggregation ID. + federation_uuid (str): Federation ID. + authorized_cols (list of str): The list of IDs of enrolled collaborators. + init_state_path* (str): The location of the initial weight file. + last_state_path* (str): The file location to store the latest weight. + best_state_path* (str): The file location to store the weight of the best model. + """ + + self.round_number = 0 + self.single_col_cert_common_name = single_col_cert_common_name + if self.single_col_cert_common_name is None: + self.single_col_cert_common_name = '' + + self.rounds_to_train = rounds_to_train + + self.authorized_cols = authorized_cols + self.uuid = aggregator_uuid + self.federation_uuid = federation_uuid + + self.logger = getLogger(__name__) + + self.quit_job_sent_to = [] + + self.kwargs.update(kwargs) + + def valid_collaborator_cn_and_id(self, cert_common_name, + collaborator_common_name): + """ + Determine if the collaborator certificate and ID are valid for this federation. + + Args: + cert_common_name: Common name for security certificate + collaborator_common_name: Common name for collaborator + + Returns: + bool: True means the collaborator common name matches the name in + the security certificate. + + """ + # if self.test_mode_whitelist is None, then the common_name must + # match collaborator_common_name and be in authorized_cols + # FIXME: '' instead of None is just for protobuf compatibility. + # Cleaner solution? + if self.single_col_cert_common_name == '': + return (cert_common_name == collaborator_common_name + and collaborator_common_name in self.authorized_cols) + # otherwise, common_name must be in whitelist and + # collaborator_common_name must be in authorized_cols + else: + return (cert_common_name == self.single_col_cert_common_name + and collaborator_common_name in self.authorized_cols) + + def get_tasks(self, collaborator_name): + """ + RPC called by a collaborator to determine which tasks to perform. + + Args: + collaborator_name: str + Requested collaborator name + + Returns: + tasks: list[str] + List of tasks to be performed by the requesting collaborator + for the current round. + sleep_time: int + time_to_quit: bool + """ + raise NotImplementedError + + def get_aggregated_tensor(self, collaborator_name, tensor_name, + round_number, report, tags, require_lossless): + """ + RPC called by collaborator. + + Performs local lookup to determine if there is an aggregated tensor available \ + that matches the request. + + Args: + collaborator_name : str + Requested tensor key collaborator name + tensor_name: str + require_lossless: bool + round_number: int + report: bool + tags: tuple[str, ...] + Returns: + named_tensor : protobuf NamedTensor + the tensor requested by the collaborator + """ + raise NotImplementedError + + def send_local_task_results(self, collaborator_name, round_number, task_name, + data_size, named_tensors): + """ + RPC called by collaborator. + + Transmits collaborator's task results to the aggregator. + + Args: + collaborator_name: str + task_name: str + round_number: int + data_size: int + named_tensors: protobuf NamedTensor + Returns: + None + """ + raise NotImplementedError + + def all_quit_jobs_sent(self): + """Assert all quit jobs are sent to collaborators.""" + return set(self.quit_job_sent_to) == set(self.authorized_cols) + + def stop(self, failed_collaborator): + """Stop aggregator execution.""" + if failed_collaborator: + self.quit_job_sent_to.append(failed_collaborator) + + +class CollaboratorInterface: + """ + The class to define a collaborator. + + The experiment manager can define a customized aggregator object for some ad-hoc + federated learning task. + + """ + + def __init__(self, **kwargs): + self.kwargs = kwargs + + def initialize(self, collaborator_name, + aggregator_uuid, + federation_uuid, + client, + **kwargs): + """ + This method is called by the Plan component, to pass the settings to the + collaborator. + + Args: + collaborator_name (string): The common name for the collaborator + aggregator_uuid: The unique id for the client + federation_uuid: The unique id for the federation + """ + + self.single_col_cert_common_name = None + + if self.single_col_cert_common_name is None: + self.single_col_cert_common_name = '' # for protobuf compatibility + # we would really want this as an object + + self.collaborator_name = collaborator_name + self.aggregator_uuid = aggregator_uuid + self.federation_uuid = federation_uuid + + self.client = client + + self.logger = getLogger(__name__) + + self.kwargs.update(kwargs) + + def run(self): + """ + Run the collaborator, called by the Envoy service. + """ + raise NotImplementedError + + def set_available_devices(self, cuda: Tuple[str] = ()): + """ + This method is called by the Envoy service. + + Set available CUDA devices. + + Cuda tuple contains string indeces, ('1', '3'). + """ + self.cuda_devices = cuda