Skip to content

Commit

Permalink
Fast random (#31)
Browse files Browse the repository at this point in the history
* improved readability

* rgb render mode for pixels obs

* Update gym.py

* Update gym.py

* Update gym.py

gym\utils\passive_env_checker already checks that

* Update gym.py

when in human mode, env is always rendered as in gym

* Update README.md

* replaced choice with better rand functions

* better seeding

seed and reset do not remake the env from scratch, they just set the random_state

the whole code is more gym-like

* Update setup.py

* Update environment.py

* Update gym.py

* try import guard

* Update README.md

* Update gym.py

* Minor changes to README

* readme

* Update README.md

* fixed my email domain in setup.py

---------

Co-authored-by: Kenny Young <[email protected]>
  • Loading branch information
sparisi and kenjyoung authored May 8, 2023
1 parent f2b7152 commit 42741e5
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 115 deletions.
35 changes: 20 additions & 15 deletions minatar/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@
from importlib import import_module
import numpy as np

try:
from matplotlib import pyplot as plt
from matplotlib import colors
import seaborn as sns
except:
import logging
logging.warning("Cannot import matplotlib and/or seaborn."
"Will not be able to render the environment.")

#####################################################################################################################
# Environment
Expand All @@ -16,21 +24,26 @@
#####################################################################################################################
class Environment:
def __init__(self, env_name, sticky_action_prob=0.1,
difficulty_ramping=True, random_seed=None):
difficulty_ramping=True):
env_module = import_module('minatar.environments.' + env_name)
self.random = np.random.RandomState(random_seed)
self.random = np.random.RandomState()
self.env_name = env_name
self.env = env_module.Env(
ramping=difficulty_ramping, random_state=self.random)
self.env = env_module.Env(ramping=difficulty_ramping)
self.n_channels = self.env.state_shape()[2]
self.sticky_action_prob = sticky_action_prob
self.last_action = 0
self.visualized = False
self.closed = False

# Seeding numpy random for reproducibility
def seed(self, seed=None):
if seed is not None:
self.random = np.random.RandomState(seed)
self.env.random = self.random

# Wrapper for env.act
def act(self, a):
if(self.random.rand() < self.sticky_action_prob):
if self.random.rand() < self.sticky_action_prob:
a = self.last_action
self.last_action = a
return self.env.act(a)
Expand Down Expand Up @@ -61,15 +74,7 @@ def minimal_action_set(self):

# Display the current environment state for time milliseconds using matplotlib
def display_state(self, time=50):
if(not self.visualized):
global plt
global colors
global sns
mpl = __import__('matplotlib.pyplot', globals(), locals())
plt = mpl.pyplot
mpl = __import__('matplotlib.colors', globals(), locals())
colors = mpl.colors
sns = __import__('seaborn', globals(), locals())
if not self.visualized:
self.cmap = sns.color_palette("cubehelix", self.n_channels)
self.cmap.insert(0, (0,0,0))
self.cmap = colors.ListedColormap(self.cmap)
Expand All @@ -78,7 +83,7 @@ def display_state(self, time=50):
_, self.ax = plt.subplots(1,1)
plt.show(block=False)
self.visualized = True
if(self.closed):
if self.closed:
_, self.ax = plt.subplots(1,1)
plt.show(block=False)
self.closed = False
Expand Down
23 changes: 10 additions & 13 deletions minatar/environments/asterix.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


#####################################################################################################################
# Env
# Env
#
# The player can move freely along the 4 cardinal directions. Enemies and treasure spawn from the sides. A reward of
# +1 is given for picking up treasure. Termination occurs if the player makes contact with an enemy. Enemy and
Expand All @@ -26,7 +26,7 @@
#
#####################################################################################################################
class Env:
def __init__(self, ramping = True, random_state = None):
def __init__(self, ramping=True):
self.channels ={
'player':0,
'enemy':1,
Expand All @@ -35,18 +35,15 @@ def __init__(self, ramping = True, random_state = None):
}
self.action_map = ['n','l','u','r','d','f']
self.ramping = ramping
if random_state is None:
self.random = np.random.RandomState()
else:
self.random = random_state
self.random = np.random.RandomState()
self.reset()

# Update environment according to agent action
def act(self, a):
r = 0
if(self.terminal):
return r, self.terminal

a = self.action_map[a]

# Spawn enemy if timer is up
Expand All @@ -65,9 +62,9 @@ def act(self, a):
self.player_y = min(8, self.player_y+1)

# Update entities
for i in range(len(self.entities)):
x = self.entities[i]
if(x is not None):
for i in range(len(self.entities)):
x = self.entities[i]
if(x is not None):
if(x[0:2]==[self.player_x,self.player_y]):
if(self.entities[i][3]):
self.entities[i] = None
Expand Down Expand Up @@ -109,13 +106,13 @@ def act(self, a):

# Spawn a new enemy or treasure at a random location with random direction (if all rows are filled do nothing)
def _spawn_entity(self):
lr = self.random.choice([True,False])
is_gold = self.random.choice([True,False], p=[1/3,2/3])
lr = self.random.rand() < 1/2
is_gold = self.random.rand() < 1/3
x = 0 if lr else 9
slot_options = [i for i in range(len(self.entities)) if self.entities[i]==None]
if(not slot_options):
return
slot = self.random.choice(slot_options)
slot = slot_options[self.random.randint(len(slot_options))]
self.entities[slot] = [x,slot+1,lr,is_gold]

# Query the current level of the difficulty ramp, could be used as additional input to agent for example
Expand Down
19 changes: 8 additions & 11 deletions minatar/environments/breakout.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,31 @@
#####################################################################################################################
# Env
#
# The player controls a paddle on the bottom of the screen and must bounce a ball tobreak 3 rows of bricks along the
# top of the screen. A reward of +1 is given for each brick broken by the ball. When all bricks are cleared another 3
# rows are added. The ball travels only along diagonals, when it hits the paddle it is bounced either to the left or
# The player controls a paddle on the bottom of the screen and must bounce a ball tobreak 3 rows of bricks along the
# top of the screen. A reward of +1 is given for each brick broken by the ball. When all bricks are cleared another 3
# rows are added. The ball travels only along diagonals, when it hits the paddle it is bounced either to the left or
# right depending on the side of the paddle hit, when it hits a wall or brick it is reflected. Termination occurs when
# the ball hits the bottom of the screen. The balls direction is indicated by a trail channel.
#
#####################################################################################################################
class Env:
def __init__(self, ramping = None, random_state = None):
def __init__(self, ramping=None):
self.channels ={
'paddle':0,
'ball':1,
'trail':2,
'brick':3,
}
self.action_map = ['n','l','u','r','d','f']
if random_state is None:
self.random = np.random.RandomState()
else:
self.random = random_state
self.random = np.random.RandomState()
self.reset()

# Update environment according to agent action
def act(self, a):
r = 0
if(self.terminal):
return r, self.terminal

a = self.action_map[a]

# Resolve player action
Expand Down Expand Up @@ -100,7 +97,7 @@ def act(self, a):

# Query the current level of the difficulty ramp, difficulty does not ramp in this game, so return None
def difficulty_ramp(self):
return None
return None

# Process the game-state into the 10x10xn state provided to the agent and return
def state(self):
Expand All @@ -114,7 +111,7 @@ def state(self):
# Reset to start state for new episode
def reset(self):
self.ball_y = 3
ball_start = self.random.choice(2)
ball_start = self.random.randint(2)
self.ball_x, self.ball_dir = [(0,2),(9,3)][ball_start]
self.pos = 4
self.brick_map = np.zeros((10,10))
Expand Down
27 changes: 12 additions & 15 deletions minatar/environments/freeway.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@
#####################################################################################################################
# Env
#
# The player begins at the bottom of the screen and motion is restricted to traveling up and down. Player speed is
# also restricted such that the player can only move every 3 frames. A reward of +1 is given when the player reaches
# the top of the screen, at which point the player is returned to the bottom. Cars travel horizontally on the screen
# and teleport to the other side when the edge is reached. When hit by a car, the player is returned to the bottom of
# the screen. Car direction and speed is indicated by 5 trail channels, the location of the trail gives direction
# while the specific channel indicates how frequently the car moves (from once every frame to once every 5 frames).
# Each time the player successfully reaches the top of the screen, the car speeds are randomized. Termination occurs
# The player begins at the bottom of the screen and motion is restricted to traveling up and down. Player speed is
# also restricted such that the player can only move every 3 frames. A reward of +1 is given when the player reaches
# the top of the screen, at which point the player is returned to the bottom. Cars travel horizontally on the screen
# and teleport to the other side when the edge is reached. When hit by a car, the player is returned to the bottom of
# the screen. Car direction and speed is indicated by 5 trail channels, the location of the trail gives direction
# while the specific channel indicates how frequently the car moves (from once every frame to once every 5 frames).
# Each time the player successfully reaches the top of the screen, the car speeds are randomized. Termination occurs
# after 2500 frames have elapsed.
#
#####################################################################################################################
class Env:
def __init__(self, ramping = None, random_state = None):
def __init__(self, ramping=None):
self.channels ={
'chicken':0,
'car':1,
Expand All @@ -39,18 +39,15 @@ def __init__(self, ramping = None, random_state = None):
'speed5':6,
}
self.action_map = ['n','l','u','r','d','f']
if random_state is None:
self.random = np.random.RandomState()
else:
self.random = random_state
self.random = np.random.RandomState()
self.reset()

# Update environment according to agent action
def act(self, a):
r = 0
if(self.terminal):
return r, self.terminal

a = self.action_map[a]

if(a=='u' and self.move_timer==0):
Expand Down Expand Up @@ -91,7 +88,7 @@ def act(self, a):

# Query the current level of the difficulty ramp, difficulty does not ramp in this game, so return None
def difficulty_ramp(self):
return None
return None

# Process the game-state into the 10x10xn state provided to the agent and return
def state(self):
Expand Down Expand Up @@ -120,7 +117,7 @@ def state(self):
# Randomize car speeds and directions, also reset their position if initialize=True
def _randomize_cars(self, initialize=False):
speeds = self.random.randint(1,6,8)
directions = self.random.choice([-1,1],8)
directions = np.sign(self.random.rand(8) - 0.5).astype(int)
speeds*=directions
if(initialize):
self.cars = []
Expand Down
45 changes: 21 additions & 24 deletions minatar/environments/seaquest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,25 @@


#####################################################################################################################
# Env
# Env
#
# The player controls a submarine consisting of two cells, front and back, to allow direction to be determined. The
# player can also fire bullets from the front of the submarine. Enemies consist of submarines and fish, distinguished
# by the fact that submarines shoot bullets and fish do not. A reward of +1 is given each time an enemy is struck by
# one of the player's bullets, at which point the enemy is also removed. There are also divers which the player can
# move onto to pick up, doing so increments a bar indicated by another channel along the bottom of the screen. The
# player also has a limited supply of oxygen indicated by another bar in another channel. Oxygen degrades over time,
# and is replenished whenever the player moves to the top of the screen as long as the player has at least one rescued
# diver on board. The player can carry a maximum of 6 divers. When surfacing with less than 6, one diver is removed.
# When surfacing with 6, all divers are removed and a reward is given for each active cell in the oxygen bar. Each
# time the player surfaces the difficulty is increased by increasing the spawn rate and movement speed of enemies.
# Termination occurs when the player is hit by an enemy fish, sub or bullet; or when oxygen reached 0; or when the
# player attempts to surface with no rescued divers. Enemy and diver directions are indicated by a trail channel
# The player controls a submarine consisting of two cells, front and back, to allow direction to be determined. The
# player can also fire bullets from the front of the submarine. Enemies consist of submarines and fish, distinguished
# by the fact that submarines shoot bullets and fish do not. A reward of +1 is given each time an enemy is struck by
# one of the player's bullets, at which point the enemy is also removed. There are also divers which the player can
# move onto to pick up, doing so increments a bar indicated by another channel along the bottom of the screen. The
# player also has a limited supply of oxygen indicated by another bar in another channel. Oxygen degrades over time,
# and is replenished whenever the player moves to the top of the screen as long as the player has at least one rescued
# diver on board. The player can carry a maximum of 6 divers. When surfacing with less than 6, one diver is removed.
# When surfacing with 6, all divers are removed and a reward is given for each active cell in the oxygen bar. Each
# time the player surfaces the difficulty is increased by increasing the spawn rate and movement speed of enemies.
# Termination occurs when the player is hit by an enemy fish, sub or bullet; or when oxygen reached 0; or when the
# player attempts to surface with no rescued divers. Enemy and diver directions are indicated by a trail channel
# active in their previous location to reduce partial observability.
#
#####################################################################################################################
class Env:
def __init__(self, ramping = True, random_state = None):
def __init__(self, ramping=True):
self.channels ={
'sub_front':0,
'sub_back':1,
Expand All @@ -55,10 +55,7 @@ def __init__(self, ramping = True, random_state = None):
}
self.action_map = ['n','l','u','r','d','f']
self.ramping = ramping
if random_state is None:
self.random = np.random.RandomState()
else:
self.random = random_state
self.random = np.random.RandomState()
self.reset()

# Update environment according to agent action
Expand Down Expand Up @@ -207,7 +204,7 @@ def act(self, a):
r+=self._surface()
return r, self.terminal

# Called when player hits surface (top row) if they have no divers, this ends the game,
# Called when player hits surface (top row) if they have no divers, this ends the game,
# if they have 6 divers this gives reward proportional to the remaining oxygen and restores full oxygen
# otherwise this reduces the number of divers and restores full oxygen
def _surface(self):
Expand All @@ -230,10 +227,10 @@ def _surface(self):
# Spawn an enemy fish or submarine in random row and random direction,
# if the resulting row and direction would lead to a collision, do nothing instead
def _spawn_enemy(self):
lr = self.random.choice([True,False])
is_sub = self.random.choice([True,False], p=[1/3,2/3])
lr = self.random.rand() < 1/2
is_sub = self.random.rand() < 1/3
x = 0 if lr else 9
y = self.random.choice(np.arange(1,9))
y = self.random.randint(low=1, high=9)

# Do not spawn in same row an opposite direction as existing
if(any([z[1]==y and z[2]!=lr for z in self.e_subs+self.e_fish])):
Expand All @@ -245,9 +242,9 @@ def _spawn_enemy(self):

# Spawn a diver in random row with random direction
def _spawn_diver(self):
lr = self.random.choice([True,False])
lr = self.random.rand() < 1/2
x = 0 if lr else 9
y = self.random.choice(np.arange(1,9))
y = self.random.randint(low=1, high=9)
self.divers+=[[x,y,lr,diver_move_interval]]

# Query the current level of the difficulty ramp, could be used as additional input to agent for example
Expand Down
Loading

1 comment on commit 42741e5

@kenjyoung
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, this update removes the random_seed argument in environment initialization. Any existing code using this argument should be updated to call env.seed(random_seed) after initialization instead.

Please sign in to comment.