Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: upgrade python #254

Merged
merged 9 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions .github/workflows/tests_linters.yml
Original file line number Diff line number Diff line change
@@ -1,34 +1,51 @@
name: Tests and Linters 🧪

on: [ push, pull_request ]
on: [ pull_request ]

jobs:
tests-and-linters:
name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}"
runs-on: "${{ matrix.os }}"
timeout-minutes: 20

strategy:
matrix:
python-version: ["3.8", "3.9"]
python-version: ["3.10", "3.11", "3.12"]
os: [ubuntu-latest]

steps:
- name: Install dependencies for viewer test
run: sudo apt-get update && sudo apt-get install -y xvfb

- name: Checkout jumanji 🐍
uses: actions/checkout@v3
- uses: actions/setup-python@v4
uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: "0.4.26"
enable-cache: true
cache-dependency-glob: "requirements/requirements**.txt" # invalidate cache when requirements file changes

- uses: actions/setup-python@v5
with:
python-version: "${{ matrix.python-version }}"

- name: Install python dependencies 🔧
run: pip install .[dev,train]
run: uv pip install .[dev,train]
env:
UV_SYSTEM_PYTHON: 1

- name: Run linters 🖌️
run: pre-commit run --all-files --verbose

- name: Run tests 🧪
run: pytest -n 2 --cov=jumanji --cov-report=term-missing --junit-xml=test-results.xml -vv jumanji

- name: Run coverage
run: |
coverage html --directory=coverage_html_report
coverage report --fail-under=0.97

- name: Test build docs 📖
run: mkdocs build --verbose --site-dir docs_public
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ cython_debug/
# MacBook Finder
.DS_Store

3.8/
3.10/
jumanji_env/
**/outputs/
*.xml
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ repos:
name: "Trailing whitespace fixer"

- repo: https://github.com/PyCQA/flake8
rev: 3.9.2
rev: 7.1.1
hooks:
- id: flake8
name: "Linter"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ Alternatively, you can install the latest development version directly from GitH
pip install git+https://github.com/instadeepai/jumanji.git
```

Jumanji has been tested on Python 3.8 and 3.9.
Jumanji has been tested on Python 3.10, 3.11 and 3.12.
Note that because the installation of JAX differs depending on your hardware accelerator,
we advise users to explicitly install the correct JAX version (see the
[official installation guide](https://github.com/google/jax#installation)).
Expand Down
2 changes: 1 addition & 1 deletion jumanji/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def reward_spec(self) -> specs.Array:

@cached_property
def discount_spec(self) -> specs.BoundedArray:
"""Returns the discount spec. By default, this is assumed to be a single float between 0 and 1.
"""Returns the discount spec. By default, this is assumed to be a float between 0 and 1.

Returns:
discount_spec: a `specs.BoundedArray` spec.
Expand Down
2 changes: 1 addition & 1 deletion jumanji/environments/logic/game_2048/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def make_frame(state_index: int) -> None:
return self._animation

def get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]:
"""This function returns a `Matplotlib` figure and axes object for displaying the 2048 game board.
"""This function returns a `Matplotlib` figure and axes for displaying the 2048 game board.

Returns:
A tuple containing the figure and axes objects.
Expand Down
2 changes: 1 addition & 1 deletion jumanji/environments/logic/graph_coloring/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def animate(
interval: int = 200,
save_path: Optional[str] = None,
) -> animation.FuncAnimation:
"""Creates an animated gif of the `GraphColoring` environment based on the sequence of game states.
"""Creates an animated gif of the `GraphColoring` environment based on a sequence of states.

Args:
states: is a list of `State` objects representing the sequence of game states.
Expand Down
4 changes: 2 additions & 2 deletions jumanji/environments/logic/sliding_tile_puzzle/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def animate(
interval: int = 200,
save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
"""Creates an animated gif of the sliding tiles puzzle game based on the sequence of game states.
"""Creates an animated gif of the sliding tiles puzzle game based on a sequence of states.

Args:
states: is a list of `State` objects representing the sequence of game states.
Expand Down Expand Up @@ -101,7 +101,7 @@ def make_frame(state_index: int) -> None:
return self._animation

def get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]:
"""This function returns a `Matplotlib` figure and axes object for displaying the game puzzle.
"""This function returns a `Matplotlib` figure and axes for displaying the puzzle.

Returns:
A tuple containing the figure and axes objects.
Expand Down
6 changes: 3 additions & 3 deletions jumanji/environments/packing/bin_pack/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,9 @@ def close(self) -> None:
def _make_observation_and_extras(
self, state: State
) -> Tuple[State, Observation, Dict]:
"""Computes the observation and the environment metrics to include in `timestep.extras`. Also
updates the `action_mask` and `sorted_ems_indexes` in the state. The observation is obtained
by selecting a subset of all EMSs, namely the `obs_num_ems` largest ones.
"""Computes the observation and the environment metrics to include in `timestep.extras`.
Also updates the `action_mask` and `sorted_ems_indexes` in the state. The observation is
obtained by selecting a subset of all EMSs, namely the `obs_num_ems` largest ones.
"""
obs_ems, obs_ems_mask, sorted_ems_indexes = self._get_set_of_largest_ems(
state.ems, state.ems_mask
Expand Down
4 changes: 2 additions & 2 deletions jumanji/environments/packing/flat_pack/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@


class InstanceGenerator(abc.ABC):
"""Base class for generators for the flat_pack environment. An `InstanceGenerator` is responsible
for generating a problem instance when the environment is reset.
"""Base class for generators for the flat_pack environment. An `InstanceGenerator` is
responsible for generating a problem instance when the environment is reset.
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion jumanji/environments/routing/mmst/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __call__(self, key: chex.PRNGKey) -> State:


class SplitRandomGenerator(Generator):
"""Generates a random environments that is solvable by spliting the graph into multiple sub graphs.
"""Generates a random environments that is solvable by spliting the graph into sub graphs.

Returns a graph and with a desired number of edges and nodes to connect per agent.
"""
Expand Down
8 changes: 4 additions & 4 deletions jumanji/environments/routing/multi_cvrp/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,10 @@ def _draw_route(self, ax: plt.Axes, coords: chex.Array, col_id: int) -> None:
ax.scatter(x, y, s=self.NODE_SIZE, color=self._cmap(col_id))

def _add_tour(self, ax: plt.Axes, state: State) -> None:
"""Add the customers and the depot to the plot, and draw each route in the tour in a different
colour. The tour is the entire trajectory between the visited customers and a route is a
trajectory either starting and ending at the depot or starting at the depot and ending at
the current city."""
"""Add the customers and the depot to the plot, and draw each route in the tour in a
different colour. The tour is the entire trajectory between the visited customers and a
route is a trajectory either starting and ending at the depot or starting at the depot
and ending at the current city."""
x_coords, y_coords = (
state.nodes.coordinates[:, 0] / self._map_max,
state.nodes.coordinates[:, 1] / self._map_max,
Expand Down
4 changes: 2 additions & 2 deletions jumanji/environments/routing/robot_warehouse/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@

@pytest.fixture(scope="module")
def robot_warehouse_env() -> RobotWarehouse:
"""Instantiates a default RobotWarehouse environment with 2 agents, 1 shelf row, 3 shelf columns,
a column height of 2, sensor range of 1 and a request queue size of 4."""
"""Instantiates a default RobotWarehouse environment with 2 agents, 1 shelf row, 3 shelf
columns, a column height of 2, sensor range of 1 and a request queue size of 4."""
generator = RandomGenerator(
shelf_rows=1,
shelf_columns=3,
Expand Down
4 changes: 3 additions & 1 deletion jumanji/environments/routing/robot_warehouse/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,9 @@ def observation_spec(self) -> specs.Spec[Observation]:

@cached_property
def action_spec(self) -> specs.MultiDiscreteArray:
"""Returns the action spec. 5 actions: [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load].
"""Returns the action spec. 5 actions:
[0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load].

Since this is a multi-agent environment, the environment expects an array of actions.
This array is of shape (num_agents,).
"""
Expand Down
8 changes: 4 additions & 4 deletions jumanji/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@


class Spec(abc.ABC, Generic[T]):
"""Adapted from `dm_env.spec.Array`. This is an augmentation of the `Array` spec to allow for nested
specs. `self.name`, `self.generate_value` and `self.validate` methods are adapted from the
`dm_env` object."""
"""Adapted from `dm_env.spec.Array`. This is an augmentation of the `Array` spec to allow for
nested specs. `self.name`, `self.generate_value` and `self.validate` methods are adapted from
the `dm_env` object."""

def __init__(
self,
Expand Down Expand Up @@ -139,7 +139,7 @@ def __getitem__(self, item: str) -> "Spec":


class Array(Spec[chex.Array]):
"""Describes a jax array spec. This is adapted from `dm_env.specs.Array` to suit Jax environments.
"""Describes a jax array spec. This is adapted from `dm_env.specs.Array` for Jax environments.

An `Array` spec allows an API to describe the arrays that it accepts or returns, before that
array exists.
Expand Down
14 changes: 7 additions & 7 deletions jumanji/specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def test_array(self) -> None:
converted_spec: dm_env.specs.Array = specs.jumanji_specs_to_dm_env_specs(
jumanji_spec
)
assert type(converted_spec) == type(dm_env_spec)
assert type(converted_spec) is type(dm_env_spec)
assert converted_spec.shape == dm_env_spec.shape
assert converted_spec.dtype == dm_env_spec.dtype
assert converted_spec.name == dm_env_spec.name
Expand All @@ -602,7 +602,7 @@ def test_bounded_array(self) -> None:
converted_spec: dm_env.specs.BoundedArray = specs.jumanji_specs_to_dm_env_specs(
jumanji_spec
)
assert type(converted_spec) == type(dm_env_spec)
assert type(converted_spec) is type(dm_env_spec)
assert converted_spec.shape == dm_env_spec.shape
assert converted_spec.dtype == dm_env_spec.dtype
assert converted_spec.name == dm_env_spec.name
Expand All @@ -615,7 +615,7 @@ def test_discrete_array(self) -> None:
converted_spec: dm_env.specs.DiscreteArray = (
specs.jumanji_specs_to_dm_env_specs(jumanji_spec)
)
assert type(converted_spec) == type(dm_env_spec)
assert type(converted_spec) is type(dm_env_spec)
assert converted_spec.shape == dm_env_spec.shape
assert converted_spec.dtype == dm_env_spec.dtype
assert converted_spec.name == dm_env_spec.name
Expand Down Expand Up @@ -675,7 +675,7 @@ def test_array(self) -> None:
jumanji_spec = specs.Array((1, 2), jnp.int32)
gym_space = gym.spaces.Box(-np.inf, np.inf, (1, 2), jnp.int32)
converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec)
assert type(converted_spec) == type(gym_space)
assert type(converted_spec) is type(gym_space)
assert_trees_all_equal(converted_spec.low, gym_space.low)
assert_trees_all_equal(converted_spec.high, gym_space.high)
assert converted_spec.shape == gym_space.shape
Expand All @@ -687,7 +687,7 @@ def test_bounded_array(self) -> None:
)
gym_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1, 2), dtype=jnp.float32)
converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec)
assert type(converted_spec) == type(gym_space)
assert type(converted_spec) is type(gym_space)
assert converted_spec.shape == gym_space.shape
assert converted_spec.dtype == gym_space.dtype
assert_trees_all_equal(converted_spec.low, gym_space.low)
Expand All @@ -697,7 +697,7 @@ def test_discrete_array(self) -> None:
jumanji_spec = specs.DiscreteArray(num_values=5, dtype=jnp.int32)
gym_space = gym.spaces.Discrete(n=5)
converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec)
assert type(converted_spec) == type(gym_space)
assert type(converted_spec) is type(gym_space)
assert converted_spec.shape == gym_space.shape
assert converted_spec.dtype == gym_space.dtype
assert converted_spec.n == gym_space.n
Expand All @@ -708,7 +708,7 @@ def test_multi_discrete_array(self) -> None:
)
gym_space = gym.spaces.MultiDiscrete(nvec=[5, 6])
converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec)
assert type(converted_spec) == type(gym_space)
assert type(converted_spec) is type(gym_space)
assert converted_spec.shape == gym_space.shape
assert converted_spec.dtype == gym_space.dtype
assert jnp.array_equal(converted_spec.nvec, gym_space.nvec)
Expand Down
2 changes: 1 addition & 1 deletion jumanji/training/networks/graph_coloring/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __call__(self, observation: Observation) -> chex.Array:
mlp_units=self.transformer_mlp_units,
w_init_scale=2 / self.num_transformer_layers,
model_size=self.model_size,
name=f"cross_attention_color_node_block_{block_id+1}",
name=f"cross_attention_color_node_block_{block_id + 1}",
)(color_embeddings, current_node_embeddings, current_node_embeddings)

return new_embedding
Expand Down
2 changes: 1 addition & 1 deletion jumanji/training/networks/mmst/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __call__(self, observation: Observation) -> chex.Array:
mlp_units=self.transformer_mlp_units,
w_init_scale=2 / self.num_transformer_layers,
model_size=self.model_size,
name=f"cross_attention_agent_node_block_{block_id+1}",
name=f"cross_attention_agent_node_block_{block_id + 1}",
)(agents_embeddings, current_node_embeddings, current_node_embeddings)

return new_embedding
Expand Down
4 changes: 2 additions & 2 deletions jumanji/training/networks/tsp/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def __init__(
transformer_mlp_units: Sequence[int],
name: Optional[str] = None,
):
"""Linear embedding of all cities' coordinates followed by `transformer_num_blocks` blocks of self
attention.
"""Linear embedding of all cities' coordinates followed by `transformer_num_blocks` blocks
of self attention.
"""
super().__init__(name=name)
self.transformer_num_blocks = transformer_num_blocks
Expand Down
52 changes: 49 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,51 @@
[tool.isort]
profile = "black"
[build-system]
requires=["setuptools>=62.6"]
build-backend="setuptools.build_meta"

[project]
name="jumanji"
authors=[{name="InstaDeep Ltd", email="[email protected]"}]
dynamic=["version", "dependencies", "optional-dependencies"]
license={file="LICENSE"}
description="A diverse suite of scalable reinforcement learning environments in JAX"
readme ="README.md"
requires-python=">=3.10"
keywords=["reinforcement-learning", "python", "jax"]
classifiers=[
"Development Status :: 5 - Production/Stable",
"Environment :: Console",
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules",
"License :: OSI Approved :: Apache Software License",
]

[tool.setuptools.packages.find]
include=["jumanji*"]

[tool.setuptools.package-data]
"jumanji" = ["py.typed"]

[tool.setuptools.dynamic]
version={attr="jumanji.version.__version__"}
dependencies={file="requirements/requirements.txt"}
optional-dependencies.dev={file=["requirements/requirements-dev.txt"]}
optional-dependencies.train={file=["requirements/requirements-train.txt"]}


[project.urls]
"Homepage"="https://github.com/instadeep/jumanji"
"Bug Tracker"="https://github.com/instadeep/jumanji/issues"
"Documentation"="https://instadeepai.github.io/jumanji"

[tool.mypy]
python_version = 3.8
python_version = "3.10"
namespace_packages = true
incremental = false
cache_dir = ""
Expand Down Expand Up @@ -47,3 +90,6 @@ module = [
"PIL.*",
]
ignore_missing_imports = true

[tool.isort]
profile = "black"
7 changes: 1 addition & 6 deletions requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
black==22.3.0
coverage
flake8==3.9.2
importlib-metadata<5.0
flake8
isort==5.11.5
livereload
mkdocs==1.2.3
Expand All @@ -22,9 +21,5 @@ pytest-cov
pytest-mock
pytest-parallel
pytest-xdist
pytype
scipy>=1.7.3
testfixtures
types-Pillow
types-requests<1.27
types-setuptools
Loading