From 43e60954f9deb9dceb91ffae1fa70be6971caca3 Mon Sep 17 00:00:00 2001 From: inel Date: Fri, 11 Oct 2024 10:34:54 +0200 Subject: [PATCH] small fixes --- src/gfn/containers/replay_buffer.py | 5 ++--- src/gfn/env.py | 12 ++++++------ 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index a5205249..47352b82 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -155,8 +155,8 @@ def _add_objs(self, training_objects: Transitions | Trajectories | tuple[States] # Add the terminating states to the buffer. if self.terminating_states is not None: - assert terminating_states is not None - self.terminating_states.extend(terminating_states) + assert self.terminating_states is not None + self.terminating_states.extend(self.terminating_states) # Sort terminating states by logreward as well. self.terminating_states = self.terminating_states[ix] @@ -164,7 +164,6 @@ def _add_objs(self, training_objects: Transitions | Trajectories | tuple[States] def add(self, training_objects: Transitions | Trajectories | tuple[States]): """Adds a training object to the buffer.""" - terminating_states = None if isinstance(training_objects, tuple): assert self.objects_type == "states" and self.terminating_states is not None training_objects, terminating_states = training_objects diff --git a/src/gfn/env.py b/src/gfn/env.py index c1569235..c6097243 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -271,7 +271,7 @@ def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]: @property def log_partition(self) -> float: "Returns the logarithm of the partition function." - return NotImplementedError( + raise NotImplementedError( "The environment does not support enumeration of states" ) @@ -432,20 +432,20 @@ def get_terminating_states_indices( @property def n_states(self) -> int: - return NotImplementedError( + raise NotImplementedError( "The environment does not support enumeration of states" ) @property def n_terminating_states(self) -> int: - return NotImplementedError( + raise NotImplementedError( "The environment does not support enumeration of states" ) @property def true_dist_pmf(self) -> TT["n_states", torch.float]: "Returns a one-dimensional tensor representing the true distribution." - return NotImplementedError( + raise NotImplementedError( "The environment does not support enumeration of states" ) @@ -456,7 +456,7 @@ def all_states(self) -> DiscreteStates: This should satisfy: self.get_states_indices(self.all_states) == torch.arange(self.n_states) """ - return NotImplementedError( + raise NotImplementedError( "The environment does not support enumeration of states" ) @@ -467,6 +467,6 @@ def terminating_states(self) -> DiscreteStates: This should satisfy: self.get_terminating_states_indices(self.terminating_states) == torch.arange(self.n_terminating_states) """ - return NotImplementedError( + raise NotImplementedError( "The environment does not support enumeration of states" )