Skip to content

Commit

Permalink
run local & federated mlp training with job api
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Dec 5, 2024
1 parent 52388af commit a1a1b93
Show file tree
Hide file tree
Showing 18 changed files with 262 additions and 9,941 deletions.

This file was deleted.

This file was deleted.

This file was deleted.

Binary file not shown.
11 changes: 0 additions & 11 deletions examples/advanced/bionemo/task_fitting/jobs/embeddings/meta.json

This file was deleted.

This file was deleted.

This file was deleted.

10 changes: 0 additions & 10 deletions examples/advanced/bionemo/task_fitting/jobs/fedavg/meta.json

This file was deleted.

2 changes: 1 addition & 1 deletion examples/advanced/bionemo/task_fitting/split_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,6 @@ def split(proteins, num_sites, split_dir=".", alpha=1.0, seed=0, concat=False):
)

print(
f"Saved {len(df_split_train_proteins)} training and {len(test_proteins)} testing proteins for {client_name}, "
f"Saved {len(df_split_train_proteins)} training and {len(df_test_proteins)} testing proteins for {client_name}, "
f"({len(set(df_split_train_proteins['labels']))}/{len(set(df_test_proteins['labels']))}) unique train/test classes."
)
111 changes: 111 additions & 0 deletions examples/advanced/bionemo/task_fitting/src/bionemo_mlp_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional
from torch import nn as nn

from bionemo_mlp_model_persistor import BioNeMoMLPModelPersistor

from nvflare.app_opt.pt.file_model_locator import PTFileModelLocator
from nvflare.app_common.abstract.model_locator import ModelLocator
from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE
from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.widgets.streaming import AnalyticsReceiver
from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator
from nvflare.app_opt.pt.job_config.model import PTModel
from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver
from nvflare.job_config.api import FedJob, validate_object_for_job


class BioNeMoMLPJob(FedJob):
def __init__(
self,
initial_model: nn.Module = None,
name: str = "fed_job",
min_clients: int = 1,
mandatory_clients: Optional[List[str]] = None,
key_metric: str = "accuracy",
validation_json_generator: Optional[ValidationJsonGenerator] = None,
intime_model_selector: Optional[IntimeModelSelector] = None,
convert_to_fed_event: Optional[ConvertToFedEvent] = None,
analytics_receiver: Optional[AnalyticsReceiver] = None,
):
"""PyTorch BaseFedJob.
Configures ValidationJsonGenerator, IntimeModelSelector, AnalyticsReceiver, ConvertToFedEvent.
User must add controllers and executors.
Args:
initial_model (nn.Module): initial PyTorch Model. Defaults to None.
name (name, optional): name of the job. Defaults to "fed_job".
min_clients (int, optional): the minimum number of clients for the job. Defaults to 1.
mandatory_clients (List[str], optional): mandatory clients to run the job. Default None.
key_metric (str, optional): Metric used to determine if the model is globally best.
if metrics are a `dict`, `key_metric` can select the metric used for global model selection.
Defaults to "accuracy".
validation_json_generator (ValidationJsonGenerator, optional): A component for generating validation results.
if not provided, a ValidationJsonGenerator will be configured.
intime_model_selector: (IntimeModelSelector, optional): A component for select the model.
if not provided, an IntimeModelSelector will be configured.
convert_to_fed_event: (ConvertToFedEvent, optional): A component to covert certain events to fed events.
if not provided, a ConvertToFedEvent object will be created.
analytics_receiver (AnlyticsReceiver, optional): Receive analytics.
If not provided, a TBAnalyticsReceiver will be configured.
"""
super().__init__(
name=name,
min_clients=min_clients,
mandatory_clients=mandatory_clients,
)

self.initial_model = initial_model
self.comp_ids = {}

if validation_json_generator:
validate_object_for_job("validation_json_generator", validation_json_generator, ValidationJsonGenerator)
else:
validation_json_generator = ValidationJsonGenerator()
self.to_server(id="json_generator", obj=validation_json_generator)

if intime_model_selector:
validate_object_for_job("intime_model_selector", intime_model_selector, IntimeModelSelector)
self.to_server(id="model_selector", obj=intime_model_selector)
elif key_metric:
self.to_server(id="model_selector", obj=IntimeModelSelector(key_metric=key_metric))

if convert_to_fed_event:
validate_object_for_job("convert_to_fed_event", convert_to_fed_event, ConvertToFedEvent)
else:
convert_to_fed_event = ConvertToFedEvent(events_to_convert=[ANALYTIC_EVENT_TYPE])
self.convert_to_fed_event = convert_to_fed_event

if analytics_receiver:
validate_object_for_job("analytics_receiver", analytics_receiver, AnalyticsReceiver)
else:
analytics_receiver = TBAnalyticsReceiver()

self.to_server(
id="receiver",
obj=analytics_receiver,
)

self.to_server(id="persistor", obj=BioNeMoMLPModelPersistor())

self.to_server(id="locator", obj=PTFileModelLocator(pt_persistor_id="persistor"))

def set_up_client(self, target: str):
self.to(id="event_to_fed", obj=self.convert_to_fed_event, target=target)
Loading

0 comments on commit a1a1b93

Please sign in to comment.