Skip to content

Commit

Permalink
Fix PrioritizedReplayBuffer filtering
Browse files Browse the repository at this point in the history
Fix error where add crashes if no trajectory is better than buffer
trajectories.
  • Loading branch information
alexandrelarouche committed Jan 22, 2025
1 parent 6f132a8 commit 1649660
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):
# dim=-1,
# )

# If all trajectories were filtered, stop there.
if not len(training_objects):
return

if self.cutoff_distance >= 0:
# Filter the batch for diverse final_states with high reward.
batch = training_objects.last_states.tensor.float()
Expand Down

0 comments on commit 1649660

Please sign in to comment.