Skip to content

Commit

Permalink
Fix nans with nan_to_num
Browse files Browse the repository at this point in the history
  • Loading branch information
reverendbedford committed Sep 14, 2023
1 parent 686c2b1 commit cfa2647
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions pyoptmat/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class StatisticalModel(PyroModule):
entry i represents the noise in test type i
"""

def __init__(self, maker, names, locs, scales, eps):
def __init__(self, maker, names, locs, scales, eps, nan_num = False):
super().__init__()

self.maker = maker
Expand All @@ -203,6 +203,8 @@ def __init__(self, maker, names, locs, scales, eps):

self.type_noise = self.eps.dim() > 0

self.nan_num = nan_num

def get_params(self):
"""
Return the sampled parameters for input to the model
Expand Down Expand Up @@ -235,9 +237,12 @@ def forward(self, exp_data, exp_cycles, exp_types, exp_control, exp_results=None
predictions[:, :, 0], exp_cycles, exp_types
)

if self.nan_num:
results = torch.nan_to_num(results)

# Setup the full noise, which can be type specific
if self.type_noise:
full_noise = torch.empty(exp_data.shape[-1])
full_noise = torch.empty(exp_data.shape[-1], device = self.eps.device)
for i in experiments.exp_map.values():
full_noise[exp_types == i] = self.eps[i]
else:
Expand Down

0 comments on commit cfa2647

Please sign in to comment.