diff --git a/fltk/core/distributed/client.py b/fltk/core/distributed/client.py index 8212d015..b031638d 100644 --- a/fltk/core/distributed/client.py +++ b/fltk/core/distributed/client.py @@ -13,6 +13,7 @@ from fltk.nets import get_net from fltk.nets.util import calculate_class_precision, calculate_class_recall, save_model, load_model_from_file from fltk.schedulers import MinCapableStepLR, LearningScheduler +from fltk.strategy import get_optimizer from fltk.util.config import DistributedConfig, DistLearningConfig from fltk.util.results import EpochData @@ -68,8 +69,8 @@ def prepare_learner(self, distributed: bool = False) -> None: # Wrap the model to use pytorch DistributedDataParallel wrapper for all reduce. self.model = torch.nn.parallel.DistributedDataParallel(self.model) - # Currently, it is assumed to use an SGD optimizer. **kwargs need to be used to launch this properly - optim_type: Type[torch.optim.Optimizer] = self.learning_params.get_optimizer() + # Currently, it is assumed to use an SGD optimizer, using non-federated optimizer types. + optim_type = get_optimizer(self.learning_params.optimizer, federated=False) self.optimizer = optim_type(self.model.parameters(), **self.learning_params.optimizer_args) self.scheduler = MinCapableStepLR(self.optimizer, self.learning_params.scheduler_step_size, diff --git a/fltk/strategy/optimization/__init__.py b/fltk/strategy/optimization/__init__.py index a394f55f..34720f08 100644 --- a/fltk/strategy/optimization/__init__.py +++ b/fltk/strategy/optimization/__init__.py @@ -7,7 +7,7 @@ from .fed_nova import FedNova -def get_optimizer(name: Optimizations) -> Type[torch.optim.Optimizer]: +def get_optimizer(name: Optimizations, federated: bool = True) -> Type[torch.optim.Optimizer]: """ Helper function to get specific Optimization class references. @param name: Optimizer class reference. @@ -20,7 +20,10 @@ def get_optimizer(name: Optimizations) -> Type[torch.optim.Optimizer]: Optimizations.adam: torch.optim.Adam, Optimizations.adam_w: torch.optim.AdamW, Optimizations.sgd: torch.optim.SGD, + } + if federated: + optimizers.update({ Optimizations.fedprox: FedProx, Optimizations.fednova: FedNova - } + }) return optimizers[name] diff --git a/fltk/util/config/__init__.py b/fltk/util/config/__init__.py index 7690692f..139a1d5a 100644 --- a/fltk/util/config/__init__.py +++ b/fltk/util/config/__init__.py @@ -6,7 +6,7 @@ import logging from fltk.util.config.distributed_config import DistributedConfig -from fltk.util.config.config import FedLearningConfig, get_safe_loader, DistLearningConfig +from fltk.util.config.learning_config import FedLearningConfig, get_safe_loader, DistLearningConfig def retrieve_config_network_params(conf: FedLearningConfig, nic=None, host=None): diff --git a/fltk/util/config/config.py b/fltk/util/config/learning_config.py similarity index 89% rename from fltk/util/config/config.py rename to fltk/util/config/learning_config.py index c84e40bd..7d644333 100644 --- a/fltk/util/config/config.py +++ b/fltk/util/config/learning_config.py @@ -49,20 +49,26 @@ def get_safe_loader() -> yaml.SafeLoader: return safe_loader +# fixme: With python 3.10, this can be done with the dataclass kw_only kwarg. @dataclass_json @dataclass -class FedLearningConfig: - batch_size: int = 1 - test_batch_size: int = 1000 +class LearningConfig: + batch_size: int = field(metadata=dict(required=False, missing=128)) + test_batch_size: int = field(metadata=dict(required=False, missing=128)) + cuda: bool = field(metadata=dict(required=False, missing=False)) + scheduler_step_size: int = field(metadata=dict(required=False, missing=50)) + scheduler_gamma: float = field(metadata=dict(required=False, missing=0.5)) + + +@dataclass_json +@dataclass +class FedLearningConfig(LearningConfig): rounds: int = 2 epochs: int = 1 lr: float = 0.01 momentum: float = 0.1 - cuda: bool = False shuffle: bool = False log_interval: int = 10 - scheduler_step_size: int = 50 - scheduler_gamma: float = 0.5 min_lr: float = 1e-10 rng_seed = 0 @@ -169,8 +175,8 @@ def from_yaml(path: Path): @dataclass_json -@dataclass(frozen=True) -class DistLearningConfig: # pylint: disable=too-many-instance-attributes +@dataclass +class DistLearningConfig(LearningConfig): # pylint: disable=too-many-instance-attributes """ Class encapsulating LearningParameters, for now used under DistributedLearning. """ @@ -182,13 +188,10 @@ class DistLearningConfig: # pylint: disable=too-many-instance-attributes learning_rate: float learning_decay: float loss: str - optimizer: str + optimizer: Optimizations optimizer_args: Dict[str, Any] - scheduler_step_size: int - scheduler_gamma: float - min_lr: float - cuda: bool + min_lr: float seed: int @staticmethod @@ -216,11 +219,3 @@ def get_loss(self) -> Type: """ return self.__safe_get(_available_loss, self.loss) - def get_optimizer(self) -> Type[torch.optim.Optimizer]: - """ - Function to obtain the loss function Type that was given via commandline to be used during the training - execution. - @return: Type corresponding to the Optimizer to be used during training. - @rtype: Type[torch.optim.Optimizer] - """ - return self.__safe_get(_available_optimizer, self.optimizer)