Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SPR implementation to atari_100k lab #184

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Working version
  • Loading branch information
MaxASchwarzer committed Sep 24, 2021
commit 13b78754ce7627171d1d65aab10f426d870111cb
2 changes: 1 addition & 1 deletion dopamine/labs/atari_100k/configs/SPR.gin
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ Runner.max_steps_per_episode = 27000 # agent steps
DeterministicOutOfGraphPrioritizedTemporalReplayBuffer.replay_capacity = 200000
DeterministicOutOfGraphPrioritizedTemporalReplayBuffer.batch_size = 32
DeterministicOutOfGraphTemporalReplayBuffer.replay_capacity = 200000
DeterministicOutOfGraphTemporalReplayBuffer.batch_size = 32
DeterministicOutOfGraphTemporalReplayBuffer.batch_size = 32
41 changes: 24 additions & 17 deletions dopamine/labs/atari_100k/spr_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import jax.numpy as jnp
import numpy as onp
import tensorflow as tf
import optax
from dopamine.labs.atari_100k import spr_networks as networks
from dopamine.labs.atari_100k.replay_memory import time_batch_replay_buffer as tdrbs

Expand Down Expand Up @@ -94,14 +95,14 @@ def get_spr_targets(model, states, key):
return results


@functools.partial(jax.jit, static_argnums=(0, 11, 12, 13, 15))
def train(network_def, target_params, optimizer, states, actions, next_states,
@functools.partial(jax.jit, static_argnums=(0, 3, 13, 14, 15, 17))
def train(network_def, online_params, target_params, optimizer, optimizer_state, states, actions, next_states,
rewards, terminals, same_traj_mask, loss_weights, support,
cumulative_gamma, double_dqn, distributional, rng, spr_weight):
"""Run a training step."""

current_state = states[:, 0]
online_params = optimizer.target
online_params = online_params
# Split the current rng into 2 for updating the rng after this call
rng, rng1, rng2 = jax.random.split(rng, num=3)
use_spr = spr_weight > 0
Expand Down Expand Up @@ -198,10 +199,11 @@ def q_online(state, key, actions=None, do_rollout=False):

# Get the unweighted loss without taking its mean for updating priorities.
(mean_loss, (loss, dqn_loss,
spr_loss)), grad = grad_fn(optimizer.target, target, spr_targets,
spr_loss)), grad = grad_fn(online_params, target, spr_targets,
loss_weights)
optimizer = optimizer.apply_gradient(grad)
return optimizer, loss, mean_loss, dqn_loss, spr_loss, rng2
updates, optimizer_state = optimizer.update(grad, optimizer_state)
online_params = optax.apply_updates(online_params, updates)
return optimizer_state, online_params, loss, mean_loss, dqn_loss, spr_loss, rng2


@functools.partial(
Expand Down Expand Up @@ -333,15 +335,15 @@ def __init__(self,

def _build_networks_and_optimizer(self):
self._rng, rng = jax.random.split(self._rng)
online_network_params = self.network_def.init(
self.online_params = self.network_def.init(
rng,
x=self.state,
actions=jnp.zeros((5,)),
do_rollout=self.spr_weight > 0,
support=self._support)
optimizer_def = dqn_agent.create_optimizer(self._optimizer_name)
self.optimizer = optimizer_def.create(online_network_params)
self.target_network_params = copy.deepcopy(online_network_params)
self.optimizer = dqn_agent.create_optimizer(self._optimizer_name)
self.optimizer_state = self.optimizer.init(self.online_params)
self.target_network_params = copy.deepcopy(self.online_params)

def _build_replay_buffer(self):
"""Creates the replay buffer used by the agent."""
Expand Down Expand Up @@ -404,15 +406,20 @@ def _training_step_update(self):
# Uniform weights if not using prioritized replay.
loss_weights = jnp.ones(states.shape[0])

self.optimizer, loss, mean_loss, dqn_loss, spr_loss, self._rng = train(
self.network_def, self.target_network_params, self.optimizer, states,
self.replay_elements['action'], next_states,
self.replay_elements['reward'][:,
0], self.replay_elements['terminal'][:,
0],
self.optimizer_state, self.online_params, loss, mean_loss,\
dqn_loss, spr_loss,\
self._rng = train(
self.network_def, self.online_params, self.target_network_params,
self.optimizer, self.optimizer_state,
states,
self.replay_elements['action'],
next_states,
self.replay_elements['reward'][:, 0],
self.replay_elements['terminal'][:, 0],
self.replay_elements['same_trajectory'][:, 1:], loss_weights,
self._support, self.cumulative_gamma, self._double_dqn,
self._distributional, self._rng, self.spr_weight)
self._distributional, self._rng, self.spr_weight
)

if self._replay_scheme == 'prioritized':
# Rainbow and prioritized replay are parametrized by an exponent
Expand Down
21 changes: 13 additions & 8 deletions dopamine/labs/atari_100k/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@
from absl import logging
from dopamine.discrete_domains import run_experiment
from dopamine.discrete_domains import train as base_train
from dopamine.labs.atari_100k import atari_100k_rainbow_agent
from dopamine.labs.atari_100k import atari_100k_rainbow_agent, spr_agent
from dopamine.labs.atari_100k import eval_run_experiment
import numpy as np
import tensorflow as tf


FLAGS = flags.FLAGS
AGENTS = ['DER', 'DrQ', 'OTRainbow', 'DrQ_eps']
AGENTS = ['DER', 'DrQ', 'OTRainbow', 'DrQ_eps', 'SPR']

# flags are defined when importing run_xm_preprocessing
flags.DEFINE_enum('agent', 'DER', AGENTS, 'Name of the agent.')
Expand All @@ -41,11 +40,17 @@
'Whether to use `MaxEpisodeEvalRunner` or not.')


def create_agent(sess, # pylint: disable=unused-argument
environment,
seed,
summary_writer=None):
def create_agent(
sess, # pylint: disable=unused-argument
environment,
seed,
agent,
summary_writer=None):
"""Helper function for creating full rainbow-based Atari 100k agent."""
if agent == "SPR":
return spr_agent.SPRAgent(num_actions=environment.action_space.n,
seed=seed,
summary_writer=summary_writer)
return atari_100k_rainbow_agent.Atari100kRainbowAgent(
num_actions=environment.action_space.n,
seed=seed,
Expand All @@ -72,7 +77,7 @@ def main(unused_argv):
gin_files, gin_bindings = FLAGS.gin_files, FLAGS.gin_bindings
run_experiment.load_gin_configs(gin_files, gin_bindings)
# Set the Jax agent seed using the run number
create_agent_fn = functools.partial(create_agent, seed=FLAGS.run_number)
create_agent_fn = functools.partial(create_agent, seed=FLAGS.run_number, agent=FLAGS.agent)
if FLAGS.max_episode_eval:
runner_fn = eval_run_experiment.MaxEpisodeEvalRunner
logging.info('Using MaxEpisodeEvalRunner for evaluation.')
Expand Down