Skip to content

Commit

Permalink
Utilize unified seed configuration in reproducability
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed Sep 6, 2022
1 parent 81b346e commit caba030
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions fltk/nets/util/reproducability.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def cuda_reproducible_backend(cuda: bool) -> None:
torch.backends.cudnn.deterministic = False


def init_reproducibility(config: Optional[ExecutionConfig] = None, test_seed: int = 42) -> None:
def init_reproducibility(config: Optional[ExecutionConfig] = None, seed: Optional[int] = None) -> None:
"""
Function to pre-set all seeds for libraries used during training. Allows for re-producible network initialization,
and non-deterministic number generation. Allows to prevent 'lucky' draws in network initialization.
Expand All @@ -34,21 +34,17 @@ def init_reproducibility(config: Optional[ExecutionConfig] = None, test_seed: in
@return: None
@rtype: None
"""
if config:
random_seed = config.reproducibility.arrival_seed
torch_seed = config.reproducibility.torch_seed
cuda = config.cuda
else:
random_seed = test_seed
torch_seed = test_seed
cuda = torch.cuda.is_available()
torch_seed, rand_seed = seed, seed
if not seed:
torch_seed, rand_seed = config.reproducibility.seeds[0], config.reproducibility.seeds[0]


torch.manual_seed(torch_seed)
if cuda:
if config.cuda:
torch.cuda.manual_seed_all(torch_seed)
cuda_reproducible_backend(True)
np.random.seed(random_seed)
os.environ['PYTHONHASHSEED'] = str(random_seed)
np.random.seed(rand_seed)
os.environ['PYTHONHASHSEED'] = str(rand_seed)



Expand Down

0 comments on commit caba030

Please sign in to comment.