From 92c4a7d3a1a83c5910a108bc5433694a61a4988c Mon Sep 17 00:00:00 2001 From: "Jeroen M. Galjaard" Date: Tue, 6 Sep 2022 15:25:06 +0200 Subject: [PATCH] Update tests to reflect changes --- configs/test/test_experiment.json | 3 +-- experiments/test/data_parallel.yaml | 1 + experiments/test/data_parallel_non_default.yaml | 1 + experiments/test/federated.yaml | 1 + experiments/test/federated_non_default.yaml | 1 + experiments/test/parsing/data_parallel_parsed.yaml | 3 ++- experiments/test/parsing/federated_parsed.yaml | 1 + fltk/nets/util/reproducability.py | 2 +- tests/nets/reproducibility_test.py | 4 ++-- 9 files changed, 11 insertions(+), 6 deletions(-) diff --git a/configs/test/test_experiment.json b/configs/test/test_experiment.json index b433527d..09802545 100644 --- a/configs/test/test_experiment.json +++ b/configs/test/test_experiment.json @@ -27,8 +27,7 @@ "epoch_save_end_suffix": "end" }, "reproducibility": { - "torch_seed": 2053695854357871005, - "arrival_seed": 123 + "seeds": [44] } } } \ No newline at end of file diff --git a/experiments/test/data_parallel.yaml b/experiments/test/data_parallel.yaml index 5d6b6f81..6d7185ee 100644 --- a/experiments/test/data_parallel.yaml +++ b/experiments/test/data_parallel.yaml @@ -1,3 +1,4 @@ +replication: -1 batch_size: 1 test_batch_size: 1 cuda: True diff --git a/experiments/test/data_parallel_non_default.yaml b/experiments/test/data_parallel_non_default.yaml index 2e51a562..19f8c7e3 100644 --- a/experiments/test/data_parallel_non_default.yaml +++ b/experiments/test/data_parallel_non_default.yaml @@ -1,3 +1,4 @@ +replication: 2 batch_size: 1 test_batch_size: 1 cuda: True diff --git a/experiments/test/federated.yaml b/experiments/test/federated.yaml index 21703a81..b51efc22 100644 --- a/experiments/test/federated.yaml +++ b/experiments/test/federated.yaml @@ -1,3 +1,4 @@ +replication: -1 batch_size: 1 test_batch_size: 1 cuda: false diff --git a/experiments/test/federated_non_default.yaml b/experiments/test/federated_non_default.yaml index 755de821..ca54dbb7 100644 --- a/experiments/test/federated_non_default.yaml +++ b/experiments/test/federated_non_default.yaml @@ -1,3 +1,4 @@ +replication: 2 batch_size: 1 test_batch_size: 1 cuda: true diff --git a/experiments/test/parsing/data_parallel_parsed.yaml b/experiments/test/parsing/data_parallel_parsed.yaml index 509154aa..ebaa3d8c 100644 --- a/experiments/test/parsing/data_parallel_parsed.yaml +++ b/experiments/test/parsing/data_parallel_parsed.yaml @@ -1,3 +1,4 @@ +replication: -1 batch_size: 500 test_batch_size: 1000 cuda: False @@ -11,5 +12,5 @@ dataset: fashion-mnist max_epoch: 44 learning_rate: 1e3 learning_rate_decay: 0.0002 -seed: 2053695854357871005 +seed: 2746317213 loss: MSELoss diff --git a/experiments/test/parsing/federated_parsed.yaml b/experiments/test/parsing/federated_parsed.yaml index 78130415..71220f07 100644 --- a/experiments/test/parsing/federated_parsed.yaml +++ b/experiments/test/parsing/federated_parsed.yaml @@ -1,3 +1,4 @@ +replication: -1 batch_size: 1 test_batch_size: 1 cuda: True diff --git a/fltk/nets/util/reproducability.py b/fltk/nets/util/reproducability.py index a3ef9261..50ac744d 100644 --- a/fltk/nets/util/reproducability.py +++ b/fltk/nets/util/reproducability.py @@ -40,7 +40,7 @@ def init_reproducibility(config: Optional[ExecutionConfig] = None, seed: Optiona torch.manual_seed(torch_seed) - if config.cuda: + if seed or (config and config.cuda): torch.cuda.manual_seed_all(torch_seed) cuda_reproducible_backend(True) np.random.seed(rand_seed) diff --git a/tests/nets/reproducibility_test.py b/tests/nets/reproducibility_test.py index 7311814f..f5480e4c 100644 --- a/tests/nets/reproducibility_test.py +++ b/tests/nets/reproducibility_test.py @@ -30,9 +30,9 @@ class TestReproducibleNet(unittest.TestCase): map(lambda x: [x], models) ) def test_reproducible_initialization(self, network_class: Type[torch.nn.Module]): # pylint: disable=missing-function-docstring - init_reproducibility() + init_reproducibility(seed=42) param_1: OrderedDict[str, torch.nn.Module] = network_class().state_dict() - init_reproducibility() + init_reproducibility(seed=42) param_2: OrderedDict[str, torch.nn.Module] = network_class().state_dict() for key, value in param_1.items():