diff --git a/README.md b/README.md index cb1b336b..afa2d504 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ pip install torchgfn[scripts] To install the cutting edge version (from the `main` branch): ```bash -git clone https://github.com/saleml/torchgfn.git +git clone https://github.com/GFNOrg/torchgfn.git conda create -n gfn python=3.10 conda activate gfn cd torchgfn diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index a5205249..47352b82 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -155,8 +155,8 @@ def _add_objs(self, training_objects: Transitions | Trajectories | tuple[States] # Add the terminating states to the buffer. if self.terminating_states is not None: - assert terminating_states is not None - self.terminating_states.extend(terminating_states) + assert self.terminating_states is not None + self.terminating_states.extend(self.terminating_states) # Sort terminating states by logreward as well. self.terminating_states = self.terminating_states[ix] @@ -164,7 +164,6 @@ def _add_objs(self, training_objects: Transitions | Trajectories | tuple[States] def add(self, training_objects: Transitions | Trajectories | tuple[States]): """Adds a training object to the buffer.""" - terminating_states = None if isinstance(training_objects, tuple): assert self.objects_type == "states" and self.terminating_states is not None training_objects, terminating_states = training_objects diff --git a/src/gfn/env.py b/src/gfn/env.py index c1569235..c6097243 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -271,7 +271,7 @@ def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]: @property def log_partition(self) -> float: "Returns the logarithm of the partition function." - return NotImplementedError( + raise NotImplementedError( "The environment does not support enumeration of states" ) @@ -432,20 +432,20 @@ def get_terminating_states_indices( @property def n_states(self) -> int: - return NotImplementedError( + raise NotImplementedError( "The environment does not support enumeration of states" ) @property def n_terminating_states(self) -> int: - return NotImplementedError( + raise NotImplementedError( "The environment does not support enumeration of states" ) @property def true_dist_pmf(self) -> TT["n_states", torch.float]: "Returns a one-dimensional tensor representing the true distribution." - return NotImplementedError( + raise NotImplementedError( "The environment does not support enumeration of states" ) @@ -456,7 +456,7 @@ def all_states(self) -> DiscreteStates: This should satisfy: self.get_states_indices(self.all_states) == torch.arange(self.n_states) """ - return NotImplementedError( + raise NotImplementedError( "The environment does not support enumeration of states" ) @@ -467,6 +467,6 @@ def terminating_states(self) -> DiscreteStates: This should satisfy: self.get_terminating_states_indices(self.terminating_states) == torch.arange(self.n_terminating_states) """ - return NotImplementedError( + raise NotImplementedError( "The environment does not support enumeration of states" )