Skip to content

Commit

Permalink
A random seed used to initialize an environment will now also be used…
Browse files Browse the repository at this point in the history
… to decide sticky actions. Previous behaviour was unintuitive as the passed random seed would be used for other random events, while sticky actions used numpy's default randomization behaviour.
  • Loading branch information
kenjyoung committed Jun 11, 2021
1 parent 310c7c3 commit 8b39a18
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 13 deletions.
5 changes: 3 additions & 2 deletions minatar/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
class Environment:
def __init__(self, env_name, sticky_action_prob = 0.1, difficulty_ramping = True, random_seed = None):
env_module = import_module('minatar.environments.'+env_name)
self.random = np.random.RandomState(random_seed)
self.env_name = env_name
self.env = env_module.Env(ramping = difficulty_ramping, seed = random_seed)
self.env = env_module.Env(ramping = difficulty_ramping, random_state = self.random)
self.n_channels = self.env.state_shape()[2]
self.sticky_action_prob = sticky_action_prob
self.last_action = 0
Expand All @@ -27,7 +28,7 @@ def __init__(self, env_name, sticky_action_prob = 0.1, difficulty_ramping = True

# Wrapper for env.act
def act(self, a):
if(np.random.rand()<self.sticky_action_prob):
if(self.random.rand()<self.sticky_action_prob):
a = self.last_action
self.last_action = a
return self.env.act(a)
Expand Down
7 changes: 5 additions & 2 deletions minatar/environments/asterix.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#
#####################################################################################################################
class Env:
def __init__(self, ramping = True, seed = None):
def __init__(self, ramping = True, random_state = None):
self.channels ={
'player':0,
'enemy':1,
Expand All @@ -35,7 +35,10 @@ def __init__(self, ramping = True, seed = None):
}
self.action_map = ['n','l','u','r','d','f']
self.ramping = ramping
self.random = np.random.RandomState(seed)
if random_state is None:
self.random = np.random.RandomState()
else:
self.random = random_state
self.reset()

# Update environment according to agent action
Expand Down
7 changes: 5 additions & 2 deletions minatar/environments/breakout.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@
#
#####################################################################################################################
class Env:
def __init__(self, ramping = None, seed = None):
def __init__(self, ramping = None, random_state = None):
self.channels ={
'paddle':0,
'ball':1,
'trail':2,
'brick':3,
}
self.action_map = ['n','l','u','r','d','f']
self.random = np.random.RandomState(seed)
if random_state is None:
self.random = np.random.RandomState()
else:
self.random = random_state
self.reset()

# Update environment according to agent action
Expand Down
7 changes: 5 additions & 2 deletions minatar/environments/freeway.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#
#####################################################################################################################
class Env:
def __init__(self, ramping = None, seed = None):
def __init__(self, ramping = None, random_state = None):
self.channels ={
'chicken':0,
'car':1,
Expand All @@ -39,7 +39,10 @@ def __init__(self, ramping = None, seed = None):
'speed5':6,
}
self.action_map = ['n','l','u','r','d','f']
self.random = np.random.RandomState(seed)
if random_state is None:
self.random = np.random.RandomState()
else:
self.random = random_state
self.reset()

# Update environment according to agent action
Expand Down
7 changes: 5 additions & 2 deletions minatar/environments/seaquest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
#
#####################################################################################################################
class Env:
def __init__(self, ramping = True, seed = None):
def __init__(self, ramping = True, random_state = None):
self.channels ={
'sub_front':0,
'sub_back':1,
Expand All @@ -55,7 +55,10 @@ def __init__(self, ramping = True, seed = None):
}
self.action_map = ['n','l','u','r','d','f']
self.ramping = ramping
self.random = np.random.RandomState(seed)
if random_state is None:
self.random = np.random.RandomState()
else:
self.random = random_state
self.reset()

# Update environment according to agent action
Expand Down
7 changes: 5 additions & 2 deletions minatar/environments/space_invaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#
#####################################################################################################################
class Env:
def __init__(self, ramping = True, seed = None):
def __init__(self, ramping = True, random_state=None):
self.channels ={
'cannon':0,
'alien':1,
Expand All @@ -40,7 +40,10 @@ def __init__(self, ramping = True, seed = None):
}
self.action_map = ['n','l','u','r','d','f']
self.ramping = ramping
self.random = np.random.RandomState(seed)
if random_state is None:
self.random = np.random.RandomState()
else:
self.random = random_state
self.reset()

# Update environment according to agent action
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='MinAtar',
version='1.0.7',
version='1.0.8',
description='A miniaturized version of the arcade learning environment.',
url='https://github.com/kenjyoung/MinAtar',
author='Kenny Young',
Expand Down

0 comments on commit 8b39a18

Please sign in to comment.