diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 4cdb136e..60c97635 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -154,9 +154,9 @@ def get_pfs_and_pbs( # 2) Also we should be able to store the outputs of the NN forward pass # to calculate both the on and off policy distributions at once. # if not isinstance(estimator_outputs, type(None)): + # import IPython; IPython.embed() try: - idx = torch.ones(trajectories.actions.batch_shape).bool() - estimator_outputs = estimator_outputs[idx] + estimator_outputs = estimator_outputs[~trajectories.actions.is_dummy] except: raise Exception( "GFlowNet is off policy but no estimator_outputs found." @@ -169,14 +169,19 @@ def get_pfs_and_pbs( # match... to be removed asap. # import IPython; IPython.embed() - # new_estimator_outputs = self.pf(valid_states) + new_estimator_outputs = self.pf(valid_states) # print("recomputed-original matches / total:\n{}/{}".format( # (new_estimator_outputs == estimator_outputs).sum(), # new_estimator_outputs.nelement(), # ) # ) - # idx = ~(new_estimator_outputs == estimator_outputs) - # # print("Mismatches Elements={}".format(valid_states.tensor[idx])) + print("Mismatches Indices={}".format( + (new_estimator_outputs != estimator_outputs).nonzero(as_tuple=True)[0] + ) + ) + idx = ~(new_estimator_outputs == estimator_outputs) + print("Mismatches Elements={}".format(valid_states.tensor[idx])) + # print("Mismatches Diffs ={}".format( # torch.abs(new_estimator_outputs[idx] - estimator_outputs[idx]).detach().numpy() # )) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index fc7665da..bec125f8 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -178,9 +178,6 @@ def sample_trajectories( trajectories_actions += [actions] trajectories_logprobs += [log_probs] - import IPython - - IPython.embed() if self.estimator.is_backward: new_states = env.backward_step(states, actions) else: