Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

to be deleted #144

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11"]
python-version: ['3.10', '3.11']

steps:
- uses: actions/checkout@v3
Expand All @@ -20,15 +20,15 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[dev]
pip install .[all]
- name: Black Formatting
run: |
black .

- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --ignore=E203,E266,E501,W503,F403,F401 --show-source --statistics
flake8 . --count --select=E9,F63,F7,F82 --ignore=E203,E266,E501,W503,F403,F401,F821 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=18 --max-line-length=89 --statistics

Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/python-package-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ jobs:
- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
python-version: 3.10
python-version: '3.10'
- name: Add conda to system path
run: |
# $CONDA is an environment variable pointing to the root of the miniconda directory
echo $CONDA/bin >> $GITHUB_PATH
- name: Install dependencies
run: |
conda env update --file environment.yml --name base
python -m pip install --upgrade pip
pip install .[all]
- name: Lint with flake8
run: |
conda install flake8
Expand All @@ -30,5 +31,4 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
conda install pytest
pytest
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,5 @@ wandb/

scripts.py

*.ipynb
models/
*.DS_Store
*.DS_Store
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@

## Installing the package

The codebase requires python >= 3.10

To install the latest stable version:
The codebase requires python >= 3.10. To install the latest stable version:

```bash
pip install torchgfn
Expand Down
23 changes: 13 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "torchgfn"
packages = [{include = "gfn", from = "src"}]
version = "1.1"
version = "1.1.1"
description = "A torch implementation of GFlowNets"
authors = ["Salem Lahou <[email protected]>", "Joseph Viviano <[email protected]>", "Victor Schmidt <[email protected]>"]
license = "MIT"
Expand All @@ -26,16 +26,17 @@ torch = ">=1.9.0"
torchtyping = ">=0.1.4"

# dev dependencies.
black = { version = "22.3.0", optional = true }
black = { version = "*", optional = true }
flake8 = { version = "*", optional = true }
gitmopy = { version = "*", optional = true }
myst-parser = { version = "*", optional = true }
pre-commit = { version = "*", optional = true }
pytest = { version = "*", optional = true }
renku-sphinx-theme = { version = "*", optional = true }
sphinx = { version = "*", optional = true }
sphinx_rtd_theme = { version = "*", optional = true }
sphinx-autoapi = { version = "*", optional = true }
sphinx-math-dollar = { version = "*", optional = true }
sphinx_rtd_theme = { version = "*", optional = true }
tox = { version = "*", optional = true }

# scripts dependencies.
Expand All @@ -52,30 +53,32 @@ dev = [
"pre-commit",
"pytest",
"renku-sphinx-theme",
"sphinx",
"sphinx_rtd_theme",
"sphinx-autoapi",
"sphinx-math-dollar",
"sphinx_rtd_theme",
"tox"
"sphinx",
"tox",
"flake8",
]

scripts = ["tqdm", "wandb", "scikit-learn", "scipy"]

all = [
"black",
"flake8",
"myst-parser",
"pre-commit",
"pytest",
"renku-sphinx-theme",
"scikit-learn",
"scipy",
"sphinx_rtd_theme",
"sphinx-autoapi",
"sphinx-math-dollar",
"sphinx",
"tox",
"black",
"myst-parser",
"tqdm",
"wandb",
"scikit-learn",
"scipy"
]

[project.urls]
Expand Down
41 changes: 27 additions & 14 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,21 @@ def __init__(
sf: Optional[TT["state_shape", torch.float]] = None,
device_str: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
log_reward_clip: Optional[float] = -100.0,
):
"""Initializes an environment.

Args:
s0: Representation of the initial state. All individual states would be of the same shape.
sf (optional): Representation of the final state. Only used for a human readable representation of
the states or trajectories.
device_str (Optional[str], optional): 'cpu' or 'cuda'. Defaults to None, in which case the device is inferred from s0.
preprocessor (Optional[Preprocessor], optional): a Preprocessor object that converts raw states to a tensor that can be fed
into a neural network. Defaults to None, in which case the IdentityPreprocessor is used.
s0: Representation of the initial state. All individual states would be of
the same shape.
sf: Representation of the final state. Only used for a human
readable representation of the states or trajectories.
device_str: 'cpu' or 'cuda'. Defaults to None, in which case the device is
inferred from s0.
preprocessor: a Preprocessor object that converts raw states to a tensor
that can be fed into a neural network. Defaults to None, in which case
the IdentityPreprocessor is used.
log_reward_clip: Used to clip small rewards (in particular, log(0) rewards).
"""
self.device = torch.device(device_str) if device_str is not None else s0.device

Expand All @@ -53,6 +58,7 @@ def __init__(

self.preprocessor = preprocessor
self.is_discrete = False
self.log_reward_clip = log_reward_clip

@abstractmethod
def make_States_class(self) -> type[States]:
Expand Down Expand Up @@ -184,12 +190,12 @@ def backward_step(
return new_states

def reward(self, final_states: States) -> TT["batch_shape", torch.float]:
"""Either this or log_reward needs to be implemented."""
return torch.exp(self.log_reward(final_states))
"""This (and potentially log_reward) needs to be implemented."""
raise NotImplementedError("reward function not implemented")

def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]:
"""Either this or reward needs to be implemented."""
raise NotImplementedError("log_reward function not implemented")
"""Calculates the log reward (clipping small rewards)."""
return torch.log(self.reward(final_states)).clip(self.log_reward_clip)

@property
def log_partition(self) -> float:
Expand All @@ -203,8 +209,9 @@ class DiscreteEnv(Env, ABC):
"""
Base class for discrete environments, where actions are represented by a number in
{0, ..., n_actions - 1}, the last one being the exit action.
`DiscreteEnv` allow specifying the validity of actions (forward and backward), via mask tensors, that
are directly attached to `States` objects.

`DiscreteEnv` allows for specifying the validity of actions (forward and backward),
via mask tensors, that are directly attached to `States` objects.
"""

def __init__(
Expand All @@ -214,16 +221,22 @@ def __init__(
sf: Optional[TT["state_shape", torch.float]] = None,
device_str: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
log_reward_clip: Optional[float] = -100.0,
):
"""Initializes a discrete environment.

Args:
n_actions: The number of actions in the environment.

s0: The initial state tensor (shared among all trajectories).
sf: The final state tensor (shared among all trajectories).
device_str: String representation of a torch.device.
preprocessor: An optional preprocessor for intermediate states.
log_reward_clip: Used to clip small rewards (in particular, log(0) rewards).
"""
self.n_actions = n_actions
super().__init__(s0, sf, device_str, preprocessor)
super().__init__(s0, sf, device_str, preprocessor, log_reward_clip)
self.is_discrete = True
self.log_reward_clip = log_reward_clip

def make_Actions_class(self) -> type[Actions]:
env = self
Expand Down
8 changes: 5 additions & 3 deletions src/gfn/gym/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
R2: float = 2.0,
epsilon: float = 1e-4,
device_str: Literal["cpu", "cuda"] = "cpu",
log_reward_clip: float = -100.0,
):
assert 0 < delta <= 1, "delta must be in (0, 1]"
self.delta = delta
Expand All @@ -30,7 +31,7 @@ def __init__(
self.R1 = R1
self.R2 = R2

super().__init__(s0=s0)
super().__init__(s0=s0, log_reward_clip=log_reward_clip)

def make_States_class(self) -> type[States]:
env = self
Expand Down Expand Up @@ -116,14 +117,15 @@ def is_action_valid(

return True

def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]:
def reward(self, final_states: States) -> TT["batch_shape", torch.float]:
"""Reward is distance from the goal point."""
R0, R1, R2 = (self.R0, self.R1, self.R2)
ax = abs(final_states.tensor - 0.5)
reward = (
R0 + (0.25 < ax).prod(-1) * R1 + ((0.3 < ax) * (ax < 0.4)).prod(-1) * R2
)

return reward.log()
return reward

@property
def log_partition(self) -> float:
Expand Down
45 changes: 24 additions & 21 deletions src/gfn/gym/discrete_ebm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import ClassVar, Literal, Tuple, cast
from typing import ClassVar, Literal, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -48,14 +48,19 @@ def __init__(
alpha: float = 1.0,
device_str: Literal["cpu", "cuda"] = "cpu",
preprocessor_name: Literal["Identity", "Enum"] = "Identity",
log_reward_clip: float = -100.0,
):
"""Discrete EBM environment.

Args:
ndim (int, optional): dimension D of the sampling space {0, 1}^D.
energy (EnergyFunction): energy function of the EBM. Defaults to None. If None, the Ising model with Identity matrix is used.
alpha (float, optional): interaction strength the EBM. Defaults to 1.0.
device_str (str, optional): "cpu" or "cuda". Defaults to "cpu".
ndim: dimension D of the sampling space {0, 1}^D.
energy: energy function of the EBM. Defaults to None. If
None, the Ising model with Identity matrix is used.
alpha: interaction strength the EBM. Defaults to 1.0.
device_str: "cpu" or "cuda". Defaults to "cpu".
preprocessor_name: "KHot" or "OneHot" or "Identity".
Defaults to "KHot".
log_reward_clip: Minimum log reward allowable (namely, for log(0)).
"""
self.ndim = ndim

Expand Down Expand Up @@ -89,6 +94,7 @@ def __init__(
sf=sf,
device_str=device_str,
preprocessor=preprocessor,
log_reward_clip=log_reward_clip,
)

def make_States_class(self) -> type[DiscreteStates]:
Expand Down Expand Up @@ -133,16 +139,7 @@ def make_masks(
return forward_masks, backward_masks

def update_masks(self) -> None:
# The following two lines are for typing only.
self.forward_masks = cast(
TT["batch_shape", "n_actions", torch.bool],
self.forward_masks,
)
self.backward_masks = cast(
TT["batch_shape", "n_actions - 1", torch.bool],
self.backward_masks,
)

self.set_default_typing()
self.forward_masks[..., : env.ndim] = self.tensor == -1
self.forward_masks[..., env.ndim : 2 * env.ndim] = self.tensor == -1
self.forward_masks[..., -1] = torch.all(self.tensor != -1, dim=-1)
Expand Down Expand Up @@ -183,16 +180,22 @@ def maskless_backward_step(
# action i in [ndim, 2*ndim-1] corresponds to replacing s[i - ndim] with 1.
# A backward action asks "what index should be set back to -1", hence the fmod
# to enable wrapping of indices.
return states.tensor.scatter(
-1,
actions.tensor.fmod(self.ndim),
-1,
)
return states.tensor.scatter(-1, actions.tensor.fmod(self.ndim), -1)

def reward(self, final_states: DiscreteStates) -> TT["batch_shape"]:
"""Not used during training but provided for completeness.

Note the effect of clipping will be seen in these values.
"""
return torch.exp(self.log_reward(final_states))

def log_reward(self, final_states: DiscreteStates) -> TT["batch_shape"]:
"""The energy weighted by alpha is our log reward."""
raw_states = final_states.tensor
canonical = 2 * raw_states - 1
return -self.alpha * self.energy(canonical)
log_reward = -self.alpha * self.energy(canonical)

return log_reward.clip(self.log_reward_clip)

def get_states_indices(self, states: DiscreteStates) -> TT["batch_shape"]:
"""The chosen encoding is the following: -1 -> 0, 0 -> 1, 1 -> 2, then we convert to base 3"""
Expand Down
Loading