diff --git a/fltk/nets/util/reproducability.py b/fltk/nets/util/reproducability.py index 46b9e118..a3ef9261 100644 --- a/fltk/nets/util/reproducability.py +++ b/fltk/nets/util/reproducability.py @@ -25,7 +25,7 @@ def cuda_reproducible_backend(cuda: bool) -> None: torch.backends.cudnn.deterministic = False -def init_reproducibility(config: Optional[ExecutionConfig] = None, test_seed: int = 42) -> None: +def init_reproducibility(config: Optional[ExecutionConfig] = None, seed: Optional[int] = None) -> None: """ Function to pre-set all seeds for libraries used during training. Allows for re-producible network initialization, and non-deterministic number generation. Allows to prevent 'lucky' draws in network initialization. @@ -34,21 +34,17 @@ def init_reproducibility(config: Optional[ExecutionConfig] = None, test_seed: in @return: None @rtype: None """ - if config: - random_seed = config.reproducibility.arrival_seed - torch_seed = config.reproducibility.torch_seed - cuda = config.cuda - else: - random_seed = test_seed - torch_seed = test_seed - cuda = torch.cuda.is_available() + torch_seed, rand_seed = seed, seed + if not seed: + torch_seed, rand_seed = config.reproducibility.seeds[0], config.reproducibility.seeds[0] + torch.manual_seed(torch_seed) - if cuda: + if config.cuda: torch.cuda.manual_seed_all(torch_seed) cuda_reproducible_backend(True) - np.random.seed(random_seed) - os.environ['PYTHONHASHSEED'] = str(random_seed) + np.random.seed(rand_seed) + os.environ['PYTHONHASHSEED'] = str(rand_seed)