From 10079225bebfcd6685b7a0dc6fc63d144660581f Mon Sep 17 00:00:00 2001 From: JMGaljaard Date: Thu, 15 Sep 2022 21:24:04 +0200 Subject: [PATCH] Make arrival task hashable (set insertion) --- fltk/util/task/arrival_task.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/fltk/util/task/arrival_task.py b/fltk/util/task/arrival_task.py index cce4f79b..d23041a5 100644 --- a/fltk/util/task/arrival_task.py +++ b/fltk/util/task/arrival_task.py @@ -4,9 +4,11 @@ import uuid from dataclasses import field, dataclass # noinspection PyUnresolvedReferences -from typing import OrderedDict, Optional, T, List +from typing import Optional, T, List, OrderedDict from uuid import UUID +from frozendict import FrozenOrderedDict + from fltk.datasets.dataset import Dataset from fltk.util.config.definitions import Nets from fltk.util.config.experiment_config import (OptimizerConfig, HyperParameters, SystemResources, SystemParameters, @@ -16,7 +18,7 @@ MASTER_REPLICATION: int = 1 # Static master replication value, dictated by PytorchTrainingJobs -@dataclass +@dataclass(frozen=True) class ArrivalTask(abc.ABC): """ DataClass representation of an ArrivalTask, representing all the information needed to spawn a new learning task. @@ -27,7 +29,7 @@ class ArrivalTask(abc.ABC): loss_function: str = field(compare=False) seed: int = field(compare=False) replication: int = field(compare=False) - type_map: Optional[OrderedDict[str, int]] + type_map: "Optional[FrozenOrderedDict[str, int]]" system_parameters: SystemParameters = field(compare=False) hyper_parameters: HyperParameters = field(compare=False) learning_parameters: LearningParameters = field(compare=False) @@ -175,7 +177,7 @@ def get_net_param(self, parameter): return getattr(self, parameter) -@dataclass(order=True) +@dataclass(order=True, frozen=True) class DistributedArrivalTask(ArrivalTask): """ Object to contain configuration of training task. It describes the following properties; @@ -211,7 +213,7 @@ def build(arrival: Arrival, u_id: uuid.UUID, replication: int) -> T: loss_function=arrival.task.network_configuration.loss_function, seed=random.randint(0, 2**32 - 2), replication=replication, - type_map=collections.OrderedDict({ + type_map=FrozenOrderedDict({ 'Master': MASTER_REPLICATION, 'Worker': arrival.task.system_parameters.data_parallelism - MASTER_REPLICATION }), @@ -221,7 +223,7 @@ def build(arrival: Arrival, u_id: uuid.UUID, replication: int) -> T: return task -@dataclass(order=True) +@dataclass(order=True, frozen=True) class FederatedArrivalTask(ArrivalTask): """ Task describing configuration objects for running FederatedLearning experiments on K8s. @@ -236,7 +238,7 @@ def build(arrival: Arrival, u_id: uuid.UUID, replication: int) -> T: loss_function=arrival.task.network_configuration.loss_function, seed=arrival.task.seed, replication=replication, - type_map=collections.OrderedDict({ + type_map=FrozenOrderedDict({ 'Master': MASTER_REPLICATION, 'Worker': arrival.task.system_parameters.data_parallelism }),