Skip to content

Commit

Permalink
debugging sync
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Nov 23, 2023
1 parent 716ee7a commit dfb929d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
15 changes: 10 additions & 5 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ def get_pfs_and_pbs(
# 2) Also we should be able to store the outputs of the NN forward pass
# to calculate both the on and off policy distributions at once.
# if not isinstance(estimator_outputs, type(None)):
# import IPython; IPython.embed()
try:
idx = torch.ones(trajectories.actions.batch_shape).bool()
estimator_outputs = estimator_outputs[idx]
estimator_outputs = estimator_outputs[~trajectories.actions.is_dummy]
except:
raise Exception(
"GFlowNet is off policy but no estimator_outputs found."
Expand All @@ -169,14 +169,19 @@ def get_pfs_and_pbs(
# match... to be removed asap.

# import IPython; IPython.embed()
# new_estimator_outputs = self.pf(valid_states)
new_estimator_outputs = self.pf(valid_states)
# print("recomputed-original matches / total:\n{}/{}".format(
# (new_estimator_outputs == estimator_outputs).sum(),
# new_estimator_outputs.nelement(),
# )
# )
# idx = ~(new_estimator_outputs == estimator_outputs)
# # print("Mismatches Elements={}".format(valid_states.tensor[idx]))
print("Mismatches Indices={}".format(
(new_estimator_outputs != estimator_outputs).nonzero(as_tuple=True)[0]
)
)
idx = ~(new_estimator_outputs == estimator_outputs)
print("Mismatches Elements={}".format(valid_states.tensor[idx]))

# print("Mismatches Diffs ={}".format(
# torch.abs(new_estimator_outputs[idx] - estimator_outputs[idx]).detach().numpy()
# ))
Expand Down
3 changes: 0 additions & 3 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,6 @@ def sample_trajectories(
trajectories_actions += [actions]
trajectories_logprobs += [log_probs]

import IPython

IPython.embed()
if self.estimator.is_backward:
new_states = env.backward_step(states, actions)
else:
Expand Down

0 comments on commit dfb929d

Please sign in to comment.