Skip to content

Commit

Permalink
Refactor for students
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed Apr 19, 2022
1 parent 88a38f6 commit 5e41c10
Show file tree
Hide file tree
Showing 50 changed files with 228 additions and 1,928 deletions.
11 changes: 7 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,20 @@ WORKDIR /opt/federation-lab

# Update the Ubuntu software repository and fetch packages
RUN apt-get update \
&& apt-get install -y curl python3 python3-pip
&& apt-get install -y python3.9

# Setup pip3.9
RUN apt install -y curl python3.9-distutils
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
RUN python3.9 get-pip.py
# Add Pre-downloaded models (otherwise needs be run every-time)
ADD data/ data/

# Use cache for pip, otherwise we repeatedly pull from repository
ADD requirements.txt ./
RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install -r requirements.txt
COPY requirements-cpu.txt ./requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip,mode=0777 python3.9 -m pip install -r requirements.txt

# Add FLTK and configurations
ADD fltk fltk
ADD configs configs
ADD experiments experiments
ADD charts charts
1 change: 0 additions & 1 deletion fltk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@

__version__ = '0.4.1'
1 change: 1 addition & 0 deletions fltk/core/comm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .rpc_util import *
9 changes: 8 additions & 1 deletion fltk/core/rpc_util.py → fltk/core/comm/rpc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@ def _remote_method_async_by_info(method, worker_info, *args, **kwargs):
args = [method, worker_info] + list(args)
return rpc.rpc_async(worker_info, _call_method, args=args, kwargs=kwargs)


def _remote_method_direct(method, other_node: str, *args, **kwargs):
"""
Utility function for RPC communication between nodes.
:param method: A callable
:param other_node: reference to other node
:return: any
"""

args = [method, other_node] + list(args)
# return rpc.rpc_sync(other_node, _call_method, args=args, kwargs=kwargs)
return rpc.rpc_sync(other_node, method, args=args, kwargs=kwargs)
2 changes: 1 addition & 1 deletion fltk/core/distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, rank: int, task_id: str, world_size: int, config: Distributed
@param task_id: String id representing the UID of the training task
@type task_id: str
@param config: Parsed configuration file representation to extract runtime information from.
@type config: BareConfig
@type config: DistributedConfig
@param learning_params: Hyper-parameter configuration to be used during the training process by the learner.
@type learning_params: LearningParameters
"""
Expand Down
14 changes: 2 additions & 12 deletions fltk/core/node.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import copy
import os
from typing import Callable, Any
Expand All @@ -12,18 +13,7 @@
global_vars = {}


def _remote_method_direct(method, other_node: str, *args, **kwargs):
"""
Utility function for RPC communication between nodes
:param method: A callable
:param other_node: reference to other node
:return: any
"""
args = [method, other_node] + list(args)
return rpc.rpc_sync(other_node, method, args=args, kwargs=kwargs)


class Node:
class Node(abc.ABC):
"""
Implementation of any participating node.
It handles communication and the basic functions for Deep Learning.
Expand Down
2 changes: 1 addition & 1 deletion fltk/datasets/loader_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from fltk.datasets.distributed import DistMNISTDataset, DistFashionMNISTDataset, DistCIFAR100Dataset, DistCIFAR10Dataset
from fltk.util.definitions import Dataset
from fltk.util.config.definitions.dataset import Dataset


def available_datasets():
Expand Down
11 changes: 6 additions & 5 deletions fltk/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from fltk.core.distributed import Orchestrator
from fltk.core.distributed.extractor import download_datasets
from fltk.core.federator import Federator
from fltk.nets.util.reproducability import init_reproducibility
from fltk.util.cluster.client import ClusterManager
from fltk.util.config import DistributedConfig, Config
from fltk.util.config.arguments import LearningParameters, extract_learning_parameters
Expand All @@ -41,7 +42,7 @@ def launch_distributed_client(task_id: str, conf: DistributedConfig = None,
@param task_id: String representation (should be unique) corresponding to a client.
@type task_id: str
@param conf: Configuration for components, needed for spinning up components of the Orchestrator.
@type conf: BareConfig
@type conf: DistributedConfig
@param learning_params: Parsed configuration of Hyper-Parameters for learning.
@type: LearningParameters
@return: None
Expand Down Expand Up @@ -71,7 +72,7 @@ def launch_orchestrator(args: Namespace = None, conf: DistributedConfig = None,
@type args: Namespace
@param config: Configuration for execution of Orchestrators components, needed for spinning up components of the
Orchestrator.
@type config: BareConfig
@type config: Optional[DistributedConfig]
@return: None
@rtype: None
"""
Expand All @@ -86,7 +87,7 @@ def launch_orchestrator(args: Namespace = None, conf: DistributedConfig = None,
logging.info("Pointing configuration to in cluster configuration.")
conf.cluster_config.load_incluster_namespace()
conf.cluster_config.load_incluster_image()
arrival_generator = DistributedExperimentGenerator() if simulate_arrivals else FederatedArrivalGenerator()
arrival_generator = (DistributedExperimentGenerator if simulate_arrivals else FederatedArrivalGenerator)(args.experiment)
cluster_manager = ClusterManager()

orchestrator = Orchestrator(cluster_manager, arrival_generator, conf)
Expand Down Expand Up @@ -114,7 +115,7 @@ def launch_extractor(base_path: Path, config_path: Path, args: Namespace = None,
@param args: Arguments passed from CLI.
@type args: Namespace
@param conf: Parsed configuration file passed from the CLI.
@type conf: BareConfig
@type conf: Optional[DistributedConfig]
@return: None
@rtype: None
"""
Expand Down Expand Up @@ -295,5 +296,5 @@ def launch_cluster(arg_path, conf_path, args: Namespace = None, config: Distribu
datefmt='%m-%d %H:%M')
# Set the seed for arrivals, torch seed is mostly ignored. Set the `arrival_seed` to a different value
# for each repetition that you want to run an experiment with.
config.set_seed()
init_reproducibility(config.execution_config)
launch_orchestrator(args=args, conf=config)
2 changes: 1 addition & 1 deletion fltk/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from fltk.util.definitions import Nets
from ..util.config.definitions.net import Nets
from .cifar_100_resnet import Cifar100ResNet
from .cifar_100_vgg import Cifar100VGG, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from .cifar_10_cnn import Cifar10CNN
Expand Down
2 changes: 1 addition & 1 deletion fltk/nets/cifar_10_cnn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# pylint: disable=missing-class-docstring,invalid-name
import torch
import torch.torch.nn.functional as F
import torch.nn.functional as F

class Cifar10CNN(torch.nn.Module):

Expand Down
2 changes: 1 addition & 1 deletion fltk/nets/cifar_10_resnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# pylint: disable=missing-class-docstring,invalid-name
import torch
import torch.torch.nn.functional as F
import torch.nn.functional as F


class BasicBlock(torch.nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions fltk/nets/util/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch.utils.tensorboard import SummaryWriter

import fltk.util.config as config
from fltk.util.config import DistributedConfig
from fltk.util.results import EpochData


Expand Down Expand Up @@ -46,7 +46,7 @@ def recover_flattened(flat_params, model):
return recovered_params


def initialize_default_model(conf: config.DistributedConfig, model_class) -> torch.nn.Module:
def initialize_default_model(conf: DistributedConfig, model_class) -> torch.nn.Module:
"""
Load a default model dictionary into a torch model.
@param model:
Expand Down
23 changes: 10 additions & 13 deletions fltk/nets/util/reproducability.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import torch

from fltk.util.config.distributed_config import ExecutionConfig


def cuda_reproducible_backend(cuda: bool) -> None:
"""
Expand All @@ -21,25 +23,20 @@ def cuda_reproducible_backend(cuda: bool) -> None:
torch.backends.cudnn.deterministic = False


def init_reproducibility(torch_seed: int = 42, cuda: bool = False, numpy_seed: int = 43, hash_seed: int = 44) -> None:
def init_reproducibility(config: ExecutionConfig) -> None:
"""
Function to pre-set all seeds for libraries used during training. Allows for re-producible network initialization,
and non-deterministic number generation. Allows to prevent 'lucky' draws in network initialization.
@param torch_seed: Integer seed to use for the PyTorch PRNG and CUDA PRNG.
@type torch_seed: int
@param cuda: Flag to indicate whether the CUDA backend needs to be
@type cuda: bool
@param numpy_seed: Integer seed to use for NumPy's PRNG.
@type numpy_seed: int
@param hash_seed: Integer seed to use for Pythons Hash function PRNG, will set the
@type hash_seed: int
@param config: Execution configuration for the experiments to be run on the remote cluster.
@type config: ExecutionConfig
@return: None
@rtype: None
"""
random_seed = config.reproducibility.arrival_seed
torch_seed = config.reproducibility.torch_seed
torch.manual_seed(torch_seed)
if cuda:
if config.cuda:
torch.cuda.manual_seed_all(torch_seed)
cuda_reproducible_backend(True)
np.random.seed(numpy_seed)
os.environ['PYTHONHASHSEED'] = str(hash_seed)
np.random.seed(random_seed)
os.environ['PYTHONHASHSEED'] = str(random_seed)
2 changes: 1 addition & 1 deletion fltk/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .dirichlet import DirichletSampler
from .limit_labels import LimitLabelsSampler
from .limit_labels_flex import LimitLabelsSamplerFlex
from ..util.definitions import DataSampler
from ..util.config.definitions import DataSampler
from ..util.log import getLogger


Expand Down
1 change: 0 additions & 1 deletion fltk/strategy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .aggregation import *
from .client_selection import *
from .optimization import *
from .offloading import OffloadingStrategy, parse_strategy
9 changes: 2 additions & 7 deletions fltk/strategy/aggregation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
from typing import Callable

import torch

from fltk.util.definitions import Aggregations
from .FedAvg import fed_avg
from .aggregation import average_nn_parameters, average_nn_parameters_simple
from fltk.util.config.definitions.aggregate import Aggregations
from .fed_avg import fed_avg


def get_aggregation(name: Aggregations):
Expand Down
File renamed without changes.
97 changes: 0 additions & 97 deletions fltk/strategy/offloading.py

This file was deleted.

11 changes: 7 additions & 4 deletions fltk/strategy/optimization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@

import torch

from fltk.util.definitions import Optimizations
from .fedprox import FedProx
from .FedNova import FedNova
from fltk.util.config.definitions.optim import Optimizations
from .fed_prox import FedProx
from .fed_nova import FedNova


def get_optimizer(name: Optimizations) -> Type[torch.optim.Optimizer]:
"""
Helper function to get specific Optimization class references.
@param name: Optimizer class reference.
@type name: Optimizations
@return: Class reference corresponding to the requested Optimizations definition.
@return: Class reference corresponding to the requested Optimizations definition. Requires instantiation with
pre-defined args and kwargs, depending on the Type of Optimizer.
@rtype: Type[torch.optim.Optimizer]
"""
optimizers = {
Optimizations.adam: torch.optim.Adam,
Optimizations.adam_w: torch.optim.AdamW,
Optimizations.sgd: torch.optim.SGD,
Optimizations.fedprox: FedProx,
Optimizations.fednova: FedNova
Expand Down
File renamed without changes.
File renamed without changes.
1 change: 0 additions & 1 deletion fltk/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from fltk.util.reproducability import init_reproducibility
Loading

0 comments on commit 5e41c10

Please sign in to comment.