Skip to content

Commit

Permalink
Make arrival task hashable (set insertion)
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed Sep 18, 2022
1 parent ca14ccf commit 1007922
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions fltk/util/task/arrival_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import uuid
from dataclasses import field, dataclass
# noinspection PyUnresolvedReferences
from typing import OrderedDict, Optional, T, List
from typing import Optional, T, List, OrderedDict
from uuid import UUID

from frozendict import FrozenOrderedDict

from fltk.datasets.dataset import Dataset
from fltk.util.config.definitions import Nets
from fltk.util.config.experiment_config import (OptimizerConfig, HyperParameters, SystemResources, SystemParameters,
Expand All @@ -16,7 +18,7 @@
MASTER_REPLICATION: int = 1 # Static master replication value, dictated by PytorchTrainingJobs


@dataclass
@dataclass(frozen=True)
class ArrivalTask(abc.ABC):
"""
DataClass representation of an ArrivalTask, representing all the information needed to spawn a new learning task.
Expand All @@ -27,7 +29,7 @@ class ArrivalTask(abc.ABC):
loss_function: str = field(compare=False)
seed: int = field(compare=False)
replication: int = field(compare=False)
type_map: Optional[OrderedDict[str, int]]
type_map: "Optional[FrozenOrderedDict[str, int]]"
system_parameters: SystemParameters = field(compare=False)
hyper_parameters: HyperParameters = field(compare=False)
learning_parameters: LearningParameters = field(compare=False)
Expand Down Expand Up @@ -175,7 +177,7 @@ def get_net_param(self, parameter):
return getattr(self, parameter)


@dataclass(order=True)
@dataclass(order=True, frozen=True)
class DistributedArrivalTask(ArrivalTask):
"""
Object to contain configuration of training task. It describes the following properties;
Expand Down Expand Up @@ -211,7 +213,7 @@ def build(arrival: Arrival, u_id: uuid.UUID, replication: int) -> T:
loss_function=arrival.task.network_configuration.loss_function,
seed=random.randint(0, 2**32 - 2),
replication=replication,
type_map=collections.OrderedDict({
type_map=FrozenOrderedDict({
'Master': MASTER_REPLICATION,
'Worker': arrival.task.system_parameters.data_parallelism - MASTER_REPLICATION
}),
Expand All @@ -221,7 +223,7 @@ def build(arrival: Arrival, u_id: uuid.UUID, replication: int) -> T:
return task


@dataclass(order=True)
@dataclass(order=True, frozen=True)
class FederatedArrivalTask(ArrivalTask):
"""
Task describing configuration objects for running FederatedLearning experiments on K8s.
Expand All @@ -236,7 +238,7 @@ def build(arrival: Arrival, u_id: uuid.UUID, replication: int) -> T:
loss_function=arrival.task.network_configuration.loss_function,
seed=arrival.task.seed,
replication=replication,
type_map=collections.OrderedDict({
type_map=FrozenOrderedDict({
'Master': MASTER_REPLICATION,
'Worker': arrival.task.system_parameters.data_parallelism
}),
Expand Down

0 comments on commit 1007922

Please sign in to comment.