Skip to content

Commit

Permalink
Resolve imports and styling issues
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed Sep 16, 2022
1 parent 3596776 commit de6966b
Show file tree
Hide file tree
Showing 13 changed files with 28 additions and 23 deletions.
11 changes: 6 additions & 5 deletions fltk/core/distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def __init__(self, rank: int, task_id: str, world_size: int, config: Distributed

# Create model and dataset
self.loss_function = self.learning_params.get_loss_function()()
self.dataset = get_dist_dataset(self.learning_params.dataset)(self.config, self.learning_params, self._id, self._world_size)
self.dataset = get_dist_dataset(self.learning_params.dataset)(self.config, self.learning_params, self._id,
self._world_size)
self.model = get_net(self.learning_params.model)()
self.device = self._init_device()

Expand Down Expand Up @@ -80,7 +81,7 @@ def prepare_learner(self, distributed: bool = False) -> None:

if self.config.execution_config.tensorboard.active and self._id == 0:
self.tb_writer = SummaryWriter(
str(self.config.get_log_path(self._task_id, self._id, self.learning_params)))
str(self.config.get_log_path(self._task_id, self._id, self.learning_params)))

def stop_learner(self):
"""
Expand All @@ -91,7 +92,7 @@ def stop_learner(self):
self._logger.info(f"Tearing down Client {self._id}")
self.tb_writer.close()

def _init_device(self, default_device: torch.device = torch.device('cpu')): # pylint: disable=no-member
def _init_device(self, default_device: torch.device = torch.device('cpu')): # pylint: disable=no-member
"""
Initialize Torch to use available devices. Either prepares CUDA device, or disables CUDA during execution to run
with CPU only inference/training.
Expand All @@ -102,7 +103,7 @@ def _init_device(self, default_device: torch.device = torch.device('cpu')): # py
@rtype: None
"""
if self.config.cuda_enabled() and torch.cuda.is_available():
return torch.device('cuda') # pylint: disable=no-member
return torch.device('cuda') # pylint: disable=no-member
# Force usage of CPU
torch.cuda.is_available = lambda: False
return default_device
Expand Down Expand Up @@ -188,7 +189,7 @@ def test(self) -> Tuple[float, float, np.array, np.array, np.array]:
outputs = self.model(images)
# Currently, the FLTK framework assumes that a classification task is performed (hence max).
# Future work may add support for non-classification training.
_, predicted = torch.max(outputs.data, 1) # pylint: disable=no-member
_, predicted = torch.max(outputs.data, 1) # pylint: disable=no-member
total += labels.size(0)
correct += (predicted == labels).sum().item()

Expand Down
1 change: 1 addition & 0 deletions fltk/core/distributed/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fltk.util.config import DistributedConfig


# noinspection PyUnresolvedReferences
def download_datasets(args: Namespace, config: DistributedConfig):
"""
Function to Download datasets to a system. This is currently meant to be run (using the extractor mode of FLTK) to
Expand Down
4 changes: 2 additions & 2 deletions fltk/nets/util/model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from __future__ import annotations

import logging
from collections import OrderedDict
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Union, Type

import deprecate
import torch
from torch.utils.tensorboard import SummaryWriter

from fltk.util.results import EpochData
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from fltk.util.config import DistributedConfig
Expand Down
1 change: 1 addition & 0 deletions fltk/nets/util/reproducability.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fltk.util.config import DistLearnerConfig


# noinspection PyUnresolvedReferences
def cuda_reproducible_backend(cuda: bool) -> None:
"""
Function to set the CUDA backend to reproducible (i.e. deterministic) or to default configuration (per PyTorch
Expand Down
7 changes: 4 additions & 3 deletions fltk/util/cluster/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import logging
import time
from collections import defaultdict
from dataclasses import dataclass
from multiprocessing.pool import ThreadPool
from typing import Dict, List, Tuple, Optional, OrderedDict, Union
from typing import TYPE_CHECKING
from uuid import UUID

import schedule
Expand All @@ -15,12 +17,10 @@

from fltk.util.cluster.conversion import Convert
from fltk.util.singleton import Singleton
from fltk.util.config.experiment import SystemResources
from fltk.util.task.arrival_task import DistributedArrivalTask, ArrivalTask, FederatedArrivalTask

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from fltk.util.config.experiment_config import SystemResources
from fltk.util.config import DistributedConfig

@dataclass
Expand Down Expand Up @@ -188,6 +188,7 @@ def start(self):
"""
self._logger.info("Spinning up cluster manager...")
# Set debugging to WARNING only, as otherwise DEBUG statements will flood the logs.
# noinspection PyUnresolvedReferences
client.rest.logger.setLevel(logging.WARNING)
self.__alive = True
self.__thread_pool.apply_async(self._watchdog.start)
Expand Down
1 change: 1 addition & 0 deletions fltk/util/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fltk.util.config.learner_config import FedLearnerConfig, get_safe_loader, DistLearnerConfig
from fltk.util.config.experiment_config import ExperimentConfig, ExperimentParser


def retrieve_config_network_params(conf: FedLearnerConfig, nic=None, host=None):
if hasattr(conf, 'system'):
system_attr = getattr(conf, 'system')
Expand Down
1 change: 1 addition & 0 deletions fltk/util/config/definitions/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict, Type, Union

import torch
# noinspection PyProtectedMember
from torch.nn.modules.loss import _Loss


Expand Down
2 changes: 0 additions & 2 deletions fltk/util/config/distributed_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,6 @@ def get_log_path(self, experiment_id: str, client_id: int, learn_params: DistLea
@type experiment_id: str
@param client_id: Rank of the client.
@type client_id: int
@param network_name: Name of the network that is to be trained.
@type network_name: str
@return: Path representation of the directory/path should be logged by the training process.
@rtype: Path
"""
Expand Down
1 change: 1 addition & 0 deletions fltk/util/config/experiment_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from dataclasses import dataclass, field
from pathlib import Path
# noinspection PyUnresolvedReferences
from typing import Optional, Union, Tuple, Dict, Any, MutableMapping, Type, OrderedDict, List, T

from dataclasses_json import dataclass_json, LetterCase, config
Expand Down
7 changes: 4 additions & 3 deletions fltk/util/task/arrival_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import random
import uuid
from dataclasses import field, dataclass
from typing import OrderedDict, Optional, T, List # noinspection PyUnresolvedReferences
# noinspection PyUnresolvedReferences
from typing import OrderedDict, Optional, T, List
from uuid import UUID

from fltk.datasets.dataset import Dataset
from fltk.util.config.definitions import Nets
from fltk.util.config.experiment import OptimizerConfig, HyperParameters, SystemResources, SystemParameters, \
SamplerConfiguration, LearningParameters
from fltk.util.config.experiment_config import (OptimizerConfig, HyperParameters, SystemResources, SystemParameters,
SamplerConfiguration, LearningParameters)
from fltk.util.task.generator.arrival_generator import Arrival

MASTER_REPLICATION: int = 1 # Static master replication value, dictated by PytorchTrainingJobs
Expand Down
6 changes: 3 additions & 3 deletions fltk/util/task/generator/arrival_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@

import numpy as np

from fltk.util.config.definitions.net import Nets
from fltk.datasets.dataset import Dataset
from fltk.util.config.definitions.net import Nets
from fltk.util.config.experiment_config import (HyperParameters, SystemParameters, LearningParameters, JobDescription,
ExperimentParser)
from fltk.util.singleton import Singleton
from fltk.util.config.experiment import HyperParameters, SystemParameters, LearningParameters, JobDescription, \
ExperimentParser
from fltk.util.task.train_task import TrainTask


Expand Down
4 changes: 2 additions & 2 deletions fltk/util/task/train_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from dataclasses_json import config

from fltk.util.config.definitions import ExperimentType
from fltk.util.config.experiment import NetworkConfiguration, SystemParameters, HyperParameters, LearningParameters, \
JobClassParameter, Priority
from fltk.util.config.experiment_config import (NetworkConfiguration, SystemParameters, HyperParameters, Priority,
LearningParameters, JobClassParameter)


@dataclass(order=True)
Expand Down
5 changes: 2 additions & 3 deletions tests/util/config/test_learning_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
import yaml

from fltk.core.distributed.orchestrator import render_template
from fltk.util.config import DistLearnerConfig, FedLearnerConfig, get_safe_loader
from fltk.util.config import DistLearnerConfig, FedLearnerConfig, get_safe_loader, ExperimentParser
from fltk.util.config.definitions import Optimizations
from fltk.util.task import FederatedArrivalTask, DistributedArrivalTask
from fltk.util.config.experiment import ExperimentParser
from fltk.util.task.train_task import TrainTask
from fltk.util.task.generator.arrival_generator import Arrival
from fltk.util.task.train_task import TrainTask

TEST_FED_CONF = './configs/test/fed_non_default.json'
TEST_DIST_CONF = './configs/test/dist_non_default.json'
Expand Down

0 comments on commit de6966b

Please sign in to comment.