Skip to content

Commit

Permalink
Merge pull request #143 from GFNOrg/easier_environment_definition
Browse files Browse the repository at this point in the history
Easier environment definition
  • Loading branch information
saleml authored Nov 25, 2023
2 parents d9ca558 + 84bb169 commit 1603723
Show file tree
Hide file tree
Showing 18 changed files with 3,168 additions and 95 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11"]
python-version: ['3.10', '3.11']

steps:
- uses: actions/checkout@v3
Expand All @@ -20,15 +20,15 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[dev]
pip install .[all]
- name: Black Formatting
run: |
black .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --ignore=E203,E266,E501,W503,F403,F401 --show-source --statistics
flake8 . --count --select=E9,F63,F7,F82 --ignore=E203,E266,E501,W503,F403,F401,F821 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=18 --max-line-length=89 --statistics
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/python-package-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ jobs:
- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
python-version: 3.10
python-version: '3.10'
- name: Add conda to system path
run: |
# $CONDA is an environment variable pointing to the root of the miniconda directory
echo $CONDA/bin >> $GITHUB_PATH
- name: Install dependencies
run: |
conda env update --file environment.yml --name base
python -m pip install --upgrade pip
pip install .[all]
- name: Lint with flake8
run: |
conda install flake8
Expand All @@ -30,5 +31,4 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
conda install pytest
pytest
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,5 @@ wandb/

scripts.py

*.ipynb
models/
*.DS_Store
*.DS_Store
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@

## Installing the package

The codebase requires python >= 3.10

To install the latest stable version:
The codebase requires python >= 3.10. To install the latest stable version:

```bash
pip install torchgfn
Expand Down
23 changes: 13 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "torchgfn"
packages = [{include = "gfn", from = "src"}]
version = "1.1"
version = "1.1.1"
description = "A torch implementation of GFlowNets"
authors = ["Salem Lahou <[email protected]>", "Joseph Viviano <[email protected]>", "Victor Schmidt <[email protected]>"]
license = "MIT"
Expand All @@ -26,16 +26,17 @@ torch = ">=1.9.0"
torchtyping = ">=0.1.4"

# dev dependencies.
black = { version = "22.3.0", optional = true }
black = { version = "*", optional = true }
flake8 = { version = "*", optional = true }
gitmopy = { version = "*", optional = true }
myst-parser = { version = "*", optional = true }
pre-commit = { version = "*", optional = true }
pytest = { version = "*", optional = true }
renku-sphinx-theme = { version = "*", optional = true }
sphinx = { version = "*", optional = true }
sphinx_rtd_theme = { version = "*", optional = true }
sphinx-autoapi = { version = "*", optional = true }
sphinx-math-dollar = { version = "*", optional = true }
sphinx_rtd_theme = { version = "*", optional = true }
tox = { version = "*", optional = true }

# scripts dependencies.
Expand All @@ -52,30 +53,32 @@ dev = [
"pre-commit",
"pytest",
"renku-sphinx-theme",
"sphinx",
"sphinx_rtd_theme",
"sphinx-autoapi",
"sphinx-math-dollar",
"sphinx_rtd_theme",
"tox"
"sphinx",
"tox",
"flake8",
]

scripts = ["tqdm", "wandb", "scikit-learn", "scipy"]

all = [
"black",
"flake8",
"myst-parser",
"pre-commit",
"pytest",
"renku-sphinx-theme",
"scikit-learn",
"scipy",
"sphinx_rtd_theme",
"sphinx-autoapi",
"sphinx-math-dollar",
"sphinx",
"tox",
"black",
"myst-parser",
"tqdm",
"wandb",
"scikit-learn",
"scipy"
]

[project.urls]
Expand Down
44 changes: 30 additions & 14 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,21 @@ def __init__(
sf: Optional[TT["state_shape", torch.float]] = None,
device_str: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
log_reward_clip: Optional[float] = -100.0,
):
"""Initializes an environment.
Args:
s0: Representation of the initial state. All individual states would be of the same shape.
sf (optional): Representation of the final state. Only used for a human readable representation of
the states or trajectories.
device_str (Optional[str], optional): 'cpu' or 'cuda'. Defaults to None, in which case the device is inferred from s0.
preprocessor (Optional[Preprocessor], optional): a Preprocessor object that converts raw states to a tensor that can be fed
into a neural network. Defaults to None, in which case the IdentityPreprocessor is used.
s0: Representation of the initial state. All individual states would be of
the same shape.
sf: Representation of the final state. Only used for a human
readable representation of the states or trajectories.
device_str: 'cpu' or 'cuda'. Defaults to None, in which case the device is
inferred from s0.
preprocessor: a Preprocessor object that converts raw states to a tensor
that can be fed into a neural network. Defaults to None, in which case
the IdentityPreprocessor is used.
log_reward_clip: Used to clip small rewards (in particular, log(0) rewards).
"""
self.device = torch.device(device_str) if device_str is not None else s0.device

Expand All @@ -53,6 +58,7 @@ def __init__(

self.preprocessor = preprocessor
self.is_discrete = False
self.log_reward_clip = log_reward_clip

@abstractmethod
def make_States_class(self) -> type[States]:
Expand Down Expand Up @@ -184,12 +190,15 @@ def backward_step(
return new_states

def reward(self, final_states: States) -> TT["batch_shape", torch.float]:
"""Either this or log_reward needs to be implemented."""
return torch.exp(self.log_reward(final_states))
"""The environment's reward given a state.
This or log_reward must be implemented.
"""
raise NotImplementedError("Reward function is not implemented.")

def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]:
"""Either this or reward needs to be implemented."""
raise NotImplementedError("log_reward function not implemented")
"""Calculates the log reward (clipping small rewards)."""
return torch.log(self.reward(final_states)).clip(self.log_reward_clip)

@property
def log_partition(self) -> float:
Expand All @@ -203,8 +212,9 @@ class DiscreteEnv(Env, ABC):
"""
Base class for discrete environments, where actions are represented by a number in
{0, ..., n_actions - 1}, the last one being the exit action.
`DiscreteEnv` allow specifying the validity of actions (forward and backward), via mask tensors, that
are directly attached to `States` objects.
`DiscreteEnv` allows for specifying the validity of actions (forward and backward),
via mask tensors, that are directly attached to `States` objects.
"""

def __init__(
Expand All @@ -214,16 +224,22 @@ def __init__(
sf: Optional[TT["state_shape", torch.float]] = None,
device_str: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
log_reward_clip: Optional[float] = -100.0,
):
"""Initializes a discrete environment.
Args:
n_actions: The number of actions in the environment.
s0: The initial state tensor (shared among all trajectories).
sf: The final state tensor (shared among all trajectories).
device_str: String representation of a torch.device.
preprocessor: An optional preprocessor for intermediate states.
log_reward_clip: Used to clip small rewards (in particular, log(0) rewards).
"""
self.n_actions = n_actions
super().__init__(s0, sf, device_str, preprocessor)
super().__init__(s0, sf, device_str, preprocessor, log_reward_clip)
self.is_discrete = True
self.log_reward_clip = log_reward_clip

def make_Actions_class(self) -> type[Actions]:
env = self
Expand Down
8 changes: 5 additions & 3 deletions src/gfn/gym/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
R2: float = 2.0,
epsilon: float = 1e-4,
device_str: Literal["cpu", "cuda"] = "cpu",
log_reward_clip: float = -100.0,
):
assert 0 < delta <= 1, "delta must be in (0, 1]"
self.delta = delta
Expand All @@ -30,7 +31,7 @@ def __init__(
self.R1 = R1
self.R2 = R2

super().__init__(s0=s0)
super().__init__(s0=s0, log_reward_clip=log_reward_clip)

def make_States_class(self) -> type[States]:
env = self
Expand Down Expand Up @@ -116,14 +117,15 @@ def is_action_valid(

return True

def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]:
def reward(self, final_states: States) -> TT["batch_shape", torch.float]:
"""Reward is distance from the goal point."""
R0, R1, R2 = (self.R0, self.R1, self.R2)
ax = abs(final_states.tensor - 0.5)
reward = (
R0 + (0.25 < ax).prod(-1) * R1 + ((0.3 < ax) * (ax < 0.4)).prod(-1) * R2
)

return reward.log()
return reward

@property
def log_partition(self) -> float:
Expand Down
45 changes: 24 additions & 21 deletions src/gfn/gym/discrete_ebm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import ClassVar, Literal, Tuple, cast
from typing import ClassVar, Literal, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -48,14 +48,19 @@ def __init__(
alpha: float = 1.0,
device_str: Literal["cpu", "cuda"] = "cpu",
preprocessor_name: Literal["Identity", "Enum"] = "Identity",
log_reward_clip: float = -100.0,
):
"""Discrete EBM environment.
Args:
ndim (int, optional): dimension D of the sampling space {0, 1}^D.
energy (EnergyFunction): energy function of the EBM. Defaults to None. If None, the Ising model with Identity matrix is used.
alpha (float, optional): interaction strength the EBM. Defaults to 1.0.
device_str (str, optional): "cpu" or "cuda". Defaults to "cpu".
ndim: dimension D of the sampling space {0, 1}^D.
energy: energy function of the EBM. Defaults to None. If
None, the Ising model with Identity matrix is used.
alpha: interaction strength the EBM. Defaults to 1.0.
device_str: "cpu" or "cuda". Defaults to "cpu".
preprocessor_name: "KHot" or "OneHot" or "Identity".
Defaults to "KHot".
log_reward_clip: Minimum log reward allowable (namely, for log(0)).
"""
self.ndim = ndim

Expand Down Expand Up @@ -89,6 +94,7 @@ def __init__(
sf=sf,
device_str=device_str,
preprocessor=preprocessor,
log_reward_clip=log_reward_clip,
)

def make_States_class(self) -> type[DiscreteStates]:
Expand Down Expand Up @@ -133,16 +139,7 @@ def make_masks(
return forward_masks, backward_masks

def update_masks(self) -> None:
# The following two lines are for typing only.
self.forward_masks = cast(
TT["batch_shape", "n_actions", torch.bool],
self.forward_masks,
)
self.backward_masks = cast(
TT["batch_shape", "n_actions - 1", torch.bool],
self.backward_masks,
)

self.set_default_typing()
self.forward_masks[..., : env.ndim] = self.tensor == -1
self.forward_masks[..., env.ndim : 2 * env.ndim] = self.tensor == -1
self.forward_masks[..., -1] = torch.all(self.tensor != -1, dim=-1)
Expand Down Expand Up @@ -183,16 +180,22 @@ def maskless_backward_step(
# action i in [ndim, 2*ndim-1] corresponds to replacing s[i - ndim] with 1.
# A backward action asks "what index should be set back to -1", hence the fmod
# to enable wrapping of indices.
return states.tensor.scatter(
-1,
actions.tensor.fmod(self.ndim),
-1,
)
return states.tensor.scatter(-1, actions.tensor.fmod(self.ndim), -1)

def reward(self, final_states: DiscreteStates) -> TT["batch_shape"]:
"""Not used during training but provided for completeness.
Note the effect of clipping will be seen in these values.
"""
return torch.exp(self.log_reward(final_states))

def log_reward(self, final_states: DiscreteStates) -> TT["batch_shape"]:
"""The energy weighted by alpha is our log reward."""
raw_states = final_states.tensor
canonical = 2 * raw_states - 1
return -self.alpha * self.energy(canonical)
log_reward = -self.alpha * self.energy(canonical)

return log_reward.clip(self.log_reward_clip)

def get_states_indices(self, states: DiscreteStates) -> TT["batch_shape"]:
"""The chosen encoding is the following: -1 -> 0, 0 -> 1, 1 -> 2, then we convert to base 3"""
Expand Down
3 changes: 2 additions & 1 deletion src/gfn/gym/helpers/box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Beta, Categorical, Distribution, MixtureSameFamily
from torchtyping import TensorType as TT

Expand Down Expand Up @@ -600,8 +601,8 @@ class BoxStateFlowModule(NeuralNet):
"""A deep neural network for the state flow function."""

def __init__(self, logZ_value: torch.Tensor, **kwargs):
self.logZ_value = logZ_value
super().__init__(**kwargs)
self.logZ_value = nn.Parameter(logZ_value)

def forward(
self, preprocessed_states: TT["batch_shape", "input_dim", float]
Expand Down
Loading

0 comments on commit 1603723

Please sign in to comment.