Skip to content

Commit

Permalink
Add replication to config and change logging dir
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed Sep 6, 2022
1 parent e90d136 commit c193414
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
21 changes: 16 additions & 5 deletions fltk/util/config/distributed_config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from __future__ import annotations
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from typing import Optional, List

from dataclasses_json import config, dataclass_json

from fltk.util.config.definitions import OrchestratorType

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from fltk.util.config import DistLearningConfig


@dataclass_json
@dataclass
class GeneralNetConfig:
Expand All @@ -20,8 +28,10 @@ class GeneralNetConfig:
@dataclass_json
@dataclass(frozen=True)
class ReproducibilityConfig:
torch_seed: int
arrival_seed: int
"""
Dataclass object to hold experiment configuration settings related to reproducibility of experiments.
"""
seeds: List[int]


@dataclass_json
Expand Down Expand Up @@ -139,7 +149,7 @@ def get_log_dir(self):
"""
return self.execution_config.log_path

def get_log_path(self, experiment_id: str, client_id: int, network_name: str) -> Path:
def get_log_path(self, experiment_id: str, client_id: int, learn_params: DistLearningConfig) -> Path:
"""
Function to get the logging path that corresponds to a specific experiment, client and network that has been
deployed as learning task.
Expand All @@ -153,7 +163,8 @@ def get_log_path(self, experiment_id: str, client_id: int, network_name: str) ->
@rtype: Path
"""
base_log = Path(self.execution_config.tensorboard.record_dir)
experiment_dir = Path(f"{self.execution_config.experiment_prefix}_{client_id}_{network_name}_{experiment_id}")
model, dataset, replication = learn_params.model, learn_params.dataset, learn_params.replication
experiment_dir = Path(f"{replication}/{self.execution_config.experiment_prefix}_{experiment_id}/{client_id}/{model}_{dataset}")
return base_log.joinpath(experiment_dir)

def get_data_path(self) -> Path:
Expand Down
1 change: 1 addition & 0 deletions fltk/util/config/learning_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def get_safe_loader() -> yaml.SafeLoader:
@dataclass_json
@dataclass
class LearningConfig:
replication: int = field(metadata=dict(required=False, missing=-1))
batch_size: int = field(metadata=dict(required=False, missing=128))
test_batch_size: int = field(metadata=dict(required=False, missing=128))
cuda: bool = field(metadata=dict(required=False, missing=False))
Expand Down

0 comments on commit c193414

Please sign in to comment.