Skip to content

Commit

Permalink
Update tests to reflect changes
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed Sep 6, 2022
1 parent b5e2518 commit 92c4a7d
Show file tree
Hide file tree
Showing 9 changed files with 11 additions and 6 deletions.
3 changes: 1 addition & 2 deletions configs/test/test_experiment.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
"epoch_save_end_suffix": "end"
},
"reproducibility": {
"torch_seed": 2053695854357871005,
"arrival_seed": 123
"seeds": [44]
}
}
}
1 change: 1 addition & 0 deletions experiments/test/data_parallel.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
replication: -1
batch_size: 1
test_batch_size: 1
cuda: True
Expand Down
1 change: 1 addition & 0 deletions experiments/test/data_parallel_non_default.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
replication: 2
batch_size: 1
test_batch_size: 1
cuda: True
Expand Down
1 change: 1 addition & 0 deletions experiments/test/federated.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
replication: -1
batch_size: 1
test_batch_size: 1
cuda: false
Expand Down
1 change: 1 addition & 0 deletions experiments/test/federated_non_default.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
replication: 2
batch_size: 1
test_batch_size: 1
cuda: true
Expand Down
3 changes: 2 additions & 1 deletion experiments/test/parsing/data_parallel_parsed.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
replication: -1
batch_size: 500
test_batch_size: 1000
cuda: False
Expand All @@ -11,5 +12,5 @@ dataset: fashion-mnist
max_epoch: 44
learning_rate: 1e3
learning_rate_decay: 0.0002
seed: 2053695854357871005
seed: 2746317213
loss: MSELoss
1 change: 1 addition & 0 deletions experiments/test/parsing/federated_parsed.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
replication: -1
batch_size: 1
test_batch_size: 1
cuda: True
Expand Down
2 changes: 1 addition & 1 deletion fltk/nets/util/reproducability.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def init_reproducibility(config: Optional[ExecutionConfig] = None, seed: Optiona


torch.manual_seed(torch_seed)
if config.cuda:
if seed or (config and config.cuda):
torch.cuda.manual_seed_all(torch_seed)
cuda_reproducible_backend(True)
np.random.seed(rand_seed)
Expand Down
4 changes: 2 additions & 2 deletions tests/nets/reproducibility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ class TestReproducibleNet(unittest.TestCase):
map(lambda x: [x], models)
)
def test_reproducible_initialization(self, network_class: Type[torch.nn.Module]): # pylint: disable=missing-function-docstring
init_reproducibility()
init_reproducibility(seed=42)
param_1: OrderedDict[str, torch.nn.Module] = network_class().state_dict()
init_reproducibility()
init_reproducibility(seed=42)
param_2: OrderedDict[str, torch.nn.Module] = network_class().state_dict()

for key, value in param_1.items():
Expand Down

0 comments on commit 92c4a7d

Please sign in to comment.