Skip to content

Commit

Permalink
Updated AC_lambda.py and the results in README.md to be consistent wi…
Browse files Browse the repository at this point in the history
…th most recent paper
  • Loading branch information
kenjyoung committed Feb 14, 2020
1 parent ef52ed0 commit 87c2bc8
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 34 deletions.
19 changes: 10 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# MinAtar
MinAtar is a testbed for AI agents which implements miniaturized version of several Atari 2600 games. MinAtar is inspired by the Arcade Learning Environment (Bellemare et. al. 2013), but simplifies the games to make experimentation with the environments more accessible and efficient. Currently, MinAtar provides analogues to five Atari games which play out on a 10x10 grid. The environments provide a 10x10xn state representation, where each of the n channels correspond to game-specific objects, such as ball, paddle and brick in the game Breakout.
MinAtar is a testbed for AI agents which implements miniaturized versions of several Atari 2600 games. MinAtar is inspired by the Arcade Learning Environment (Bellemare et. al. 2013) but simplifies the games to make experimentation with the environments more accessible and efficient. Currently, MinAtar provides analogues to five Atari games which play out on a 10x10 grid. The environments provide a 10x10xn state representation, where each of the n channels correspond to a game-specific object, such as ball, paddle and brick in the game Breakout.

<p align="center">
<img src="img/seaquest.gif" width="200" />
Expand Down Expand Up @@ -51,7 +51,7 @@ To play a game as a human, run examples/human_play.py as follows:
```bash
python human_play.py -g <game>
```
Use the arrow keys to move and space bar to fire. Also press q to quit and r to reset.
Use the arrow keys to move and space bar to fire. Also, press q to quit and r to reset.

Also included in the examples directory are example implementations of DQN (dqn.py) and online actor-critic with eligibility traces (AC_lambda.py).

Expand All @@ -70,7 +70,7 @@ This is the simplest way to visualize the environments, unless you need to handl


### Using GUI class
We also include a slightly more complex GUI for visualize the environments and optionally handle user input. This GUI is used in examples/human_play.py to play as a human and examples/agent_play.py to visualize the performance of trained agents. To use the GUI you can import it in your code with:
We also include a slightly more complex GUI to visualize the environments and optionally handle user input. This GUI is used in examples/human_play.py to play as a human and examples/agent_play.py to visualize the performance of trained agents. To use the GUI you can import it in your code with:
```python
from minatar import GUI
```
Expand All @@ -96,12 +96,13 @@ gui.run()
This will enter the agent environment interaction loop and then run the GUI thread, gui.run() will block until gui.quit() is called. To handle user input you can use gui.overwrite_key_handle(on_key_event, on_release_event). The arguments are functions to be called whenever a key is pressed, and released respectively. For an example of how to do this see examples/human_play.py.

## Results
The following plots display results for DQN (Mnih et al., 2015) and actor-critic with eligibility traces. Our DQN agent uses a significantly smaller network. We perform an ablation study of DQN, and display results for variants without experience replay, and without a separate target network. Our AC agent uses a similar architecture to DQN, but does not use experience replay. We display results for two values of the trace decay parameter, 0.8 and 0.0. Each curve is the average of 30 independent runs with different random seeds. For further information, see the paper on MinAtar available [here](https://arxiv.org/abs/1903.03176).
The following plots display results for DQN (Mnih et al., 2015) and actor-critic (AC) with eligibility traces. Our DQN agent uses a significantly smaller network compared to that of Mnih et al., 2015. We display results for DQN with and without experience reply. Our AC agent uses a similar architecture to DQN, but does not use experience replay. We display results for two values of the trace decay parameter, 0.8 and 0.0. Each curve is the average of 30 independent runs with different random seeds. The top plots display the sensitivity of final performance to the step-size parameter, while the bottom plots display the average return during training as a function of training frames. For further information, see the paper on MinAtar available [here](https://arxiv.org/abs/1903.03176).

<img align="center" src="img/results.gif" width=800>
<img align="center" src="img/sensitivity_curves.gif" width=800>
<img align="center" src="img/learning_curves.gif" width=800>

## Games
So far we have implemented analogues to five Atari games in MinAtar as follows. For each game we include a link to a video of a trained DQN agent playing.
So far we have implemented analogues to five Atari games in MinAtar as follows. For each game, we include a link to a video of a trained DQN agent playing.

### Asterix
The player can move freely along the 4 cardinal directions. Enemies and treasure spawn from the sides. A reward of +1 is given for picking up treasure. Termination occurs if the player makes contact with an enemy. Enemy and treasure direction are indicated by a trail channel. Difficulty is periodically increased by increasing the speed and spawn rate of enemies and treasure.
Expand All @@ -114,12 +115,12 @@ The player controls a paddle on the bottom of the screen and must bounce a ball
[Video](https://www.youtube.com/watch?v=cFk4efZNNVI&t)

### Freeway
The player begins at the bottom of the screen and the motion is restricted to traveling up and down. Player speed is also restricted such that the player can only move every 3 frames. A reward of +1 is given when the player reaches the top of the screen, at which point the player is returned to the bottom. Cars travel horizontally on the screen and teleport to the other side when the edge is reached. When hit by a car, the player is returned to the bottom of the screen. Car direction and speed is indicated by 5 trail channels. The location of the trail gives direction while the specific channel indicates how frequently the car moves (from once every frame to once every 5 frames). Each time the player successfully reaches the top of the screen, the car speeds are randomized. Termination occurs after 2500 frames have elapsed.
The player begins at the bottom of the screen and the motion is restricted to travelling up and down. Player speed is also restricted such that the player can only move every 3 frames. A reward of +1 is given when the player reaches the top of the screen, at which point the player is returned to the bottom. Cars travel horizontally on the screen and teleport to the other side when the edge is reached. When hit by a car, the player is returned to the bottom of the screen. Car direction and speed is indicated by 5 trail channels. The location of the trail gives direction while the specific channel indicates how frequently the car moves (from once every frame to once every 5 frames). Each time the player successfully reaches the top of the screen, the car speeds are randomized. Termination occurs after 2500 frames have elapsed.

[Video](https://www.youtube.com/watch?v=gbj4jiTcryw)

### Seaquest
The player controls a submarine consisting of two cells, front and back, to allow direction to be determined. The player can also fire bullets from the front of the submarine. Enemies consist of submarines and fish, distinguished by the fact that submarines shoot bullets and fish do not. A reward of +1 is given each time an enemy is struck by one of the player's bullets, at which point the enemy is also removed. There are also divers which the player can move onto to pick up, doing so increments a bar indicated by another channel along the bottom of the screen. The player also has a limited supply of oxygen indicated by another bar in another channel. Oxygen degrades over time, and is replenished whenever the player moves to the top of the screen as long as the player has at least one rescued diver on board. The player can carry a maximum of 6 divers. When surfacing with less than 6, one diver is removed. When surfacing with 6, all divers are removed and a reward is given for each active cell in the oxygen bar. Each time the player surfaces the difficulty is increased by increasing the spawn rate and movement speed of enemies. Termination occurs when the player is hit by an enemy fish, sub or bullet; or when oxygen reaches 0; or when the player attempts to surface with no rescued divers. Enemy and diver directions are indicated by a trail channel active in their previous location to reduce partial observability.
The player controls a submarine consisting of two cells, front and back, to allow direction to be determined. The player can also fire bullets from the front of the submarine. Enemies consist of submarines and fish, distinguished by the fact that submarines shoot bullets and fish do not. A reward of +1 is given each time an enemy is struck by one of the player's bullets, at which point the enemy is also removed. There are also divers which the player can move onto to pick up, doing so increments a bar indicated by another channel along the bottom of the screen. The player also has a limited supply of oxygen indicated by another bar in another channel. Oxygen degrades over time and is replenished whenever the player moves to the top of the screen as long as the player has at least one rescued diver on board. The player can carry a maximum of 6 divers. When surfacing with less than 6, one diver is removed. When surfacing with 6, all divers are removed and a reward is given for each active cell in the oxygen bar. Each time the player surfaces the difficulty is increased by increasing the spawn rate and movement speed of enemies. Termination occurs when the player is hit by an enemy fish, sub or bullet; or when oxygen reaches 0; or when the player attempts to surface with no rescued divers. Enemy and diver directions are indicated by a trail channel active in their previous location to reduce partial observability.

[Video](https://www.youtube.com/watch?v=W9k38b5QPxA&t)

Expand Down Expand Up @@ -162,4 +163,4 @@ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
along with this program. If not, see <http://www.gnu.org/licenses/>.
60 changes: 36 additions & 24 deletions examples/AC_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@
#
#####################################################################################################################
NUM_FRAMES = 5000000
ALPHA = 0.00390625
ALPHA = 0.00048828125
LAMBDA = 0.8
GAMMA = 0.99
BETA = 0.01
GAMMA_RMS = 0.999
EPS_RMS = 0.0001
MIN_DENOM = 0.0001

dSiLU = lambda x: torch.sigmoid(x)*(1+x*(1-torch.sigmoid(x)))
Expand All @@ -47,9 +49,9 @@
#####################################################################################################################
# ACNetwork
#
# Setup the AC-network with one hidden 2D conv with variable number of input channels. We use 16 filters, a quarter of
# the original DQN paper of 64. One hidden fully connected linear layer with a quarter of the original DQN paper of
# 512 rectified units. Finally, we use one output layer which is a fully connected softmax layer with a single output
# Setup the AC-network with one hidden 2D conv with variable number of input channels. We use 16 filters, a quarter of
# the original DQN paper of 64. One hidden fully connected linear layer with a quarter of the original DQN paper of
# 512 rectified units. Finally, we use one output layer which is a fully connected softmax layer with a single output
# for each valid action for the policy network, and another output which is a fully connected linear layer, with a
# single output for the state value.
#
Expand Down Expand Up @@ -110,8 +112,8 @@ def get_state(s):
#####################################################################################################################
# world_dynamics
#
# Responsible for world dynamics. It generates next state and reward after taking an action according to the behavior
# policy. The behavior policy is specified by the policy network output. Reward should be casted to float, otherwise
# Responsible for world dynamics. It generates next state and reward after taking an action according to the behavior
# policy. The behavior policy is specified by the policy network output. Reward should be casted to float, otherwise
# it is LongTensor, which is used for indexing.
#
# Inputs:
Expand All @@ -123,9 +125,9 @@ def get_state(s):
#
#####################################################################################################################
def world_dynamics(s, env, network):
# Since state is 10x10xchannel, we are not dealing with batch here. network(s)[0] specifies the policy network,
# which we use to draw an action according to a multinomial distribution over axis 1, (axis 0 iterates over samples,
# and is unused in this case. torch._no_grad() avoids tracking history in autograd.
# network(s)[0] specifies the policy network, which we use to draw an action according to a multinomial
# distribution over axis 1, (axis 0 iterates over samples, and is unused in this case. torch._no_grad()
# avoids tracking history in autograd.
with torch.no_grad():
action = torch.multinomial(network(s)[0],1)[0]

Expand All @@ -140,8 +142,7 @@ def world_dynamics(s, env, network):
#####################################################################################################################
# train
#
# This is where learning happens. More specifically, this function learns the weights of the policy/value network
# using huber loss.
# This is where learning happens. More specifically, this function updates the weights of the policy/value network.
#
# Inputs:
# sample: a single transition
Expand All @@ -151,7 +152,7 @@ def world_dynamics(s, env, network):
# alpha: learning rate for actor-critic update
#
#####################################################################################################################
def train(sample, traces, grads, network, alpha):
def train(sample, traces, grads, MSGs, network, alpha, time_step):
# states, next_states: (1, in_channel, 10, 10) - inline with pytorch NCHW format
# actions, rewards, is_terminal: (1, 1)
last_state = sample.last_state
Expand Down Expand Up @@ -180,8 +181,12 @@ def train(sample, traces, grads, network, alpha):
with torch.no_grad():
V_last = network(last_state)[1]
delta = GAMMA*(0 if is_terminal else V_curr)+reward-V_last
for param, trace in zip(network.parameters(), traces):
param.copy_(param+alpha*(trace*delta[0]+BETA*param.grad))

# Update uses RMSProp with initialization debiasing
for param, trace, MSG in zip(network.parameters(), traces, MSGs):
grad = trace*delta[0]+BETA*param.grad
MSG.copy_(GAMMA_RMS*MSG+(1-GAMMA_RMS)*grad*grad)
param.copy_(param+alpha*grad/(torch.sqrt(MSG/(1-GAMMA_RMS**(time_step+1))+EPS_RMS)))

# Always update trace
with torch.no_grad():
Expand Down Expand Up @@ -216,9 +221,15 @@ def AC_lambda(env, output_file_name, store_intermediate_result=False, load_path=
# Instantiate networks, optimizer, loss and buffer
network = ACNetwork(in_channels, num_actions).to(device)

# Eligibility traces are stored here
traces = [torch.zeros(x.size(), dtype=torch.float32, device=device) for x in network.parameters()]

# Space allocated to store gradients used in training
grads = [torch.zeros(x.size(), dtype=torch.float32, device=device) for x in network.parameters()]

# Running average of mean squared gradient for use in RMSProp
MSG = [torch.zeros(x.size(), dtype=torch.float32, device=device) for x in network.parameters()]

# Set initial values
e = 0
t = 0
Expand Down Expand Up @@ -252,14 +263,14 @@ def AC_lambda(env, output_file_name, store_intermediate_result=False, load_path=
is_terminated = False
s_last = None
r_last = None
t_last = None
term_last = None
while(not is_terminated) and t < NUM_FRAMES:
# Generate data
s_prime, action, reward, is_terminated = world_dynamics(s, env, network)

sample = transition(s, s_last, action, r_last, t_last)
sample = transition(s, s_last, action, r_last, term_last)

train(sample, traces, grads, network, alpha)
train(sample, traces, grads, MSG, network, alpha, t)

G += reward.item()

Expand All @@ -268,13 +279,15 @@ def AC_lambda(env, output_file_name, store_intermediate_result=False, load_path=
# Continue the process
s_last = s
r_last = reward
t_last = is_terminated
term_last = is_terminated
s = s_prime

# Increment the episodes
e += 1
sample = transition(s, s_last, action, r_last, t_last)
train(sample, traces, grads, network, alpha)
sample = transition(s, s_last, action, r_last, term_last)
train(sample, traces, grads, MSG, network, alpha, t)

# Clear elligibility traces after each episode
for trace in traces:
trace.zero_()

Expand All @@ -287,7 +300,8 @@ def AC_lambda(env, output_file_name, store_intermediate_result=False, load_path=
avg_return = 0.99 * avg_return + 0.01 * G
if e % 1000 == 0:
logging.info("Episode " + str(e) + " | Return: " + str(G) + " | Avg return: " +
str(numpy.around(avg_return, 2)) + " | Frame: " + str(t)+" | Time per frame: " +str((time.time()-t_start)/t) )
str(numpy.around(avg_return, 2)) + " | Frame: " + str(t)+" | Time per frame: " +
str((time.time()-t_start)/t) )

# Save model data and other intermediate data if specified
if store_intermediate_result and e % 1000 == 0:
Expand Down Expand Up @@ -320,14 +334,12 @@ def main():
parser.add_argument("--loadfile", "-l", type=str)
parser.add_argument("--alpha", "-a", type=float, default=ALPHA)
parser.add_argument("--save", "-s", action="store_true")
parser.add_argument("--replayoff", "-r", action="store_true")
parser.add_argument("--targetoff", "-t", action="store_true")
args = parser.parse_args()

if args.verbose:
logging.basicConfig(level=logging.INFO)

# If there's an output specified, then use the user specified output. Otherwise, create file in the current
# If there's an output specified, then use the user specified output. Otherwise, create file in the current
# directory with the game's name.
if args.output:
file_name = args.output
Expand Down
Binary file added img/learning_curves.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed img/results.gif
Binary file not shown.
Binary file added img/sensitivity_curves.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='MinAtar',
version='1.0.4',
version='1.0.5',
description='A miniaturized version of the arcade learning environment.',
url='https://github.com/kenjyoung/MinAtar',
author='Kenny Young',
Expand Down

0 comments on commit 87c2bc8

Please sign in to comment.