From f431d80dc03f6ba6e6ed73a7882baa04367cc51e Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Tue, 29 Oct 2024 14:06:19 +0200 Subject: [PATCH 1/8] chore: upgrade python --- .github/workflows/tests_linters.yml | 27 ++++++++++++++--- .gitignore | 2 +- README.md | 2 +- pyproject.toml | 47 ++++++++++++++++++++++++++++- 4 files changed, 70 insertions(+), 8 deletions(-) diff --git a/.github/workflows/tests_linters.yml b/.github/workflows/tests_linters.yml index 258378b7a..3af9f0251 100644 --- a/.github/workflows/tests_linters.yml +++ b/.github/workflows/tests_linters.yml @@ -1,6 +1,6 @@ name: Tests and Linters ๐Ÿงช -on: [ push, pull_request ] +on: [ pull_request ] jobs: tests-and-linters: @@ -9,26 +9,43 @@ jobs: strategy: matrix: - python-version: ["3.8", "3.9"] + python-version: ["3.10", "3.11", "3.12"] os: [ubuntu-latest] steps: - name: Install dependencies for viewer test run: sudo apt-get update && sudo apt-get install -y xvfb + - name: Checkout jumanji ๐Ÿ - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + uses: actions/checkout@v4 + + + - name: Install uv + uses: astral-sh/setup-uv@v3 + with: + version: "0.4.26" + enable-cache: true + cache-dependency-glob: "requirements/requirements**.txt" # invalidate cache when requirements file changes + + - uses: actions/setup-python@v5 with: python-version: "${{ matrix.python-version }}" + - name: Install python dependencies ๐Ÿ”ง - run: pip install .[dev,train] + run: uv pip install .[dev,train] + env: + UV_SYSTEM_PYTHON: 1 + - name: Run linters ๐Ÿ–Œ๏ธ run: pre-commit run --all-files --verbose + - name: Run tests ๐Ÿงช run: pytest -n 2 --cov=jumanji --cov-report=term-missing --junit-xml=test-results.xml -vv jumanji + - name: Run coverage run: | coverage html --directory=coverage_html_report coverage report --fail-under=0.97 + - name: Test build docs ๐Ÿ“– run: mkdocs build --verbose --site-dir docs_public diff --git a/.gitignore b/.gitignore index 7a4e033c2..ec09148f1 100644 --- a/.gitignore +++ b/.gitignore @@ -150,7 +150,7 @@ cython_debug/ # MacBook Finder .DS_Store -3.8/ +3.10/ jumanji_env/ **/outputs/ *.xml diff --git a/README.md b/README.md index 4866657e3..04b3a437f 100644 --- a/README.md +++ b/README.md @@ -136,7 +136,7 @@ Alternatively, you can install the latest development version directly from GitH pip install git+https://github.com/instadeepai/jumanji.git ``` -Jumanji has been tested on Python 3.8 and 3.9. +Jumanji has been tested on Python 3.10, 3.11 and 3.12. Note that because the installation of JAX differs depending on your hardware accelerator, we advise users to explicitly install the correct JAX version (see the [official installation guide](https://github.com/google/jax#installation)). diff --git a/pyproject.toml b/pyproject.toml index 79ed22fe2..d335fb5ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,53 @@ [tool.isort] profile = "black" +[build-system] +requires=["setuptools>=62.6"] +build-backend="setuptools.build_meta" + +[project] +name="jumanji" +authors=[{name="InstaDeep Ltd", email="clement.bonnet16@gmail.com"}] +dynamic=["version", "dependencies", "optional-dependencies"] +license={file="LICENSE"} +description="A diverse suite of scalable reinforcement learning environments in JAX" +readme ="README.md" +requires-python=">=3.10" +keywords=["reinforcement-learning", "python", "jax"] +classifiers=[ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", + "License :: OSI Approved :: Apache Software License", +] + +[tool.setuptools.packages.find] +include=["jumanji*"] + +[tool.setuptools.package-data] +"jumanji" = ["py.typed"] + +[tool.setuptools.dynamic] +version={attr="jumanji.version.__version__"} +dependencies={file="requirements/requirements.txt"} +optional-dependencies.dev={file=["requirements/requirements-dev.txt"]} +optional-dependencies.train={file=["requirements/requirements-train.txt"]} + + +[project.urls] +"Homepage"="https://github.com/instadeep/jumanji" +"Bug Tracker"="https://github.com/instadeep/jumanji/issues" +"Documentation"="https://instadeepai.github.io/jumanji" [tool.mypy] -python_version = 3.8 +python_version = 3.10 namespace_packages = true incremental = false cache_dir = "" From c6810ed7440b92b87c98b2d0c4e3ab19ead59609 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Tue, 29 Oct 2024 16:03:16 +0200 Subject: [PATCH 2/8] chore: remove setup.py --- .github/workflows/tests_linters.yml | 2 +- setup.py | 71 ----------------------------- 2 files changed, 1 insertion(+), 72 deletions(-) delete mode 100644 setup.py diff --git a/.github/workflows/tests_linters.yml b/.github/workflows/tests_linters.yml index 3af9f0251..5110c93f1 100644 --- a/.github/workflows/tests_linters.yml +++ b/.github/workflows/tests_linters.yml @@ -6,6 +6,7 @@ jobs: tests-and-linters: name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" runs-on: "${{ matrix.os }}" + timeout-minutes: 10 strategy: matrix: @@ -19,7 +20,6 @@ jobs: - name: Checkout jumanji ๐Ÿ uses: actions/checkout@v4 - - name: Install uv uses: astral-sh/setup-uv@v3 with: diff --git a/setup.py b/setup.py deleted file mode 100644 index d6f0c038c..000000000 --- a/setup.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2022 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from typing import List - -import setuptools -from setuptools import setup - - -def _parse_requirements(path: str) -> List[str]: - """Returns content of given requirements file.""" - with open(os.path.join(path)) as f: - return [ - line.rstrip() for line in f if not (line.isspace() or line.startswith("#")) - ] - - -def _get_version() -> str: - """Grabs the package version from jumanji/version.py.""" - dict_ = {} - with open("jumanji/version.py") as f: - exec(f.read(), dict_) - return dict_["__version__"] - - -setup( - name="jumanji", - version=_get_version(), - author="InstaDeep", - author_email="clement.bonnet16@gmail.com", - description="A diverse suite of scalable reinforcement learning environments in JAX", - license="Apache 2.0", - url="https://github.com/instadeepai/jumanji/", - long_description=open("README.md", encoding="utf-8").read(), - long_description_content_type="text/markdown", - keywords="reinforcement-learning python jax", - packages=setuptools.find_packages(), - python_requires=">=3.8", - install_requires=_parse_requirements("requirements/requirements.txt"), - extras_require={ - "dev": _parse_requirements("requirements/requirements-dev.txt"), - "train": _parse_requirements("requirements/requirements-train.txt"), - }, - package_data={"jumanji": ["py.typed"]}, - classifiers=[ - "Development Status :: 4 - Beta", - "Environment :: Console", - "Intended Audience :: Science/Research", - "Intended Audience :: Developers", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Software Development :: Libraries :: Python Modules", - "License :: OSI Approved :: Apache Software License", - ], - zip_safe=False, - include_package_data=True, -) From be2f440deef849a5de1c09619c8e135a8be26ee6 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Tue, 29 Oct 2024 16:47:02 +0200 Subject: [PATCH 3/8] chore: remove dev requirements --- requirements/requirements-dev.txt | 4 ---- 1 file changed, 4 deletions(-) diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index e2a0aadd2..ed227c357 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -22,9 +22,5 @@ pytest-cov pytest-mock pytest-parallel pytest-xdist -pytype scipy>=1.7.3 testfixtures -types-Pillow -types-requests<1.27 -types-setuptools From 308e7cf2566aa514522bf6eafbe98b0e76b55950 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Tue, 29 Oct 2024 17:09:43 +0200 Subject: [PATCH 4/8] fix: mypy python version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d335fb5ad..b14b2795e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ optional-dependencies.train={file=["requirements/requirements-train.txt"]} "Documentation"="https://instadeepai.github.io/jumanji" [tool.mypy] -python_version = 3.10 +python_version = "3.10" namespace_packages = true incremental = false cache_dir = "" From f64eea88b676630e1daf438784876bf5d6af8086 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Fri, 1 Nov 2024 13:16:43 +0200 Subject: [PATCH 5/8] chore: flake ignore A005 --- pyproject.toml | 5 +++-- setup.cfg | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b14b2795e..636dcae26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,3 @@ -[tool.isort] -profile = "black" [build-system] requires=["setuptools>=62.6"] build-backend="setuptools.build_meta" @@ -92,3 +90,6 @@ module = [ "PIL.*", ] ignore_missing_imports = true + +[tool.isort] +profile = "black" diff --git a/setup.cfg b/setup.cfg index d032b15c6..6d09b4167 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,6 +22,7 @@ per-file-ignores = __init__.py:F401 ignore = A002 # Argument shadowing a Python builtin. A003 # Class attribute shadowing a Python builtin. + A005 # Module shadowing a Python builtin. D107 # Do not require docstrings for __init__. E266 # Do not require block comments to only have a single leading #. E731 # Do not assign a lambda expression, use a def. From 333bb4a5215ad6a4cb7d5a61debd9a4d13873a28 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Fri, 1 Nov 2024 14:21:17 +0200 Subject: [PATCH 6/8] chore: unpin flake8 --- requirements/requirements-dev.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index ed227c357..03d70be87 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -1,7 +1,6 @@ black==22.3.0 coverage -flake8==3.9.2 -importlib-metadata<5.0 +flake8 isort==5.11.5 livereload mkdocs==1.2.3 From c4bacb70c08f07c978c70e85cf74a03731efdbb6 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Fri, 1 Nov 2024 14:59:21 +0200 Subject: [PATCH 7/8] chore: upgrade flake and fix errors --- .pre-commit-config.yaml | 2 +- jumanji/env.py | 2 +- .../environments/logic/game_2048/viewer.py | 2 +- .../environments/logic/graph_coloring/env.py | 2 +- .../logic/sliding_tile_puzzle/viewer.py | 4 +-- jumanji/environments/packing/bin_pack/env.py | 6 ++--- .../packing/flat_pack/generator.py | 4 +-- .../environments/routing/mmst/generator.py | 2 +- .../environments/routing/multi_cvrp/viewer.py | 8 +++--- .../routing/robot_warehouse/conftest.py | 4 +-- .../routing/robot_warehouse/env.py | 4 ++- jumanji/specs.py | 8 +++--- jumanji/specs_test.py | 14 +++++----- .../networks/graph_coloring/actor_critic.py | 2 +- .../training/networks/mmst/actor_critic.py | 2 +- jumanji/training/networks/tsp/actor_critic.py | 4 +-- setup.cfg | 27 ++++++++++++------- 17 files changed, 54 insertions(+), 43 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 928835b3d..f62452ace 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,7 @@ repos: name: "Trailing whitespace fixer" - repo: https://github.com/PyCQA/flake8 - rev: 3.9.2 + rev: 7.1.1 hooks: - id: flake8 name: "Linter" diff --git a/jumanji/env.py b/jumanji/env.py index 48035a992..9674960c8 100644 --- a/jumanji/env.py +++ b/jumanji/env.py @@ -110,7 +110,7 @@ def reward_spec(self) -> specs.Array: @cached_property def discount_spec(self) -> specs.BoundedArray: - """Returns the discount spec. By default, this is assumed to be a single float between 0 and 1. + """Returns the discount spec. By default, this is assumed to be a float between 0 and 1. Returns: discount_spec: a `specs.BoundedArray` spec. diff --git a/jumanji/environments/logic/game_2048/viewer.py b/jumanji/environments/logic/game_2048/viewer.py index 819d1c251..b64b48b20 100644 --- a/jumanji/environments/logic/game_2048/viewer.py +++ b/jumanji/environments/logic/game_2048/viewer.py @@ -123,7 +123,7 @@ def make_frame(state_index: int) -> None: return self._animation def get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: - """This function returns a `Matplotlib` figure and axes object for displaying the 2048 game board. + """This function returns a `Matplotlib` figure and axes for displaying the 2048 game board. Returns: A tuple containing the figure and axes objects. diff --git a/jumanji/environments/logic/graph_coloring/env.py b/jumanji/environments/logic/graph_coloring/env.py index b5e65a3e5..32de81019 100644 --- a/jumanji/environments/logic/graph_coloring/env.py +++ b/jumanji/environments/logic/graph_coloring/env.py @@ -296,7 +296,7 @@ def animate( interval: int = 200, save_path: Optional[str] = None, ) -> animation.FuncAnimation: - """Creates an animated gif of the `GraphColoring` environment based on the sequence of game states. + """Creates an animated gif of the `GraphColoring` environment based on a sequence of states. Args: states: is a list of `State` objects representing the sequence of game states. diff --git a/jumanji/environments/logic/sliding_tile_puzzle/viewer.py b/jumanji/environments/logic/sliding_tile_puzzle/viewer.py index 6596a323d..7fa905fcf 100644 --- a/jumanji/environments/logic/sliding_tile_puzzle/viewer.py +++ b/jumanji/environments/logic/sliding_tile_puzzle/viewer.py @@ -71,7 +71,7 @@ def animate( interval: int = 200, save_path: Optional[str] = None, ) -> matplotlib.animation.FuncAnimation: - """Creates an animated gif of the sliding tiles puzzle game based on the sequence of game states. + """Creates an animated gif of the sliding tiles puzzle game based on a sequence of states. Args: states: is a list of `State` objects representing the sequence of game states. @@ -101,7 +101,7 @@ def make_frame(state_index: int) -> None: return self._animation def get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: - """This function returns a `Matplotlib` figure and axes object for displaying the game puzzle. + """This function returns a `Matplotlib` figure and axes for displaying the puzzle. Returns: A tuple containing the figure and axes objects. diff --git a/jumanji/environments/packing/bin_pack/env.py b/jumanji/environments/packing/bin_pack/env.py index 5b2b2c7cf..3506fa0b0 100644 --- a/jumanji/environments/packing/bin_pack/env.py +++ b/jumanji/environments/packing/bin_pack/env.py @@ -397,9 +397,9 @@ def close(self) -> None: def _make_observation_and_extras( self, state: State ) -> Tuple[State, Observation, Dict]: - """Computes the observation and the environment metrics to include in `timestep.extras`. Also - updates the `action_mask` and `sorted_ems_indexes` in the state. The observation is obtained - by selecting a subset of all EMSs, namely the `obs_num_ems` largest ones. + """Computes the observation and the environment metrics to include in `timestep.extras`. + Also updates the `action_mask` and `sorted_ems_indexes` in the state. The observation is + obtained by selecting a subset of all EMSs, namely the `obs_num_ems` largest ones. """ obs_ems, obs_ems_mask, sorted_ems_indexes = self._get_set_of_largest_ems( state.ems, state.ems_mask diff --git a/jumanji/environments/packing/flat_pack/generator.py b/jumanji/environments/packing/flat_pack/generator.py index 7ea8495d5..412c4d9e1 100644 --- a/jumanji/environments/packing/flat_pack/generator.py +++ b/jumanji/environments/packing/flat_pack/generator.py @@ -28,8 +28,8 @@ class InstanceGenerator(abc.ABC): - """Base class for generators for the flat_pack environment. An `InstanceGenerator` is responsible - for generating a problem instance when the environment is reset. + """Base class for generators for the flat_pack environment. An `InstanceGenerator` is + responsible for generating a problem instance when the environment is reset. """ def __init__( diff --git a/jumanji/environments/routing/mmst/generator.py b/jumanji/environments/routing/mmst/generator.py index c71d01054..be0e4685a 100644 --- a/jumanji/environments/routing/mmst/generator.py +++ b/jumanji/environments/routing/mmst/generator.py @@ -88,7 +88,7 @@ def __call__(self, key: chex.PRNGKey) -> State: class SplitRandomGenerator(Generator): - """Generates a random environments that is solvable by spliting the graph into multiple sub graphs. + """Generates a random environments that is solvable by spliting the graph into sub graphs. Returns a graph and with a desired number of edges and nodes to connect per agent. """ diff --git a/jumanji/environments/routing/multi_cvrp/viewer.py b/jumanji/environments/routing/multi_cvrp/viewer.py index 3dbda1aec..dd6fc5e8e 100644 --- a/jumanji/environments/routing/multi_cvrp/viewer.py +++ b/jumanji/environments/routing/multi_cvrp/viewer.py @@ -205,10 +205,10 @@ def _draw_route(self, ax: plt.Axes, coords: chex.Array, col_id: int) -> None: ax.scatter(x, y, s=self.NODE_SIZE, color=self._cmap(col_id)) def _add_tour(self, ax: plt.Axes, state: State) -> None: - """Add the customers and the depot to the plot, and draw each route in the tour in a different - colour. The tour is the entire trajectory between the visited customers and a route is a - trajectory either starting and ending at the depot or starting at the depot and ending at - the current city.""" + """Add the customers and the depot to the plot, and draw each route in the tour in a + different colour. The tour is the entire trajectory between the visited customers and a + route is a trajectory either starting and ending at the depot or starting at the depot + and ending at the current city.""" x_coords, y_coords = ( state.nodes.coordinates[:, 0] / self._map_max, state.nodes.coordinates[:, 1] / self._map_max, diff --git a/jumanji/environments/routing/robot_warehouse/conftest.py b/jumanji/environments/routing/robot_warehouse/conftest.py index 95ed58271..68d815705 100644 --- a/jumanji/environments/routing/robot_warehouse/conftest.py +++ b/jumanji/environments/routing/robot_warehouse/conftest.py @@ -31,8 +31,8 @@ @pytest.fixture(scope="module") def robot_warehouse_env() -> RobotWarehouse: - """Instantiates a default RobotWarehouse environment with 2 agents, 1 shelf row, 3 shelf columns, - a column height of 2, sensor range of 1 and a request queue size of 4.""" + """Instantiates a default RobotWarehouse environment with 2 agents, 1 shelf row, 3 shelf + columns, a column height of 2, sensor range of 1 and a request queue size of 4.""" generator = RandomGenerator( shelf_rows=1, shelf_columns=3, diff --git a/jumanji/environments/routing/robot_warehouse/env.py b/jumanji/environments/routing/robot_warehouse/env.py index eb9c2c578..8ab107bc4 100644 --- a/jumanji/environments/routing/robot_warehouse/env.py +++ b/jumanji/environments/routing/robot_warehouse/env.py @@ -362,7 +362,9 @@ def observation_spec(self) -> specs.Spec[Observation]: @cached_property def action_spec(self) -> specs.MultiDiscreteArray: - """Returns the action spec. 5 actions: [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load]. + """Returns the action spec. 5 actions: + [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load]. + Since this is a multi-agent environment, the environment expects an array of actions. This array is of shape (num_agents,). """ diff --git a/jumanji/specs.py b/jumanji/specs.py index 6dc40237b..6cfacd546 100644 --- a/jumanji/specs.py +++ b/jumanji/specs.py @@ -44,9 +44,9 @@ class Spec(abc.ABC, Generic[T]): - """Adapted from `dm_env.spec.Array`. This is an augmentation of the `Array` spec to allow for nested - specs. `self.name`, `self.generate_value` and `self.validate` methods are adapted from the - `dm_env` object.""" + """Adapted from `dm_env.spec.Array`. This is an augmentation of the `Array` spec to allow for + nested specs. `self.name`, `self.generate_value` and `self.validate` methods are adapted from + the `dm_env` object.""" def __init__( self, @@ -139,7 +139,7 @@ def __getitem__(self, item: str) -> "Spec": class Array(Spec[chex.Array]): - """Describes a jax array spec. This is adapted from `dm_env.specs.Array` to suit Jax environments. + """Describes a jax array spec. This is adapted from `dm_env.specs.Array` for Jax environments. An `Array` spec allows an API to describe the arrays that it accepts or returns, before that array exists. diff --git a/jumanji/specs_test.py b/jumanji/specs_test.py index 74e95b512..09b9f48b1 100644 --- a/jumanji/specs_test.py +++ b/jumanji/specs_test.py @@ -589,7 +589,7 @@ def test_array(self) -> None: converted_spec: dm_env.specs.Array = specs.jumanji_specs_to_dm_env_specs( jumanji_spec ) - assert type(converted_spec) == type(dm_env_spec) + assert type(converted_spec) is type(dm_env_spec) assert converted_spec.shape == dm_env_spec.shape assert converted_spec.dtype == dm_env_spec.dtype assert converted_spec.name == dm_env_spec.name @@ -602,7 +602,7 @@ def test_bounded_array(self) -> None: converted_spec: dm_env.specs.BoundedArray = specs.jumanji_specs_to_dm_env_specs( jumanji_spec ) - assert type(converted_spec) == type(dm_env_spec) + assert type(converted_spec) is type(dm_env_spec) assert converted_spec.shape == dm_env_spec.shape assert converted_spec.dtype == dm_env_spec.dtype assert converted_spec.name == dm_env_spec.name @@ -615,7 +615,7 @@ def test_discrete_array(self) -> None: converted_spec: dm_env.specs.DiscreteArray = ( specs.jumanji_specs_to_dm_env_specs(jumanji_spec) ) - assert type(converted_spec) == type(dm_env_spec) + assert type(converted_spec) is type(dm_env_spec) assert converted_spec.shape == dm_env_spec.shape assert converted_spec.dtype == dm_env_spec.dtype assert converted_spec.name == dm_env_spec.name @@ -675,7 +675,7 @@ def test_array(self) -> None: jumanji_spec = specs.Array((1, 2), jnp.int32) gym_space = gym.spaces.Box(-np.inf, np.inf, (1, 2), jnp.int32) converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec) - assert type(converted_spec) == type(gym_space) + assert type(converted_spec) is type(gym_space) assert_trees_all_equal(converted_spec.low, gym_space.low) assert_trees_all_equal(converted_spec.high, gym_space.high) assert converted_spec.shape == gym_space.shape @@ -687,7 +687,7 @@ def test_bounded_array(self) -> None: ) gym_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1, 2), dtype=jnp.float32) converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec) - assert type(converted_spec) == type(gym_space) + assert type(converted_spec) is type(gym_space) assert converted_spec.shape == gym_space.shape assert converted_spec.dtype == gym_space.dtype assert_trees_all_equal(converted_spec.low, gym_space.low) @@ -697,7 +697,7 @@ def test_discrete_array(self) -> None: jumanji_spec = specs.DiscreteArray(num_values=5, dtype=jnp.int32) gym_space = gym.spaces.Discrete(n=5) converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec) - assert type(converted_spec) == type(gym_space) + assert type(converted_spec) is type(gym_space) assert converted_spec.shape == gym_space.shape assert converted_spec.dtype == gym_space.dtype assert converted_spec.n == gym_space.n @@ -708,7 +708,7 @@ def test_multi_discrete_array(self) -> None: ) gym_space = gym.spaces.MultiDiscrete(nvec=[5, 6]) converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec) - assert type(converted_spec) == type(gym_space) + assert type(converted_spec) is type(gym_space) assert converted_spec.shape == gym_space.shape assert converted_spec.dtype == gym_space.dtype assert jnp.array_equal(converted_spec.nvec, gym_space.nvec) diff --git a/jumanji/training/networks/graph_coloring/actor_critic.py b/jumanji/training/networks/graph_coloring/actor_critic.py index 6e2e336f6..62187d709 100644 --- a/jumanji/training/networks/graph_coloring/actor_critic.py +++ b/jumanji/training/networks/graph_coloring/actor_critic.py @@ -219,7 +219,7 @@ def __call__(self, observation: Observation) -> chex.Array: mlp_units=self.transformer_mlp_units, w_init_scale=2 / self.num_transformer_layers, model_size=self.model_size, - name=f"cross_attention_color_node_block_{block_id+1}", + name=f"cross_attention_color_node_block_{block_id + 1}", )(color_embeddings, current_node_embeddings, current_node_embeddings) return new_embedding diff --git a/jumanji/training/networks/mmst/actor_critic.py b/jumanji/training/networks/mmst/actor_critic.py index 45e776b4c..18b21dec3 100644 --- a/jumanji/training/networks/mmst/actor_critic.py +++ b/jumanji/training/networks/mmst/actor_critic.py @@ -207,7 +207,7 @@ def __call__(self, observation: Observation) -> chex.Array: mlp_units=self.transformer_mlp_units, w_init_scale=2 / self.num_transformer_layers, model_size=self.model_size, - name=f"cross_attention_agent_node_block_{block_id+1}", + name=f"cross_attention_agent_node_block_{block_id + 1}", )(agents_embeddings, current_node_embeddings, current_node_embeddings) return new_embedding diff --git a/jumanji/training/networks/tsp/actor_critic.py b/jumanji/training/networks/tsp/actor_critic.py index cff891c5e..1d720743b 100644 --- a/jumanji/training/networks/tsp/actor_critic.py +++ b/jumanji/training/networks/tsp/actor_critic.py @@ -72,8 +72,8 @@ def __init__( transformer_mlp_units: Sequence[int], name: Optional[str] = None, ): - """Linear embedding of all cities' coordinates followed by `transformer_num_blocks` blocks of self - attention. + """Linear embedding of all cities' coordinates followed by `transformer_num_blocks` blocks + of self attention. """ super().__init__(name=name) self.transformer_num_blocks = transformer_num_blocks diff --git a/setup.cfg b/setup.cfg index 6d09b4167..47a19d216 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,12 +20,21 @@ docstring-convention = google per-file-ignores = __init__.py:F401 ignore = - A002 # Argument shadowing a Python builtin. - A003 # Class attribute shadowing a Python builtin. - A005 # Module shadowing a Python builtin. - D107 # Do not require docstrings for __init__. - E266 # Do not require block comments to only have a single leading #. - E731 # Do not assign a lambda expression, use a def. - W503 # Line break before binary operator (not compatible with black). - B017 # assertRaises(Exception): or pytest.raises(Exception) should be considered evil. - E203 # black and flake8 disagree on whitespace before ':'. +# Argument shadowing a Python builtin. + A002 +# Class attribute shadowing a Python builtin. + A003 +# Module shadowing a Python builtin. + A005 +# Do not require docstrings for __init__. + D107 +# Do not require block comments to only have a single leading #. + E266 +# Do not assign a lambda expression, use a def. + E731 +# Line break before binary operator (not compatible with black). + W503 +# assertRaises(Exception): or pytest.raises(Exception) should be considered evil. + B017 +# black and flake8 disagree on whitespace before ':'. + E203 From 1ecabdfdc1066b1dd79b9b397622e6cbd66ce1a4 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Fri, 1 Nov 2024 15:49:48 +0200 Subject: [PATCH 8/8] chore: longer timeout for tests --- .github/workflows/tests_linters.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests_linters.yml b/.github/workflows/tests_linters.yml index 5110c93f1..fe6b148cb 100644 --- a/.github/workflows/tests_linters.yml +++ b/.github/workflows/tests_linters.yml @@ -6,7 +6,7 @@ jobs: tests-and-linters: name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" runs-on: "${{ matrix.os }}" - timeout-minutes: 10 + timeout-minutes: 20 strategy: matrix: