Skip to content

Commit

Permalink
Allow setting of sticky action probability, setting fixed random seed…
Browse files Browse the repository at this point in the history
…, and toggling difficulty ramp via environment initializer
  • Loading branch information
kenjyoung committed Mar 7, 2019
1 parent b77ac3a commit cfa4227
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 51 deletions.
16 changes: 6 additions & 10 deletions environments/asterix.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#
#####################################################################################################################
class Env:
def __init__(self, ramping = True):
def __init__(self, ramping = True, seed = None):
self.channels ={
'player':0,
'enemy':1,
Expand All @@ -35,6 +35,7 @@ def __init__(self, ramping = True):
}
self.action_map = ['n','l','u','r','d','f']
self.ramping = ramping
self.random = np.random.RandomState(seed)
self.reset()

# Update environment according to agent action
Expand All @@ -43,10 +44,7 @@ def act(self, a):
if(self.terminal):
return r, self.terminal

if(np.random.rand()>0.1):
a = self.action_map[a]
else:
a = self.last_action
a = self.action_map[a]

# Spawn enemy if timer is up
if(self.spawn_timer==0):
Expand Down Expand Up @@ -104,18 +102,17 @@ def act(self, a):
self.spawn_speed-=1
self.ramp_index+=1
self.ramp_timer=ramp_interval
self.last_action = a
return r, self.terminal

# Spawn a new enemy or treasure at a random location with random direction (if all rows are filled do nothing)
def _spawn_entity(self):
lr = np.random.choice([True,False])
is_gold = np.random.choice([True,False], p=[1/3,2/3])
lr = self.random.choice([True,False])
is_gold = self.random.choice([True,False], p=[1/3,2/3])
x = 0 if lr else 9
slot_options = [i for i in range(len(self.entities)) if self.entities[i]==None]
if(not slot_options):
return
slot = np.random.choice(slot_options)
slot = self.random.choice(slot_options)
self.entities[slot] = [x,slot+1,lr,is_gold]

# Query the current level of the difficulty ramp, could be used as additional input to agent for example
Expand Down Expand Up @@ -147,7 +144,6 @@ def reset(self):
self.move_timer = self.move_speed
self.ramp_timer = ramp_interval
self.ramp_index = 0
self.last_action = 0
self.terminal = False

# Dimensionality of the game-state (10x10xn)
Expand Down
12 changes: 4 additions & 8 deletions environments/breakout.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
#
#####################################################################################################################
class Env:
def __init__(self):
def __init__(self, ramping = None, seed = None):
self.channels ={
'paddle':0,
'ball':1,
'trail':2,
'brick':3,
}
self.action_map = ['n','l','u','r','d','f']
self.random = np.random.RandomState(seed)
self.reset()

# Update environment according to agent action
Expand All @@ -33,10 +34,7 @@ def act(self, a):
if(self.terminal):
return r, self.terminal

if(np.random.rand()>0.1):
a = self.action_map[a]
else:
a = self.last_action
a = self.action_map[a]

# Resolve player action
if(a=='l'):
Expand Down Expand Up @@ -95,7 +93,6 @@ def act(self, a):

self.ball_x = new_x
self.ball_y = new_y
self.last_action = a
return r, self.terminal

# Query the current level of the difficulty ramp, difficulty does not ramp in this game, so return None
Expand All @@ -114,12 +111,11 @@ def state(self):
# Reset to start state for new episode
def reset(self):
self.ball_y = 3
ball_start = np.random.choice(2)
ball_start = self.random.choice(2)
self.ball_x, self.ball_dir = [(0,2),(9,3)][ball_start]
self.pos = 4
self.brick_map = np.zeros((10,10))
self.brick_map[1:4,:] = 1
self.last_action = 0
self.strike = False
self.last_x = self.ball_x
self.last_y = self.ball_y
Expand Down
14 changes: 5 additions & 9 deletions environments/freeway.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#
#####################################################################################################################
class Env:
def __init__(self):
def __init__(self, ramping = None, seed = None):
self.channels ={
'chicken':0,
'car':1,
Expand All @@ -39,6 +39,7 @@ def __init__(self):
'speed5':6,
}
self.action_map = ['n','l','u','r','d','f']
self.random = np.random.RandomState(seed)
self.reset()

# Update environment according to agent action
Expand All @@ -47,10 +48,7 @@ def act(self, a):
if(self.terminal):
return r, self.terminal

if(np.random.rand()>0.1):
a = self.action_map[a]
else:
a = self.last_action
a = self.action_map[a]

if(a=='u' and self.move_timer==0):
self.move_timer = player_speed
Expand Down Expand Up @@ -86,7 +84,6 @@ def act(self, a):
self.terminate_timer-=1
if(self.terminate_timer<0):
self.terminal = True
self.last_action = a
return r, self.terminal

# Query the current level of the difficulty ramp, difficulty does not ramp in this game, so return None
Expand Down Expand Up @@ -119,8 +116,8 @@ def state(self):

# Randomize car speeds and directions, also reset their position if initialize=True
def _randomize_cars(self, initialize=False):
speeds = np.random.randint(1,6,8)
directions = np.random.choice([-1,1],8)
speeds = self.random.randint(1,6,8)
directions = self.random.choice([-1,1],8)
speeds*=directions
if(initialize):
self.cars = []
Expand All @@ -136,7 +133,6 @@ def reset(self):
self.pos = 9
self.move_timer = player_speed
self.terminate_timer = time_limit
self.last_action = 0
self.terminal = False

# Dimensionality of the game-state (10x10xn)
Expand Down
21 changes: 8 additions & 13 deletions environments/seaquest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
#
#####################################################################################################################
class Env:
def __init__(self, ramping = True):
def __init__(self, ramping = True, seed = None):
self.channels ={
'sub_front':0,
'sub_back':1,
Expand All @@ -55,6 +55,7 @@ def __init__(self, ramping = True):
}
self.action_map = ['n','l','u','r','d','f']
self.ramping = ramping
self.random = np.random.RandomState(seed)
self.reset()

# Update environment according to agent action
Expand All @@ -63,10 +64,7 @@ def act(self, a):
if(self.terminal):
return r, self.terminal

if(np.random.rand()>0.1):
a = self.action_map[a]
else:
a = self.last_action
a = self.action_map[a]

# Spawn enemy if timer is up
if(self.e_spawn_timer==0):
Expand Down Expand Up @@ -203,8 +201,6 @@ def act(self, a):
self.terminal = True
else:
r+=self._surface()

self.last_action = a
return r, self.terminal

# Called when player hits surface (top row) if they have no divers, this ends the game,
Expand All @@ -230,10 +226,10 @@ def _surface(self):
# Spawn an enemy fish or submarine in random row and random direction,
# if the resulting row and direction would lead to a collision, do nothing instead
def _spawn_enemy(self):
lr = np.random.choice([True,False])
is_sub = np.random.choice([True,False], p=[1/3,2/3])
lr = self.random.choice([True,False])
is_sub = self.random.choice([True,False], p=[1/3,2/3])
x = 0 if lr else 9
y = np.random.choice(np.arange(1,9))
y = self.random.choice(np.arange(1,9))

# Do not spawn in same row an opposite direction as existing
if(any([z[1]==y and z[2]!=lr for z in self.e_subs+self.e_fish])):
Expand All @@ -245,9 +241,9 @@ def _spawn_enemy(self):

# Spawn a diver in random row with random direction
def _spawn_diver(self):
lr = np.random.choice([True,False])
lr = self.random.choice([True,False])
x = 0 if lr else 9
y = np.random.choice(np.arange(1,9))
y = self.random.choice(np.arange(1,9))
self.divers+=[[x,y,lr,diver_move_interval]]

# Query the current level of the difficulty ramp, could be used as additional input to agent for example
Expand Down Expand Up @@ -304,7 +300,6 @@ def reset(self):
self.ramp_index = 0
self.shot_timer = 0
self.surface = True
self.last_action = 0
self.terminal = False

# Dimensionality of the game-state (10x10xn)
Expand Down
11 changes: 3 additions & 8 deletions environments/space_invaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#
#####################################################################################################################
class Env:
def __init__(self, ramping = True):
def __init__(self, ramping = True, seed = None):
self.channels ={
'cannon':0,
'alien':1,
Expand All @@ -40,6 +40,7 @@ def __init__(self, ramping = True):
}
self.action_map = ['n','l','u','r','d','f']
self.ramping = ramping
self.random = np.random.RandomState(seed)
self.reset()

# Update environment according to agent action
Expand All @@ -48,10 +49,7 @@ def act(self, a):
if(self.terminal):
return r, self.terminal

if(np.random.rand()>0.1):
a = self.action_map[a]
else:
a = self.last_action
a = self.action_map[a]

# Resolve player action
if(a=='f' and self.shot_timer == 0):
Expand Down Expand Up @@ -106,8 +104,6 @@ def act(self, a):
self.enemy_move_interval-=1
self.ramp_index+=1
self.alien_map[0:4,2:8] = 1

self.last_action = a
return r, self.terminal

# Find the alien closest to player in manhattan distance, currently used to decide which alien shoots
Expand Down Expand Up @@ -149,7 +145,6 @@ def reset(self):
self.alien_shot_timer = enemy_shot_interval
self.ramp_index = 0
self.shot_timer = 0
self.last_action = 0
self.terminal = False

# Dimensionality of the game-state (10x10xn)
Expand Down
10 changes: 8 additions & 2 deletions minatar_environment/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Tian Tian ([email protected]) #
################################################################################################################
from importlib import import_module
import numpy as np


#####################################################################################################################
Expand All @@ -14,14 +15,19 @@
#
#####################################################################################################################
class Environment:
def __init__(self, env_name):
def __init__(self, env_name, sticky_action_prob = 0.1, difficulty_ramping = True, random_seed = None):
env_module = import_module('minatar.environments.'+env_name)
self.env_name = env_name
self.env = env_module.Env()
self.env = env_module.Env(ramping = difficulty_ramping, seed = random_seed)
self.n_channels = self.env.state_shape()[2]
self.sticky_action_prob = sticky_action_prob
self.last_action = 0

# Wrapper for env.act
def act(self, a):
if(np.random.rand()<self.sticky_action_prob):
a = self.last_action
self.last_action = a
return self.env.act(a)

# Wrapper for env.state
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from distutils.core import setup

setup(name='MinAtar',
version='1.0',
version='1.0.1',
description='A miniaturized version of the arcade learning environment.',
url='https://github.com/kenjyoung/MinAtar',
author='Kenny Young',
Expand Down

0 comments on commit cfa4227

Please sign in to comment.