Skip to content

Commit

Permalink
RGB render mode (#29)
Browse files Browse the repository at this point in the history
* improved readability

* rgb render mode for pixels obs

* Update gym.py

* Update gym.py

* Update gym.py

gym\utils\passive_env_checker already checks that

* Update gym.py

when in human mode, env is always rendered as in gym

* Update README.md

* Minor changes to README

---------

Co-authored-by: Kenny Young <[email protected]>
  • Loading branch information
sparisi and kenjyoung authored May 5, 2023
1 parent 429898c commit 2d73682
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 23 deletions.
29 changes: 21 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pip install minatar
```
If you'd like to install MinAtar from the github repo instead (for example if you'd like to modify the code), please follow the steps below:

1. Clone the repo:
1. Clone the repo:
```bash
git clone https://github.com/kenjyoung/MinAtar.git
```
Expand Down Expand Up @@ -84,29 +84,42 @@ pip install minatar==1.0.11

## Visualizing the Environments
We provide 2 ways to visualize a MinAtar environment.

### Using Environment.display_state()
The Environment class includes a simple visualizer using matplotlib in the display_state function. To use this simply call:
```python
env = Environment('breakout')
env.display_state(50)
```
where env is an instance of MinAtar.Environment. The argument is the number of milliseconds to display the state before continuing execution. To close the resulting display window call:
or, if you're using the gym interface:
```python
env = gym.make('Minatar/Breakout-v1')
env.game.display_state(50)
```
The argument is the number of milliseconds to display the state before continuing execution. To close the resulting display window call:
```python
env.game.close_display()
```
This is the simplest way to visualize the environments, unless you need to handle user input during execution in which case you could use the provided GUI class.
With the gym interface, you can also enable real-time rendering by making the environment in human render mode. In this case, the display_state function will be called automatically at every step:
```python
env.close_display()
env = gym.make('MinAtar/Breakout-v1', render_mode='human')
env.reset()
env.step(1)
```
This is the simplest way to visualize the environments, unless you need to handle user input during execution in which case you could use the provided GUI class.
### Using GUI class
We also include a slightly more complex GUI to visualize the environments and optionally handle user input. This GUI is used in examples/human_play.py to play as a human and examples/agent_play.py to visualize the performance of trained agents. To use the GUI you can import it in your code with:
```python
from minatar.gui import GUI
```
Initialize an instance of the GUI class by providing a name for the window, and the integer number of input channels for the minatar environment to be visualized. For example:
Initialize an instance of the GUI class by providing a name for the window, and the integer number of input channels for the MinAtar environment to be visualized. For example:
```python
GUI(env.game_name(), env.n_channels)
```
where env is an instance of minatar.Environment. The recommended way to use the GUI for visualizing an environment is to include all you're agent-environment interaction code in a function that looks something like this:
where env is an instance of minatar.Environment. The recommended way to use the GUI for visualizing an environment is to include all your agent-environment interaction code in a function that looks something like this:
```python
def func():
gui.display_state(env.state())
Expand All @@ -128,7 +141,7 @@ This will enter the agent environment interaction loop and then run the GUI thre
- [JAX](https://github.com/RobertTLange/gymnax)
## Results
The following plots display results for DQN (Mnih et al., 2015) and actor-critic (AC) with eligibility traces. Our DQN agent uses a significantly smaller network compared to that of Mnih et al., 2015. We display results for DQN with and without experience reply. Our AC agent uses a similar architecture to DQN, but does not use experience replay. We display results for two values of the trace decay parameter, 0.8 and 0.0. Each curve is the average of 30 independent runs with different random seeds. The top plots display the sensitivity of final performance to the step-size parameter, while the bottom plots display the average return during training as a function of training frames. For further information, see the paper on MinAtar available [here](https://arxiv.org/abs/1903.03176).
The following plots display results for DQN (Mnih et al., 2015) and actor-critic (AC) with eligibility traces. Our DQN agent uses a significantly smaller network compared to that of Mnih et al., 2015. We display results for DQN with and without experience reply. Our AC agent uses a similar architecture to DQN, but does not use experience replay. We display results for two values of the trace decay parameter, 0.8 and 0.0. Each curve is the average of 30 independent runs with different random seeds. The top plots display the sensitivity of final performance to the step-size parameter, while the bottom plots display the average return during training as a function of training frames. For further information, see the paper on MinAtar available [here](https://arxiv.org/abs/1903.03176).
**Note, the currently displayed results for Seaquest are for the version in MinAtar v1.0.10 and lower, where a bug caused the oxygen bar to flash to full one step before running out**. Results for the updated version may be different.
Expand Down Expand Up @@ -199,4 +212,4 @@ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
along with this program. If not, see <http://www.gnu.org/licenses/>.
24 changes: 14 additions & 10 deletions minatar/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@
# Environment
#
# Wrapper for all the specific game environments. Imports the environment specified by the user and then acts as a
# minimal interface. Also defines code for displaying the environment for a human user.
# minimal interface. Also defines code for displaying the environment for a human user.
#
#####################################################################################################################
class Environment:
def __init__(self, env_name, sticky_action_prob = 0.1, difficulty_ramping = True, random_seed = None):
env_module = import_module('minatar.environments.'+env_name)
def __init__(self, env_name, sticky_action_prob=0.1,
difficulty_ramping=True, random_seed=None):
env_module = import_module('minatar.environments.' + env_name)
self.random = np.random.RandomState(random_seed)
self.env_name = env_name
self.env = env_module.Env(ramping = difficulty_ramping, random_state = self.random)
self.env = env_module.Env(
ramping=difficulty_ramping, random_state=self.random)
self.n_channels = self.env.state_shape()[2]
self.sticky_action_prob = sticky_action_prob
self.last_action = 0
Expand All @@ -28,7 +30,7 @@ def __init__(self, env_name, sticky_action_prob = 0.1, difficulty_ramping = True

# Wrapper for env.act
def act(self, a):
if(self.random.rand()<self.sticky_action_prob):
if(self.random.rand() < self.sticky_action_prob):
a = self.last_action
self.last_action = a
return self.env.act(a)
Expand Down Expand Up @@ -69,8 +71,8 @@ def display_state(self, time=50):
colors = mpl.colors
sns = __import__('seaborn', globals(), locals())
self.cmap = sns.color_palette("cubehelix", self.n_channels)
self.cmap.insert(0,(0,0,0))
self.cmap=colors.ListedColormap(self.cmap)
self.cmap.insert(0, (0,0,0))
self.cmap = colors.ListedColormap(self.cmap)
bounds = [i for i in range(self.n_channels+2)]
self.norm = colors.BoundaryNorm(bounds, self.n_channels+1)
_, self.ax = plt.subplots(1,1)
Expand All @@ -81,9 +83,11 @@ def display_state(self, time=50):
plt.show(block=False)
self.closed = False
state = self.env.state()
numerical_state = np.amax(state*np.reshape(np.arange(self.n_channels)+1,(1,1,-1)),2)+0.5
self.ax.imshow(numerical_state, cmap=self.cmap, norm=self.norm, interpolation='none')
plt.pause(time/1000)
numerical_state = np.amax(
state * np.reshape(np.arange(self.n_channels) + 1, (1,1,-1)), 2) + 0.5
self.ax.imshow(
numerical_state, cmap=self.cmap, norm=self.norm, interpolation='none')
plt.pause(time / 1000)
plt.cla()

def close_display(self):
Expand Down
34 changes: 29 additions & 5 deletions minatar/gym.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Adapted from https://github.com/qlan3/gym-games
import numpy as np
import gym
from gym import spaces
from gym.envs import register
Expand All @@ -7,11 +8,13 @@


class BaseEnv(gym.Env):
metadata = {"render.modes": ["human", "array"]}
metadata = {"render_modes": ["human", "array", "rgb_array"]}

def __init__(self, game, display_time=50, use_minimal_action_set=False, **kwargs):
def __init__(self, game, display_time=50, use_minimal_action_set=False,
render_mode=None, **kwargs):
self.game_name = game
self.display_time = display_time
self.render_mode = render_mode

self.game_kwargs = kwargs
self.seed()
Expand All @@ -29,6 +32,8 @@ def __init__(self, game, display_time=50, use_minimal_action_set=False, **kwargs
def step(self, action):
action = self.action_set[action]
reward, done = self.game.act(action)
if self.render_mode == "human":
self.render()
return self.game.state(), reward, done, False, {}

def reset(self, seed=None, options=None):
Expand All @@ -39,6 +44,8 @@ def reset(self, seed=None, options=None):
**self.game_kwargs
)
self.game.reset()
if self.render_mode == "human":
self.render()
return self.game.state(), {}

def seed(self, seed=None):
Expand All @@ -49,11 +56,28 @@ def seed(self, seed=None):
)
return seed

def render(self, mode="human"):
if mode == "array":
def render(self):
if self.render_mode is None:
gym.logger.warn(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "
f'e.g. gym("{self.spec.id}", render_mode="rgb_array")'
)
return
if self.render_mode == "array":
return self.game.state()
elif mode == "human":
elif self.render_mode == "human":
self.game.display_state(self.display_time)
elif self.render_mode == "rgb_array": # use the same color palette of Environment.display_state
state = self.game.state()
n_channels = state.shape[-1]
sns = __import__('seaborn', globals(), locals())
cmap = sns.color_palette("cubehelix", n_channels)
cmap.insert(0, (0,0,0))
numerical_state = np.amax(
state * np.reshape(np.arange(n_channels) + 1, (1,1,-1)), 2)
rgb_array = np.stack(cmap)[numerical_state]
return rgb_array

def close(self):
if self.game.visualized:
Expand Down

0 comments on commit 2d73682

Please sign in to comment.