Skip to content

Commit

Permalink
Merge pull request #232 from Quoding/buffer_filter_fix
Browse files Browse the repository at this point in the history
Fix PrioritizedReplayBuffer filtering
  • Loading branch information
josephdviviano authored Jan 24, 2025
2 parents fb17e2b + 1649660 commit 59a1efa
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):
# dim=-1,
# )

# If all trajectories were filtered, stop there.
if not len(training_objects):
return

if self.cutoff_distance >= 0:
# Filter the batch for diverse final_states with high reward.
batch = training_objects.last_states.tensor.float()
Expand Down

0 comments on commit 59a1efa

Please sign in to comment.