Skip to content

Commit

Permalink
Update name of Federated learning class name
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed May 11, 2022
1 parent be5d595 commit a8640ab
Show file tree
Hide file tree
Showing 10 changed files with 25 additions and 24 deletions.
4 changes: 2 additions & 2 deletions fltk/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fltk.core.node import Node
from fltk.schedulers import MinCapableStepLR
from fltk.strategy import get_optimizer
from fltk.util.config import Config
from fltk.util.config import FedLearningConfig


class Client(Node):
Expand All @@ -17,7 +17,7 @@ class Client(Node):
"""
running = False

def __init__(self, identifier: str, rank: int, world_size: int, config: Config):
def __init__(self, identifier: str, rank: int, world_size: int, config: FedLearningConfig):
super().__init__(identifier, rank, world_size, config)

self.loss_function = self.config.get_loss_function()()
Expand Down
4 changes: 2 additions & 2 deletions fltk/core/federator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from fltk.core.node import Node
from fltk.strategy import get_aggregation
from fltk.strategy import random_selection
from fltk.util.config import Config
from fltk.util.config import FedLearningConfig
from fltk.util.data_container import DataContainer, FederatorRecord, ClientRecord

NodeReference = Union[Node, str]
Expand Down Expand Up @@ -57,7 +57,7 @@ class Federator(Node):
num_rounds: int
exp_data: DataContainer

def __init__(self, identifier: str, rank: int, world_size: int, config: Config):
def __init__(self, identifier: str, rank: int, world_size: int, config: FedLearningConfig):
super().__init__(identifier, rank, world_size, config)
self.loss_function = self.config.get_loss_function()()
self.num_rounds = config.rounds
Expand Down
6 changes: 3 additions & 3 deletions fltk/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.distributed import rpc
from fltk.datasets.loader_util import get_dataset
from fltk.nets import get_net
from fltk.util.config import Config
from fltk.util.config import FedLearningConfig
from fltk.util.log import getLogger

# Global dictionary to enable peer to peer communication between clients
Expand All @@ -30,7 +30,7 @@ class Node(abc.ABC):
dataset: Any
logger = getLogger(__name__)

def __init__(self, identifier: str, rank: int, world_size: int, config: Config):
def __init__(self, identifier: str, rank: int, world_size: int, config: FedLearningConfig):
self.config = config
self.id = identifier # pylint: disable=invalid-name
self.rank = rank
Expand All @@ -40,7 +40,7 @@ def __init__(self, identifier: str, rank: int, world_size: int, config: Config):
global_vars['self'] = self
self._config(config)

def _config(self, config: Config):
def _config(self, config: FedLearningConfig):
self.logger.setLevel(config.log_level.value)
self.config.rank = self.rank
self.config.world_size = self.world_size
Expand Down
4 changes: 2 additions & 2 deletions fltk/datasets/distributed/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

from fltk.datasets.distributed.dataset import DistDataset
from fltk.samplers import get_sampler
from fltk.util.config import Config
from fltk.util.config import FedLearningConfig


class DistCIFAR10Dataset(DistDataset):
"""
CIFAR10 Dataset implementation for Distributed learning experiments.
"""

def __init__(self, args: Config):
def __init__(self, args: FedLearningConfig):
super(DistCIFAR10Dataset, self).__init__(args)
self.init_train_dataset()
self.init_test_dataset()
Expand Down
4 changes: 2 additions & 2 deletions fltk/datasets/distributed/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any

# from fltk.util.arguments import Arguments
from fltk.util.config import Config
from fltk.util.config import FedLearningConfig
from fltk.util.log import getLogger


Expand All @@ -15,7 +15,7 @@ class DistDataset:
test_loader = None
logger = getLogger(__name__)

def __init__(self, args: Config):
def __init__(self, args: FedLearningConfig):
self.args = args

def get_args(self):
Expand Down
6 changes: 3 additions & 3 deletions fltk/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from fltk.nets.util.reproducability import init_reproducibility, init_learning_reproducibility
from fltk.util.cluster.client import ClusterManager
from fltk.util.cluster.worker import should_distribute
from fltk.util.config import DistributedConfig, Config, retrieve_config_network_params, get_learning_param_config, \
from fltk.util.config import DistributedConfig, FedLearningConfig, retrieve_config_network_params, get_learning_param_config, \
DistLearningConfig
from fltk.util.environment import retrieve_or_init_env, retrieve_env_config
from fltk.util.task.generator.arrival_generator import SimulatedArrivalGenerator, SequentialArrivalGenerator
Expand Down Expand Up @@ -204,7 +204,7 @@ def launch_single(arg_path: Path, conf_path: Path, rank: Rank, nic: Optional[NIC
# We can iterate over all the experiments in the directory and execute it, as long as the system remains the same!
# System = machines and its configuration
print(conf_path)
s_conf = Config.from_yaml(conf_path)
s_conf = FedLearningConfig.from_yaml(conf_path)
s_conf.world_size = conf.num_clients + 1
s_conf.replication_id = prefix
federator_node = Federator('federator', 0, conf.world_size, s_conf)
Expand Down Expand Up @@ -238,7 +238,7 @@ def launch_remote(arg_path: Path, conf_path: Path, rank: Rank, nic: Optional[NIC
@return: None
@rtype: None
"""
r_conf = Config.from_yaml(conf_path)
r_conf = FedLearningConfig.from_yaml(conf_path)
r_conf.world_size = r_conf.num_clients + 1
r_conf.replication_id = prefix
if rank and not (nic and host):
Expand Down
4 changes: 2 additions & 2 deletions fltk/util/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import logging

from fltk.util.config.distributed_config import DistributedConfig
from fltk.util.config.config import Config, get_safe_loader, DistLearningConfig
from fltk.util.config.config import FedLearningConfig, get_safe_loader, DistLearningConfig


def retrieve_config_network_params(conf: Config, nic=None, host=None):
def retrieve_config_network_params(conf: FedLearningConfig, nic=None, host=None):
if hasattr(conf, 'system'):
system_attr = getattr(conf, 'system')
if 'federator' in system_attr:
Expand Down
9 changes: 5 additions & 4 deletions fltk/util/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ def get_safe_loader() -> yaml.SafeLoader:
list(u'-+0123456789.'))
return safe_loader


@dataclass_json
@dataclass
class Config:
class FedLearningConfig:
batch_size: int = 1
test_batch_size: int = 1000
rounds: int = 2
Expand Down Expand Up @@ -141,18 +142,18 @@ def from_yaml(path: Path):
you prefer to create json based configuration files.
>>> with open("configs/example.json") as f:
>>> Config.from_json(f.read())
>>> FedLearningConfig.from_json(f.read())
@param path: Path pointing to configuration yaml file.
@type path: Path
@return: Configuration dataclass representation of the configuration file.
@rtype: Config
@rtype: FedLearningConfig
"""
getLogger(__name__).debug(f'Loading yaml from {path.absolute()}')
safe_loader = get_safe_loader()
with open(path) as file:
content = yaml.load(file, Loader=safe_loader)
conf = Config.from_dict(content)
conf = FedLearningConfig.from_dict(content)
return conf


Expand Down
2 changes: 1 addition & 1 deletion fltk/util/config/distributed_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def get_data_path(self) -> Path:

def get_default_model_folder_path(self) -> Path:
"""
@deprecated Function to get the default model folder path from Config, needed for non-default training in the
@deprecated Function to get the default model folder path from FedLearningConfig, needed for non-default training in the
FLTK framework.
@return: Path representation of model path.
@rtype: Path
Expand Down
6 changes: 3 additions & 3 deletions tests/core/client_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from fltk.core.client import Client
from fltk.core.distributed import DistClient
from fltk.util.config import DistributedConfig, get_distributed_config, get_learning_param_config, Config, \
from fltk.util.config import DistributedConfig, get_distributed_config, get_learning_param_config, FedLearningConfig, \
DistLearningConfig

from fltk.datasets.dataset import Dataset as DS
Expand Down Expand Up @@ -56,10 +56,10 @@ def test_parallel_client(self, name, net: Nets, dataset: Dataset):


class TestFederatedLearnerSmoke(unittest.TestCase):
learning_config: Config = None
learning_config: FedLearningConfig = None

def setUp(self):
self.learning_config = Config.from_yaml(Path(TEST_PARAM_CONF_FEDERATED))
self.learning_config = FedLearningConfig.from_yaml(Path(TEST_PARAM_CONF_FEDERATED))

@parameterized.expand(
[[f"{x.value}-{y.value}", x, y] for x, y in MODEL_SET_PAIRING]
Expand Down

0 comments on commit a8640ab

Please sign in to comment.