Skip to content

Commit

Permalink
Merge pull request #192 from inelgnu/fixes
Browse files Browse the repository at this point in the history
Small fixes
  • Loading branch information
josephdviviano authored Oct 11, 2024
2 parents 13c059a + 6758954 commit c62b5b0
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pip install torchgfn[scripts]
To install the cutting edge version (from the `main` branch):

```bash
git clone https://github.com/saleml/torchgfn.git
git clone https://github.com/GFNOrg/torchgfn.git
conda create -n gfn python=3.10
conda activate gfn
cd torchgfn
Expand Down
5 changes: 2 additions & 3 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,15 @@ def _add_objs(self, training_objects: Transitions | Trajectories | tuple[States]

# Add the terminating states to the buffer.
if self.terminating_states is not None:
assert terminating_states is not None
self.terminating_states.extend(terminating_states)
assert self.terminating_states is not None
self.terminating_states.extend(self.terminating_states)

# Sort terminating states by logreward as well.
self.terminating_states = self.terminating_states[ix]
self.terminating_states = self.terminating_states[-self.capacity :]

def add(self, training_objects: Transitions | Trajectories | tuple[States]):
"""Adds a training object to the buffer."""
terminating_states = None
if isinstance(training_objects, tuple):
assert self.objects_type == "states" and self.terminating_states is not None
training_objects, terminating_states = training_objects
Expand Down
12 changes: 6 additions & 6 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]:
@property
def log_partition(self) -> float:
"Returns the logarithm of the partition function."
return NotImplementedError(
raise NotImplementedError(
"The environment does not support enumeration of states"
)

Expand Down Expand Up @@ -432,20 +432,20 @@ def get_terminating_states_indices(

@property
def n_states(self) -> int:
return NotImplementedError(
raise NotImplementedError(
"The environment does not support enumeration of states"
)

@property
def n_terminating_states(self) -> int:
return NotImplementedError(
raise NotImplementedError(
"The environment does not support enumeration of states"
)

@property
def true_dist_pmf(self) -> TT["n_states", torch.float]:
"Returns a one-dimensional tensor representing the true distribution."
return NotImplementedError(
raise NotImplementedError(
"The environment does not support enumeration of states"
)

Expand All @@ -456,7 +456,7 @@ def all_states(self) -> DiscreteStates:
This should satisfy:
self.get_states_indices(self.all_states) == torch.arange(self.n_states)
"""
return NotImplementedError(
raise NotImplementedError(
"The environment does not support enumeration of states"
)

Expand All @@ -467,6 +467,6 @@ def terminating_states(self) -> DiscreteStates:
This should satisfy:
self.get_terminating_states_indices(self.terminating_states) == torch.arange(self.n_terminating_states)
"""
return NotImplementedError(
raise NotImplementedError(
"The environment does not support enumeration of states"
)

0 comments on commit c62b5b0

Please sign in to comment.