Skip to content

Commit

Permalink
Update config file to learning_config file
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed May 11, 2022
1 parent a8640ab commit dabcf8e
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 26 deletions.
5 changes: 3 additions & 2 deletions fltk/core/distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from fltk.nets import get_net
from fltk.nets.util import calculate_class_precision, calculate_class_recall, save_model, load_model_from_file
from fltk.schedulers import MinCapableStepLR, LearningScheduler
from fltk.strategy import get_optimizer
from fltk.util.config import DistributedConfig, DistLearningConfig
from fltk.util.results import EpochData

Expand Down Expand Up @@ -68,8 +69,8 @@ def prepare_learner(self, distributed: bool = False) -> None:
# Wrap the model to use pytorch DistributedDataParallel wrapper for all reduce.
self.model = torch.nn.parallel.DistributedDataParallel(self.model)

# Currently, it is assumed to use an SGD optimizer. **kwargs need to be used to launch this properly
optim_type: Type[torch.optim.Optimizer] = self.learning_params.get_optimizer()
# Currently, it is assumed to use an SGD optimizer, using non-federated optimizer types.
optim_type = get_optimizer(self.learning_params.optimizer, federated=False)
self.optimizer = optim_type(self.model.parameters(), **self.learning_params.optimizer_args)
self.scheduler = MinCapableStepLR(self.optimizer,
self.learning_params.scheduler_step_size,
Expand Down
7 changes: 5 additions & 2 deletions fltk/strategy/optimization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .fed_nova import FedNova


def get_optimizer(name: Optimizations) -> Type[torch.optim.Optimizer]:
def get_optimizer(name: Optimizations, federated: bool = True) -> Type[torch.optim.Optimizer]:
"""
Helper function to get specific Optimization class references.
@param name: Optimizer class reference.
Expand All @@ -20,7 +20,10 @@ def get_optimizer(name: Optimizations) -> Type[torch.optim.Optimizer]:
Optimizations.adam: torch.optim.Adam,
Optimizations.adam_w: torch.optim.AdamW,
Optimizations.sgd: torch.optim.SGD,
}
if federated:
optimizers.update({
Optimizations.fedprox: FedProx,
Optimizations.fednova: FedNova
}
})
return optimizers[name]
2 changes: 1 addition & 1 deletion fltk/util/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging

from fltk.util.config.distributed_config import DistributedConfig
from fltk.util.config.config import FedLearningConfig, get_safe_loader, DistLearningConfig
from fltk.util.config.learning_config import FedLearningConfig, get_safe_loader, DistLearningConfig


def retrieve_config_network_params(conf: FedLearningConfig, nic=None, host=None):
Expand Down
37 changes: 16 additions & 21 deletions fltk/util/config/config.py → fltk/util/config/learning_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,26 @@ def get_safe_loader() -> yaml.SafeLoader:
return safe_loader


# fixme: With python 3.10, this can be done with the dataclass kw_only kwarg.
@dataclass_json
@dataclass
class FedLearningConfig:
batch_size: int = 1
test_batch_size: int = 1000
class LearningConfig:
batch_size: int = field(metadata=dict(required=False, missing=128))
test_batch_size: int = field(metadata=dict(required=False, missing=128))
cuda: bool = field(metadata=dict(required=False, missing=False))
scheduler_step_size: int = field(metadata=dict(required=False, missing=50))
scheduler_gamma: float = field(metadata=dict(required=False, missing=0.5))


@dataclass_json
@dataclass
class FedLearningConfig(LearningConfig):
rounds: int = 2
epochs: int = 1
lr: float = 0.01
momentum: float = 0.1
cuda: bool = False
shuffle: bool = False
log_interval: int = 10
scheduler_step_size: int = 50
scheduler_gamma: float = 0.5
min_lr: float = 1e-10
rng_seed = 0

Expand Down Expand Up @@ -169,8 +175,8 @@ def from_yaml(path: Path):


@dataclass_json
@dataclass(frozen=True)
class DistLearningConfig: # pylint: disable=too-many-instance-attributes
@dataclass
class DistLearningConfig(LearningConfig): # pylint: disable=too-many-instance-attributes
"""
Class encapsulating LearningParameters, for now used under DistributedLearning.
"""
Expand All @@ -182,13 +188,10 @@ class DistLearningConfig: # pylint: disable=too-many-instance-attributes
learning_rate: float
learning_decay: float
loss: str
optimizer: str
optimizer: Optimizations
optimizer_args: Dict[str, Any]
scheduler_step_size: int
scheduler_gamma: float
min_lr: float

cuda: bool
min_lr: float
seed: int

@staticmethod
Expand Down Expand Up @@ -216,11 +219,3 @@ def get_loss(self) -> Type:
"""
return self.__safe_get(_available_loss, self.loss)

def get_optimizer(self) -> Type[torch.optim.Optimizer]:
"""
Function to obtain the loss function Type that was given via commandline to be used during the training
execution.
@return: Type corresponding to the Optimizer to be used during training.
@rtype: Type[torch.optim.Optimizer]
"""
return self.__safe_get(_available_optimizer, self.optimizer)

0 comments on commit dabcf8e

Please sign in to comment.