Skip to content

Commit

Permalink
Simplified gym wrapper, and added a second version of each environmen…
Browse files Browse the repository at this point in the history
…t which uses minimal_action_set.
  • Loading branch information
kenjyoung committed Oct 8, 2021
1 parent 4f81cb8 commit c645454
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 228 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ MinAtar now includes an optional OpenAI Gym wrapper. To include this wrapper in
```bash
INSTALL_GYM_WRAPPER=1 pip install .
```
This will additionally install a package called gym_minatar, which when imported will register the following gym environments corresponding to the associated MinAtar game: Asterix-MinAtar-v0, Breakout-MinAtar-v0, Freeway-MinAtar-v0, Seaquest-MinAtar-v0, SpaceInvaders-MinAtar-v0.
This will additionally install a package called gym_minatar, which when imported will register the following gym environments corresponding to the associated MinAtar game: Asterix-MinAtar-v0, Breakout-MinAtar-v0, Freeway-MinAtar-v0, Seaquest-MinAtar-v0, SpaceInvaders-MinAtar-v0. An additional version of each game which uses the minimal action set for the game (as opposed to all 6 actions, some of which are equivalent to no-op depending on the game), for this version, simply replace v0 above with v1. Note that the results included in this repo and the associated paper use the full action set of 6 actions.

## Visualizing the Environments
We provide 2 ways to visualize a MinAtar environment.
Expand Down
12 changes: 8 additions & 4 deletions gym_minatar/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from gym.envs.registration import register

for game in ['Asterix', 'Breakout', 'Freeway', 'Seaquest', 'SpaceInvaders']:
register(
id='{}-MinAtar-v0'.format(game),
entry_point=f'gym_minatar.envs:{game}Env'
)
register(
id='{}-MinAtar-v0'.format(game),
entry_point=f'gym_minatar.envs:{game}Env'
)
register(
id='{}-MinAtar-v1'.format(game),
entry_point=f'gym_minatar.envs:MinimalAction{game}Env'
)
6 changes: 1 addition & 5 deletions gym_minatar/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,2 @@
from gym_minatar.envs.base import BaseEnv
from gym_minatar.envs.asterix import AsterixEnv
from gym_minatar.envs.breakout import BreakoutEnv
from gym_minatar.envs.freeway import FreewayEnv
from gym_minatar.envs.seaquest import SeaquestEnv
from gym_minatar.envs.space_invaders import SpaceInvadersEnv
from gym_minatar.envs.game_envs import *
36 changes: 0 additions & 36 deletions gym_minatar/envs/asterix.py

This file was deleted.

79 changes: 41 additions & 38 deletions gym_minatar/envs/base.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,49 @@
#Adapted from https://github.com/qlan3/gym-games
import os
import importlib
import numpy as np
import gym
from gym import spaces

from minatar import Environment


class BaseEnv(gym.Env):
metadata = {'render.modes': ['human', 'array']}

def __init__(self, display_time=50, **kwargs):
self.game_name = 'Game Name'
self.display_time = display_time
self.init(**kwargs)

def init(self, **kwargs):
self.game = Environment(env_name=self.game_name, **kwargs)
self.action_set = self.game.env.action_map
self.action_space = spaces.Discrete(self.game.num_actions())
self.observation_space = spaces.Box(0.0, 1.0, shape=self.game.state_shape(), dtype=bool)

def step(self, action):
reward, done = self.game.act(action)
return (self.game.state(), reward, done, {})

def reset(self):
self.game.reset()
return self.game.state()

def seed(self, seed=None):
self.game = Environment(env_name=self.game_name, random_seed=seed)
return seed

def render(self, mode='human'):
if mode == 'array':
return self.game.state()
elif mode == 'human':
self.game.display_state(self.display_time)

def close(self):
if self.game.visualized:
self.game.close_display()
return 0
metadata = {'render.modes': ['human', 'array']}

def __init__(self, display_time=50, **kwargs):
self.game_name = 'Game Name'
self.display_time = display_time
self.init(**kwargs)

def init(self, use_minimal_action_set=False, **kwargs):
self.game = Environment(env_name=self.game_name, **kwargs)
self.use_minimal_action_set = use_minimal_action_set
self.action_set = self.game.env.action_map
if(self.use_minimal_action_set):
self.action_space = spaces.Discrete(len(self.game.minimal_action_set()))
else:
self.action_space = spaces.Discrete(self.game.num_actions())
self.observation_space = spaces.Box(0.0, 1.0, shape=self.game.state_shape(), dtype=bool)

def step(self, action):
if(self.use_minimal_action_set):
action = self.game.minimal_action_set()[action]
reward, done = self.game.act(action)
return (self.game.state(), reward, done, {})

def reset(self):
self.game.reset()
return self.game.state()

def seed(self, seed=None):
self.game = Environment(env_name=self.game_name, random_seed=seed)
return seed

def render(self, mode='human'):
if mode == 'array':
return self.game.state()
elif mode == 'human':
self.game.display_state(self.display_time)

def close(self):
if self.game.visualized:
self.game.close_display()
return 0
36 changes: 0 additions & 36 deletions gym_minatar/envs/breakout.py

This file was deleted.

36 changes: 0 additions & 36 deletions gym_minatar/envs/freeway.py

This file was deleted.

62 changes: 62 additions & 0 deletions gym_minatar/envs/game_envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#Adapted from https://github.com/qlan3/gym-games
from gym_minatar.envs.base import BaseEnv

class AsterixEnv(BaseEnv):
def __init__(self, display_time=50, **kwargs):
self.game_name = 'asterix'
self.display_time = display_time
self.init(**kwargs)

class MinimalActionAsterixEnv(BaseEnv):
def __init__(self, display_time=50, **kwargs):
self.game_name = 'asterix'
self.display_time = display_time
self.init(**kwargs, use_minimal_action_set=True)

class BreakoutEnv(BaseEnv):
def __init__(self, display_time=50, **kwargs):
self.game_name = 'breakout'
self.display_time = display_time
self.init(**kwargs)

class MinimalActionBreakoutEnv(BaseEnv):
def __init__(self, display_time=50, **kwargs):
self.game_name = 'breakout'
self.display_time = display_time
self.init(**kwargs, use_minimal_action_set=True)

class FreewayEnv(BaseEnv):
def __init__(self, display_time=50, **kwargs):
self.game_name = 'freeway'
self.display_time = display_time
self.init(**kwargs)

class MinimalActionFreewayEnv(BaseEnv):
def __init__(self, display_time=50, **kwargs):
self.game_name = 'freeway'
self.display_time = display_time
self.init(**kwargs, use_minimal_action_set=True)

class SeaquestEnv(BaseEnv):
def __init__(self, display_time=50, **kwargs):
self.game_name = 'seaquest'
self.display_time = display_time
self.init(**kwargs)

class MinimalActionSeaquestEnv(BaseEnv):
def __init__(self, display_time=50, **kwargs):
self.game_name = 'seaquest'
self.display_time = display_time
self.init(**kwargs, use_minimal_action_set=True)

class SpaceInvadersEnv(BaseEnv):
def __init__(self, display_time=50, **kwargs):
self.game_name = 'space_invaders'
self.display_time = display_time
self.init(**kwargs)

class MinimalActionSpaceInvadersEnv(BaseEnv):
def __init__(self, display_time=50, **kwargs):
self.game_name = 'space_invaders'
self.display_time = display_time
self.init(**kwargs, use_minimal_action_set=True)
36 changes: 0 additions & 36 deletions gym_minatar/envs/seaquest.py

This file was deleted.

36 changes: 0 additions & 36 deletions gym_minatar/envs/space_invaders.py

This file was deleted.

0 comments on commit c645454

Please sign in to comment.