From 42741e5217119cdda9155b3cea16f15bd9b90113 Mon Sep 17 00:00:00 2001 From: Simone Parisi Date: Mon, 8 May 2023 10:33:14 -0600 Subject: [PATCH] Fast random (#31) * improved readability * rgb render mode for pixels obs * Update gym.py * Update gym.py * Update gym.py gym\utils\passive_env_checker already checks that * Update gym.py when in human mode, env is always rendered as in gym * Update README.md * replaced choice with better rand functions * better seeding seed and reset do not remake the env from scratch, they just set the random_state the whole code is more gym-like * Update setup.py * Update environment.py * Update gym.py * try import guard * Update README.md * Update gym.py * Minor changes to README * readme * Update README.md * fixed my email domain in setup.py --------- Co-authored-by: Kenny Young --- minatar/environment.py | 35 +++++++++++--------- minatar/environments/asterix.py | 23 ++++++------- minatar/environments/breakout.py | 19 +++++------ minatar/environments/freeway.py | 27 +++++++--------- minatar/environments/seaquest.py | 45 ++++++++++++-------------- minatar/environments/space_invaders.py | 27 +++++++--------- minatar/gym.py | 37 +++++++++------------ setup.py | 2 +- 8 files changed, 100 insertions(+), 115 deletions(-) diff --git a/minatar/environment.py b/minatar/environment.py index 2181897..31ceef8 100644 --- a/minatar/environment.py +++ b/minatar/environment.py @@ -6,6 +6,14 @@ from importlib import import_module import numpy as np +try: + from matplotlib import pyplot as plt + from matplotlib import colors + import seaborn as sns +except: + import logging + logging.warning("Cannot import matplotlib and/or seaborn." + "Will not be able to render the environment.") ##################################################################################################################### # Environment @@ -16,21 +24,26 @@ ##################################################################################################################### class Environment: def __init__(self, env_name, sticky_action_prob=0.1, - difficulty_ramping=True, random_seed=None): + difficulty_ramping=True): env_module = import_module('minatar.environments.' + env_name) - self.random = np.random.RandomState(random_seed) + self.random = np.random.RandomState() self.env_name = env_name - self.env = env_module.Env( - ramping=difficulty_ramping, random_state=self.random) + self.env = env_module.Env(ramping=difficulty_ramping) self.n_channels = self.env.state_shape()[2] self.sticky_action_prob = sticky_action_prob self.last_action = 0 self.visualized = False self.closed = False + # Seeding numpy random for reproducibility + def seed(self, seed=None): + if seed is not None: + self.random = np.random.RandomState(seed) + self.env.random = self.random + # Wrapper for env.act def act(self, a): - if(self.random.rand() < self.sticky_action_prob): + if self.random.rand() < self.sticky_action_prob: a = self.last_action self.last_action = a return self.env.act(a) @@ -61,15 +74,7 @@ def minimal_action_set(self): # Display the current environment state for time milliseconds using matplotlib def display_state(self, time=50): - if(not self.visualized): - global plt - global colors - global sns - mpl = __import__('matplotlib.pyplot', globals(), locals()) - plt = mpl.pyplot - mpl = __import__('matplotlib.colors', globals(), locals()) - colors = mpl.colors - sns = __import__('seaborn', globals(), locals()) + if not self.visualized: self.cmap = sns.color_palette("cubehelix", self.n_channels) self.cmap.insert(0, (0,0,0)) self.cmap = colors.ListedColormap(self.cmap) @@ -78,7 +83,7 @@ def display_state(self, time=50): _, self.ax = plt.subplots(1,1) plt.show(block=False) self.visualized = True - if(self.closed): + if self.closed: _, self.ax = plt.subplots(1,1) plt.show(block=False) self.closed = False diff --git a/minatar/environments/asterix.py b/minatar/environments/asterix.py index 55acf93..0513c82 100644 --- a/minatar/environments/asterix.py +++ b/minatar/environments/asterix.py @@ -17,7 +17,7 @@ ##################################################################################################################### -# Env +# Env # # The player can move freely along the 4 cardinal directions. Enemies and treasure spawn from the sides. A reward of # +1 is given for picking up treasure. Termination occurs if the player makes contact with an enemy. Enemy and @@ -26,7 +26,7 @@ # ##################################################################################################################### class Env: - def __init__(self, ramping = True, random_state = None): + def __init__(self, ramping=True): self.channels ={ 'player':0, 'enemy':1, @@ -35,10 +35,7 @@ def __init__(self, ramping = True, random_state = None): } self.action_map = ['n','l','u','r','d','f'] self.ramping = ramping - if random_state is None: - self.random = np.random.RandomState() - else: - self.random = random_state + self.random = np.random.RandomState() self.reset() # Update environment according to agent action @@ -46,7 +43,7 @@ def act(self, a): r = 0 if(self.terminal): return r, self.terminal - + a = self.action_map[a] # Spawn enemy if timer is up @@ -65,9 +62,9 @@ def act(self, a): self.player_y = min(8, self.player_y+1) # Update entities - for i in range(len(self.entities)): - x = self.entities[i] - if(x is not None): + for i in range(len(self.entities)): + x = self.entities[i] + if(x is not None): if(x[0:2]==[self.player_x,self.player_y]): if(self.entities[i][3]): self.entities[i] = None @@ -109,13 +106,13 @@ def act(self, a): # Spawn a new enemy or treasure at a random location with random direction (if all rows are filled do nothing) def _spawn_entity(self): - lr = self.random.choice([True,False]) - is_gold = self.random.choice([True,False], p=[1/3,2/3]) + lr = self.random.rand() < 1/2 + is_gold = self.random.rand() < 1/3 x = 0 if lr else 9 slot_options = [i for i in range(len(self.entities)) if self.entities[i]==None] if(not slot_options): return - slot = self.random.choice(slot_options) + slot = slot_options[self.random.randint(len(slot_options))] self.entities[slot] = [x,slot+1,lr,is_gold] # Query the current level of the difficulty ramp, could be used as additional input to agent for example diff --git a/minatar/environments/breakout.py b/minatar/environments/breakout.py index 4df5dcd..3eabc0e 100644 --- a/minatar/environments/breakout.py +++ b/minatar/environments/breakout.py @@ -9,15 +9,15 @@ ##################################################################################################################### # Env # -# The player controls a paddle on the bottom of the screen and must bounce a ball tobreak 3 rows of bricks along the -# top of the screen. A reward of +1 is given for each brick broken by the ball. When all bricks are cleared another 3 -# rows are added. The ball travels only along diagonals, when it hits the paddle it is bounced either to the left or +# The player controls a paddle on the bottom of the screen and must bounce a ball tobreak 3 rows of bricks along the +# top of the screen. A reward of +1 is given for each brick broken by the ball. When all bricks are cleared another 3 +# rows are added. The ball travels only along diagonals, when it hits the paddle it is bounced either to the left or # right depending on the side of the paddle hit, when it hits a wall or brick it is reflected. Termination occurs when # the ball hits the bottom of the screen. The balls direction is indicated by a trail channel. # ##################################################################################################################### class Env: - def __init__(self, ramping = None, random_state = None): + def __init__(self, ramping=None): self.channels ={ 'paddle':0, 'ball':1, @@ -25,10 +25,7 @@ def __init__(self, ramping = None, random_state = None): 'brick':3, } self.action_map = ['n','l','u','r','d','f'] - if random_state is None: - self.random = np.random.RandomState() - else: - self.random = random_state + self.random = np.random.RandomState() self.reset() # Update environment according to agent action @@ -36,7 +33,7 @@ def act(self, a): r = 0 if(self.terminal): return r, self.terminal - + a = self.action_map[a] # Resolve player action @@ -100,7 +97,7 @@ def act(self, a): # Query the current level of the difficulty ramp, difficulty does not ramp in this game, so return None def difficulty_ramp(self): - return None + return None # Process the game-state into the 10x10xn state provided to the agent and return def state(self): @@ -114,7 +111,7 @@ def state(self): # Reset to start state for new episode def reset(self): self.ball_y = 3 - ball_start = self.random.choice(2) + ball_start = self.random.randint(2) self.ball_x, self.ball_dir = [(0,2),(9,3)][ball_start] self.pos = 4 self.brick_map = np.zeros((10,10)) diff --git a/minatar/environments/freeway.py b/minatar/environments/freeway.py index c1597e9..b714c4e 100644 --- a/minatar/environments/freeway.py +++ b/minatar/environments/freeway.py @@ -17,18 +17,18 @@ ##################################################################################################################### # Env # -# The player begins at the bottom of the screen and motion is restricted to traveling up and down. Player speed is -# also restricted such that the player can only move every 3 frames. A reward of +1 is given when the player reaches -# the top of the screen, at which point the player is returned to the bottom. Cars travel horizontally on the screen -# and teleport to the other side when the edge is reached. When hit by a car, the player is returned to the bottom of -# the screen. Car direction and speed is indicated by 5 trail channels, the location of the trail gives direction -# while the specific channel indicates how frequently the car moves (from once every frame to once every 5 frames). -# Each time the player successfully reaches the top of the screen, the car speeds are randomized. Termination occurs +# The player begins at the bottom of the screen and motion is restricted to traveling up and down. Player speed is +# also restricted such that the player can only move every 3 frames. A reward of +1 is given when the player reaches +# the top of the screen, at which point the player is returned to the bottom. Cars travel horizontally on the screen +# and teleport to the other side when the edge is reached. When hit by a car, the player is returned to the bottom of +# the screen. Car direction and speed is indicated by 5 trail channels, the location of the trail gives direction +# while the specific channel indicates how frequently the car moves (from once every frame to once every 5 frames). +# Each time the player successfully reaches the top of the screen, the car speeds are randomized. Termination occurs # after 2500 frames have elapsed. # ##################################################################################################################### class Env: - def __init__(self, ramping = None, random_state = None): + def __init__(self, ramping=None): self.channels ={ 'chicken':0, 'car':1, @@ -39,10 +39,7 @@ def __init__(self, ramping = None, random_state = None): 'speed5':6, } self.action_map = ['n','l','u','r','d','f'] - if random_state is None: - self.random = np.random.RandomState() - else: - self.random = random_state + self.random = np.random.RandomState() self.reset() # Update environment according to agent action @@ -50,7 +47,7 @@ def act(self, a): r = 0 if(self.terminal): return r, self.terminal - + a = self.action_map[a] if(a=='u' and self.move_timer==0): @@ -91,7 +88,7 @@ def act(self, a): # Query the current level of the difficulty ramp, difficulty does not ramp in this game, so return None def difficulty_ramp(self): - return None + return None # Process the game-state into the 10x10xn state provided to the agent and return def state(self): @@ -120,7 +117,7 @@ def state(self): # Randomize car speeds and directions, also reset their position if initialize=True def _randomize_cars(self, initialize=False): speeds = self.random.randint(1,6,8) - directions = self.random.choice([-1,1],8) + directions = np.sign(self.random.rand(8) - 0.5).astype(int) speeds*=directions if(initialize): self.cars = [] diff --git a/minatar/environments/seaquest.py b/minatar/environments/seaquest.py index f718919..be74c3b 100644 --- a/minatar/environments/seaquest.py +++ b/minatar/environments/seaquest.py @@ -22,25 +22,25 @@ ##################################################################################################################### -# Env +# Env # -# The player controls a submarine consisting of two cells, front and back, to allow direction to be determined. The -# player can also fire bullets from the front of the submarine. Enemies consist of submarines and fish, distinguished -# by the fact that submarines shoot bullets and fish do not. A reward of +1 is given each time an enemy is struck by -# one of the player's bullets, at which point the enemy is also removed. There are also divers which the player can -# move onto to pick up, doing so increments a bar indicated by another channel along the bottom of the screen. The -# player also has a limited supply of oxygen indicated by another bar in another channel. Oxygen degrades over time, -# and is replenished whenever the player moves to the top of the screen as long as the player has at least one rescued -# diver on board. The player can carry a maximum of 6 divers. When surfacing with less than 6, one diver is removed. -# When surfacing with 6, all divers are removed and a reward is given for each active cell in the oxygen bar. Each -# time the player surfaces the difficulty is increased by increasing the spawn rate and movement speed of enemies. -# Termination occurs when the player is hit by an enemy fish, sub or bullet; or when oxygen reached 0; or when the -# player attempts to surface with no rescued divers. Enemy and diver directions are indicated by a trail channel +# The player controls a submarine consisting of two cells, front and back, to allow direction to be determined. The +# player can also fire bullets from the front of the submarine. Enemies consist of submarines and fish, distinguished +# by the fact that submarines shoot bullets and fish do not. A reward of +1 is given each time an enemy is struck by +# one of the player's bullets, at which point the enemy is also removed. There are also divers which the player can +# move onto to pick up, doing so increments a bar indicated by another channel along the bottom of the screen. The +# player also has a limited supply of oxygen indicated by another bar in another channel. Oxygen degrades over time, +# and is replenished whenever the player moves to the top of the screen as long as the player has at least one rescued +# diver on board. The player can carry a maximum of 6 divers. When surfacing with less than 6, one diver is removed. +# When surfacing with 6, all divers are removed and a reward is given for each active cell in the oxygen bar. Each +# time the player surfaces the difficulty is increased by increasing the spawn rate and movement speed of enemies. +# Termination occurs when the player is hit by an enemy fish, sub or bullet; or when oxygen reached 0; or when the +# player attempts to surface with no rescued divers. Enemy and diver directions are indicated by a trail channel # active in their previous location to reduce partial observability. # ##################################################################################################################### class Env: - def __init__(self, ramping = True, random_state = None): + def __init__(self, ramping=True): self.channels ={ 'sub_front':0, 'sub_back':1, @@ -55,10 +55,7 @@ def __init__(self, ramping = True, random_state = None): } self.action_map = ['n','l','u','r','d','f'] self.ramping = ramping - if random_state is None: - self.random = np.random.RandomState() - else: - self.random = random_state + self.random = np.random.RandomState() self.reset() # Update environment according to agent action @@ -207,7 +204,7 @@ def act(self, a): r+=self._surface() return r, self.terminal - # Called when player hits surface (top row) if they have no divers, this ends the game, + # Called when player hits surface (top row) if they have no divers, this ends the game, # if they have 6 divers this gives reward proportional to the remaining oxygen and restores full oxygen # otherwise this reduces the number of divers and restores full oxygen def _surface(self): @@ -230,10 +227,10 @@ def _surface(self): # Spawn an enemy fish or submarine in random row and random direction, # if the resulting row and direction would lead to a collision, do nothing instead def _spawn_enemy(self): - lr = self.random.choice([True,False]) - is_sub = self.random.choice([True,False], p=[1/3,2/3]) + lr = self.random.rand() < 1/2 + is_sub = self.random.rand() < 1/3 x = 0 if lr else 9 - y = self.random.choice(np.arange(1,9)) + y = self.random.randint(low=1, high=9) # Do not spawn in same row an opposite direction as existing if(any([z[1]==y and z[2]!=lr for z in self.e_subs+self.e_fish])): @@ -245,9 +242,9 @@ def _spawn_enemy(self): # Spawn a diver in random row with random direction def _spawn_diver(self): - lr = self.random.choice([True,False]) + lr = self.random.rand() < 1/2 x = 0 if lr else 9 - y = self.random.choice(np.arange(1,9)) + y = self.random.randint(low=1, high=9) self.divers+=[[x,y,lr,diver_move_interval]] # Query the current level of the difficulty ramp, could be used as additional input to agent for example diff --git a/minatar/environments/space_invaders.py b/minatar/environments/space_invaders.py index b0c35dc..6d23670 100644 --- a/minatar/environments/space_invaders.py +++ b/minatar/environments/space_invaders.py @@ -16,20 +16,20 @@ ##################################################################################################################### -# Env +# Env # -# The player controls a cannon at the bottom of the screen and can shoot bullets upward at a cluster of aliens above. -# The aliens move across the screen until one of them hits the edge, at which point they all move down and switch -# directions. The current alien direction is indicated by 2 channels (one for left and one for right) one of which is -# active at the location of each alien. A reward of +1 is given each time an alien is shot, and that alien is also -# removed. The aliens will also shoot bullets back at the player. When few aliens are left, alien speed will begin to -# increase. When only one alien is left, it will move at one cell per frame. When a wave of aliens is fully cleared a -# new one will spawn which moves at a slightly faster speed than the last. Termination occurs when an alien or bullet +# The player controls a cannon at the bottom of the screen and can shoot bullets upward at a cluster of aliens above. +# The aliens move across the screen until one of them hits the edge, at which point they all move down and switch +# directions. The current alien direction is indicated by 2 channels (one for left and one for right) one of which is +# active at the location of each alien. A reward of +1 is given each time an alien is shot, and that alien is also +# removed. The aliens will also shoot bullets back at the player. When few aliens are left, alien speed will begin to +# increase. When only one alien is left, it will move at one cell per frame. When a wave of aliens is fully cleared a +# new one will spawn which moves at a slightly faster speed than the last. Termination occurs when an alien or bullet # hits the player. # ##################################################################################################################### class Env: - def __init__(self, ramping = True, random_state=None): + def __init__(self, ramping=True): self.channels ={ 'cannon':0, 'alien':1, @@ -40,10 +40,7 @@ def __init__(self, ramping = True, random_state=None): } self.action_map = ['n','l','u','r','d','f'] self.ramping = ramping - if random_state is None: - self.random = np.random.RandomState() - else: - self.random = random_state + self.random = np.random.RandomState() self.reset() # Update environment according to agent action @@ -97,7 +94,7 @@ def act(self, a): r+=np.sum(kill_locations) self.alien_map[kill_locations] = self.f_bullet_map[kill_locations] = 0 - + # Update various timers self.shot_timer -= self.shot_timer>0 self.alien_move_timer-=1 @@ -121,7 +118,7 @@ def _nearest_alien(self, pos): # Query the current level of the difficulty ramp, could be used as additional input to agent for example def difficulty_ramp(self): return self.ramp_index - + # Process the game-state into the 10x10xn state provided to the agent and return def state(self): state = np.zeros((10,10,len(self.channels)),dtype=bool) diff --git a/minatar/gym.py b/minatar/gym.py index 2a525ce..d2289c9 100644 --- a/minatar/gym.py +++ b/minatar/gym.py @@ -4,20 +4,25 @@ from gym import spaces from gym.envs import register +try: + import seaborn as sns +except: + import logging + logging.warning("Cannot import seaborn." + "Will not be able to train from pixel observations.") + from minatar import Environment class BaseEnv(gym.Env): metadata = {"render_modes": ["human", "array", "rgb_array"]} - def __init__(self, game, display_time=50, use_minimal_action_set=False, - render_mode=None, **kwargs): - self.game_name = game - self.display_time = display_time + def __init__(self, game, render_mode=None, display_time=50, + use_minimal_action_set=False, **kwargs): self.render_mode = render_mode + self.display_time = display_time - self.game_kwargs = kwargs - self.seed() + self.game = Environment(env_name=game, **kwargs) if use_minimal_action_set: self.action_set = self.game.minimal_action_set() @@ -36,26 +41,17 @@ def step(self, action): self.render() return self.game.state(), reward, done, False, {} + def seed(self, seed=None): + self.game.seed(seed) + def reset(self, seed=None, options=None): - if(seed is not None): - self.game = Environment( - env_name=self.game_name, - random_seed=seed, - **self.game_kwargs - ) + if seed is not None: + self.seed(seed) self.game.reset() if self.render_mode == "human": self.render() return self.game.state(), {} - def seed(self, seed=None): - self.game = Environment( - env_name=self.game_name, - random_seed=seed, - **self.game_kwargs - ) - return seed - def render(self): if self.render_mode is None: gym.logger.warn( @@ -71,7 +67,6 @@ def render(self): elif self.render_mode == "rgb_array": # use the same color palette of Environment.display_state state = self.game.state() n_channels = state.shape[-1] - sns = __import__('seaborn', globals(), locals()) cmap = sns.color_palette("cubehelix", n_channels) cmap.insert(0, (0,0,0)) numerical_state = np.amax( diff --git a/setup.py b/setup.py index ec2b5d3..ccca1ec 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ setup( name='MinAtar', - version='1.0.12', + version='1.0.13', description='A miniaturized version of the Arcade Learning Environment.', url='https://github.com/kenjyoung/MinAtar', author='Kenny Young',