Skip to content

Commit

Permalink
Point towards example cloud experiment description
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed Mar 28, 2022
1 parent e2338bc commit 23064b6
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 59 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ ADD scripts scripts
EXPOSE 5000

# Update relevant runtime configuration for experiment
COPY cloud_configs/cloud_experiment.yaml configs/cloud_config.yaml
COPY configs/example_cloud_experiment.json configs/example_cloud_experiment.json
2 changes: 1 addition & 1 deletion charts/federator/templates/fl-server-pod.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ spec:
- -m
- fltk
- cluster
- configs/cloud_config.yaml
- configs/example_cloud_experiment.json
- --rank=0
env:
- name: MASTER_PORT
Expand Down
2 changes: 1 addition & 1 deletion configs/example_cloud_experiment.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"cuda": false,
"tensorboard": {
"active": true,
"record_dir": true
"record_dir": "logging"
},
"net": {
"save_model": false,
Expand Down
8 changes: 0 additions & 8 deletions fltk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,7 @@ def cluster_start(args: dict, config: BareConfig):
"""
logging.info("[Fed] Starting in cluster mode.")
# TODO: Load configuration path
config_path: Path = None
cluster_manager = ClusterManager()
arrival_generator = ExperimentGenerator(config_path)

pool = ThreadPool(4)
pool.apply(cluster_manager.start)
pool.apply(arrival_generator.run)

pool.join()

print(f'rank={args.rank}, world_size={world_size}, host={master_address}, args=cfg, nic={nic}')
run_single(rank=args.rank, args=config, nic=nic)
Expand Down
7 changes: 3 additions & 4 deletions fltk/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp

from fltk.orchestrator import run_ps
from fltk.orchestrator import run_orchestrator
from fltk.util.env.learner_environment import prepare_environment

logging.basicConfig(level=logging.INFO)


def await_assigned_orchestrator():
# TODO: Implement await function for client

"""
TODO:
1. Setup everything correctly according to provided configuration files.
2. Register to cleint
3. Start working on task description provided by orchestrator
Expand Down Expand Up @@ -51,7 +50,7 @@ def run_single(rank, world_size, host=None, args=None, nic=None):
world_size=world_size,
rpc_backend_options=options
)
run_ps([(f"client{r}", r, world_size) for r in range(1, world_size)], args)
run_orchestrator([(f"client{r}", r, world_size) for r in range(1, world_size)], args)

# block until all rpc finish
rpc.shutdown()
Expand Down
59 changes: 35 additions & 24 deletions fltk/orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import pathlib
import time
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import List, Callable, Dict

import torch
Expand All @@ -10,11 +11,11 @@

from fltk.client import Client
from fltk.nets.util.utils import flatten_params, save_model
from fltk.util.cluster.client import ClientRef, ClusterManager
from fltk.util.config.base_config import BareConfig
from fltk.util.cluster.client import ClientRef
from fltk.util.task.generator.arrival_generator import ArrivalGenerator
from fltk.util.log import DistLearningLogger
from fltk.util.results import EpochData
from fltk.util.task.generator.arrival_generator import ArrivalGenerator, ExperimentGenerator


def _call_method(method, rref, *args, **kwargs):
Expand Down Expand Up @@ -144,12 +145,22 @@ def client_reset_model(self):
for client in self.clients:
_remote_method_async(Client.reset_model, client.ref)

def client_load_data(self, poison_pill):
def client_load_data(self):
"""
TODO: Make this compatible with job registration...
@return:
@rtype:
"""
for client in self.clients:
_remote_method_async(Client.init_dataloader, client.ref,
pill=None if poison_pill and client not in self.poisoned_clients else poison_pill)
_remote_method_async(Client.init_dataloader, client.ref)

def clients_ready(self):
"""
TODO: Make compatible with Job registration
TODO: Make push based instead of pull based.
@return:
@rtype:
"""
all_ready = False
ready_clients = []
while not all_ready:
Expand Down Expand Up @@ -271,19 +282,11 @@ def run(self):

# # Select clients which will be poisened
# TODO: get attack type and ratio from config, temp solution now
poison_pill = None
save_path = self.config

# Update writer to logdir
self.update_clients(rat)
if self.attack:
self.poisoned_clients: List[ClientRef] = self.attack.select_poisoned_clients(self.clients, rat)
self.healthy_clients = list(set(self.clients).symmetric_difference(set(self.poisoned_clients)))
print(f"Poisoning workers: {self.poisoned_clients}")
with open(f"{self.tb_path_base}/config_{rat}_poisoned.txt", 'w') as f:
f.writelines(list(map(lambda worker: worker.name, self.poisoned_clients)))
poison_pill = self.attack.get_poison_pill()
self.client_load_data(poison_pill)
save_path = Path(self.config.execution_config.general_net.save_model_path)
logging_dir = self.config.execution_config.tensorboard.record_dir

# self.update_clients()
self.client_load_data()
self.ping_all()
self.clients_ready()
self.update_client_data_sizes()
Expand Down Expand Up @@ -330,7 +333,7 @@ def store_gradient(self, gradient, client_id, epoch, ratio):
"""
directory: str = f"{self.tb_path_base}/gradient/{ratio}/{epoch}/{client_id}"
# Ensure path exists (mkdir -p)
pathlib.Path(directory).mkdir(parents=True, exist_ok=True)
Path(directory).mkdir(parents=True, exist_ok=True)
# Save using pytorch.
torch.save(gradient, f"{directory}/gradient.pt")

Expand All @@ -350,7 +353,15 @@ def distribute_new_model(self, updated_model) -> None:
logging.info('Weights are updated')


def run_ps(rpc_ids_triple, args):
print(f'Starting the federator...')
fed = Orchestrator(rpc_ids_triple, config=args)
fed.run()
def run_orchestrator(rpc_ids_triple, configuration: BareConfig, config_path: Path):
logging.info("Starting Orchestrator, initializing resources....")
orchestrator = Orchestrator(rpc_ids_triple, config=configuration)
cluster_manager = ClusterManager()
arrival_generator = ExperimentGenerator(config_path)

pool = ThreadPool(3)
pool.apply(cluster_manager.start)

pool.apply(arrival_generator.run, {'orchestrator': orchestrator})
pool.apply(orchestrator.run, {'cluster_manager': cluster_manager})
pool.join()
36 changes: 21 additions & 15 deletions fltk/strategy/data_samplers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import random
import logging
import random
from collections import Counter
Expand All @@ -11,16 +10,16 @@
class DistributedSamplerWrapper(DistributedSampler):
indices = []
epoch_size = 1.0
def __init__(self, dataset: Dataset, num_replicas = None,
rank = None, seed = 0) -> None:

def __init__(self, dataset: Dataset, num_replicas=None,
rank=None, seed=0) -> None:
super().__init__(dataset, num_replicas=num_replicas, rank=rank)

self.client_id = rank - 1
self.n_clients = num_replicas - 1
self.n_labels = len(dataset.classes)
self.seed = seed


def order_by_label(self, dataset):
# order the indices by label
ordered_by_label = [[] for i in range(len(dataset.classes))]
Expand All @@ -40,35 +39,37 @@ def set_epoch_size(self, epoch_size: float) -> None:
self.epoch_size = epoch_size

def __iter__(self) -> Iterator[int]:
random.seed(self.rank+self.epoch)
random.seed(self.rank + self.epoch)
epochs_todo = self.epoch_size
indices = []
while(epochs_todo > 0.0):
while (epochs_todo > 0.0):
random.shuffle(self.indices)
if epochs_todo >= 1.0:
indices.extend(self.indices)
else:
end_index = int(round(len(self.indices)*epochs_todo))
end_index = int(round(len(self.indices) * epochs_todo))
indices.extend(self.indices[:end_index])

epochs_todo = epochs_todo - 1

ratio = len(indices)/float(len(self.indices))
ratio = len(indices) / float(len(self.indices))
np.testing.assert_almost_equal(ratio, self.epoch_size, decimal=2)

return iter(indices)

def __len__(self) -> int:
return len(self.indices)


class LimitLabelsSampler(DistributedSamplerWrapper):
"""
A sampler that limits the number of labels per client
"""

def __init__(self, dataset, num_replicas, rank, args=(5, 42)):
limit, seed = args
super().__init__(dataset, num_replicas, rank, seed)

if self.n_clients % self.n_labels != 0:
logging.error(
"multiples of {} clients are needed for the 'limiting-labels' data distribution method, {} does not work".format(
Expand Down Expand Up @@ -145,6 +146,7 @@ def __init__(self, dataset, num_replicas, rank, args=(5, 42)):

self.indices = indices


class Probability_q_Sampler(DistributedSamplerWrapper):
"""
Clients are divided among M groups, with M being the number of labels.
Expand All @@ -157,11 +159,11 @@ class Probability_q_Sampler(DistributedSamplerWrapper):
def __init__(self, dataset, num_replicas, rank, args=(0.5, 42)):
q, seed = args
super().__init__(dataset, num_replicas, rank, seed)

if self.n_clients % self.n_labels != 0:
logging.error(
"multiples of {} clients are needed for the 'probability-q-sampler' data distribution method, {} does not work".format(
self.n_labels,self.n_clients))
self.n_labels, self.n_clients))
return

# divide data among groups
Expand Down Expand Up @@ -195,13 +197,15 @@ def __init__(self, dataset, num_replicas, rank, args=(0.5, 42)):

self.indices = indices


class DirichletSampler(DistributedSamplerWrapper):
""" Generates a (non-iid) data distribution by sampling the dirichlet distribution. Dirichlet constructs a
vector of length num_clients, that sums to one. Decreasing alpha results in a more non-iid data set.
This distribution method results in both label and quantity skew.
"""
def __init__(self, dataset: Dataset, num_replicas = None,
rank = None, args = (0.5, 42)) -> None:

def __init__(self, dataset: Dataset, num_replicas=None,
rank=None, args=(0.5, 42)) -> None:
alpha, seed = args
super().__init__(dataset, num_replicas=num_replicas, rank=rank, seed=seed)

Expand All @@ -211,7 +215,7 @@ def __init__(self, dataset: Dataset, num_replicas = None,
for labels in ordered_by_label:
n_samples = len(labels)
# generate an allocation by sampling dirichlet, which results in how many samples each client gets
allocation = np.random.dirichlet([alpha] * self.n_clients) * n_samples
allocation = np.random.dirichlet([alpha] * self.n_clients) * n_samples
allocation = allocation.astype(int)
start_index = allocation[0:self.client_id].sum()
end_index = 0
Expand All @@ -232,19 +236,21 @@ def __init__(self, dataset: Dataset, num_replicas = None,

self.indices = indices


class UniformSampler(DistributedSamplerWrapper):
def __init__(self, dataset, num_replicas=None, rank=None, seed=0):
super().__init__(dataset, num_replicas=num_replicas, rank=rank, seed=seed)
indices = list(range(len(self.dataset)))
self.indices = indices[self.rank:self.total_size:self.num_replicas]


def get_sampler(dataset, args):
sampler = None
if args.get_distributed():
method = args.get_sampler()
args.get_logger().info(
"Using {} sampler method, with args: {}".format(method, args.get_sampler_args()))

if method == "uniform":
sampler = UniformSampler(dataset, num_replicas=args.get_world_size(), rank=args.get_rank())
elif method == "q sampler":
Expand Down
1 change: 0 additions & 1 deletion fltk/util/cluster/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def start(self):
_thread_pool = ThreadPool(processes=2)
_thread_pool.apply(self._watchdog.start)
_thread_pool.apply(self._run)

_thread_pool.join()

def _stop(self):
Expand Down
30 changes: 27 additions & 3 deletions fltk/util/config/base_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from pathlib import Path

from dataclasses_json import config, dataclass_json

Expand All @@ -21,13 +22,37 @@ class ReproducibilityConfig:
arrival_seed: int


@dataclass_json
@dataclass(frozen=True)
class TensorboardConfig:
active: bool
record_dir: str

def prepare_log_dir(self, working_dir: Path = None):
"""
Function to create logging directory used by TensorBoard. When running in a cluster, this function should not be
used, as the TensorBoard instance that is started simultaneously with the Orchestrator.
@param working_dir: Current working directory, by default PWD is assumed at which the Python interpreter is
started.
@type working_dir: pathlib.Path
@return: None
@rtype: None
"""
dir_to_check = Path(self.record_dir)
if working_dir:
dir_to_check = working_dir.joinpath(dir_to_check)
if not dir_to_check.exists() and dir_to_check.parent.is_dir():
dir_to_check.mkdir()


@dataclass_json
@dataclass
class ExecutionConfig:
general_net: GeneralNetConfig = field(metadata=config(field_name="net"))
reproducibility: ReproducibilityConfig
tensorboard: TensorboardConfig

experiment_prefix: str = "experiment"
tensorboard_active: bool = True
cuda: bool = False


Expand Down Expand Up @@ -99,7 +124,6 @@ def set_test_data_loader_pickle_path(self, path, name='cifar10'):
def get_test_data_loader_pickle_path(self):
return self.test_data_loader_pickle_path[self.dataset_name]


def should_save_model(self, epoch_idx):
"""
Returns true/false models should be saved.
Expand All @@ -108,4 +132,4 @@ def should_save_model(self, epoch_idx):
:type epoch_idx: int
"""
return self.execution_config.general_net.save_model and (
epoch_idx == 1 or epoch_idx % self.execution_config.general_net.save_epoch_interval == 0)
epoch_idx == 1 or epoch_idx % self.execution_config.general_net.save_epoch_interval == 0)
2 changes: 1 addition & 1 deletion fltk/util/task/generator/arrival_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def stop(self) -> None:
self.logger.info("Received stopping signal")
self._alive = False

def run(self):
def run(self, **kwargs):
"""
Run function to generate arrivals during existence of the Orchestrator. WIP.
Expand Down

0 comments on commit 23064b6

Please sign in to comment.