From 6dc94f7bd78718d77bc42ed1cea870f6ef4a0c32 Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Thu, 9 Mar 2023 14:02:59 -0500 Subject: [PATCH 01/29] dqn --- actors/dqn.py | 411 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 411 insertions(+) create mode 100644 actors/dqn.py diff --git a/actors/dqn.py b/actors/dqn.py new file mode 100644 index 00000000..630208e9 --- /dev/null +++ b/actors/dqn.py @@ -0,0 +1,411 @@ +# Copyright 2022 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=E0611 + +import logging +import copy +import time +import json +import math +import numpy as np + +import cogment +import torch +from gym.spaces import Discrete, utils + +from cogment_verse.specs import AgentConfig, cog_settings, EnvironmentConfig, EnvironmentSpecs + +from cogment_verse.constants import PLAYER_ACTOR_CLASS, WEB_ACTOR_NAME, HUMAN_ACTOR_IMPL + +from cogment_verse import Model, TorchReplayBuffer # pylint: disable=abstract-class-instantiated + +torch.multiprocessing.set_sharing_strategy("file_system") + +log = logging.getLogger(__name__) + + +def create_linear_schedule(start, end, duration): + slope = (end - start) / duration + + def compute_value(t): + return max(slope * t + start, end) + + return compute_value + + +class DQNModel(Model): + def __init__( + self, + model_id, + environment_implementation, + num_input, + num_output, + num_hidden_nodes, + epsilon, + dtype=torch.float, + version_number=0, + ): + super().__init__(model_id, version_number) + self._dtype = dtype + self._environment_implementation = environment_implementation + self._num_input = num_input + self._num_output = num_output + self._num_hidden_nodes = list(num_hidden_nodes) + + self.epsilon = epsilon + self.network = torch.nn.Sequential( + torch.nn.Linear(self._num_input, self._num_hidden_nodes[0]), + torch.nn.ReLU(), + *[ + layer + for hidden_node_idx in range(len(self._num_hidden_nodes) - 1) + for layer in [ + torch.nn.Linear(self._num_hidden_nodes[hidden_node_idx], self._num_hidden_nodes[-1]), + torch.nn.ReLU(), + ] + ], + torch.nn.Linear(self._num_hidden_nodes[-1], self._num_output), + ) + + # version user data + self.num_samples_seen = 0 + + def get_model_user_data(self): + return { + "environment_implementation": self._environment_implementation, + "num_input": self._num_input, + "num_output": self._num_output, + "num_hidden_nodes": json.dumps(self._num_hidden_nodes), + } + + def save(self, model_data_f): + torch.save((self.network.state_dict(), self.epsilon), model_data_f) + + return {"num_samples_seen": self.num_samples_seen} + + @classmethod + def load(cls, model_id, version_number, model_user_data, version_user_data, model_data_f): + # Create the model instance + model = DQNModel( + model_id=model_id, + version_number=version_number, + environment_implementation=model_user_data["environment_implementation"], + num_input=int(model_user_data["num_input"]), + num_output=int(model_user_data["num_output"]), + num_hidden_nodes=json.loads(model_user_data["num_hidden_nodes"]), + epsilon=0, + ) + + # Load the saved states + (network_state_dict, epsilon) = torch.load(model_data_f) + model.network.load_state_dict(network_state_dict) + model.epsilon = epsilon + + # Load version data + model.num_samples_seen = int(version_user_data["num_samples_seen"]) + + return model + + +class SimpleDQNActor: + def __init__(self, _cfg): + self._dtype = torch.float + self.samples_since_update = 0 + + def get_actor_classes(self): + return [PLAYER_ACTOR_CLASS] + + async def impl(self, actor_session): + actor_session.start() + + config = actor_session.config + + rng = np.random.default_rng(config.seed if config.seed is not None else 0) + + environment_specs = EnvironmentSpecs.deserialize(config.environment_specs) + observation_space = environment_specs.get_observation_space() + action_space = environment_specs.get_action_space(seed=rng.integers(9999)) + + assert isinstance(action_space.gym_space, Discrete) + + model, _, _ = await actor_session.model_registry.retrieve_version( + DQNModel, config.model_id, config.model_version + ) + model.network.eval() + + async for event in actor_session.all_events(): + if event.observation and event.type == cogment.EventType.ACTIVE: + observation = observation_space.deserialize(event.observation.observation) + if observation.current_player is not None and observation.current_player != actor_session.name: + # Not the turn of the agent + actor_session.do_action(action_space.serialize(action_space.create())) + continue + + if ( + config.model_update_frequency > 0 + and self.samples_since_update > 0 + and self.samples_since_update % config.model_update_frequency == 0 + ): + model, _, _ = await actor_session.model_registry.retrieve_version( + DQNModel, config.model_id, config.model_version + ) + model.network.eval() + self.samples_since_update = 0 + + if rng.random() < model.epsilon: + action = action_space.sample(mask=observation.action_mask) + else: + obs_tensor = torch.tensor(observation.flat_value, dtype=self._dtype) + action_probs = model.network(obs_tensor) + action_mask = observation.action_mask + if action_mask is not None: + action_mask_tensor = torch.tensor(action_mask, dtype=self._dtype) + large = torch.finfo(self._dtype).max + if torch.equal(action_mask_tensor, torch.zeros_like(action_mask_tensor)): + log.info("no moves are available, this shouldn't be possible") + action_probs = action_probs - large * (1 - action_mask_tensor) + discrete_action_tensor = torch.argmax(action_probs) + action = action_space.create(value=discrete_action_tensor.item()) + + self.samples_since_update += 1 + actor_session.do_action(action_space.serialize(action)) + + +class SimpleDQNTraining: + default_cfg = { + "seed": 10, + "num_trials": 5000, + "num_parallel_trials": 10, + "learning_rate": 0.00025, + "buffer_size": 10000, + "discount_factor": 0.99, + "target_update_frequency": 500, + "batch_size": 128, + "epsilon_schedule_start": 1, + "epsilon_schedule_end": 0.05, + "epsilon_schedule_duration_ratio": 0.75, + "learning_starts": 10000, + "train_frequency": 10, + "model_update_frequency": 10, + "value_network": {"num_hidden_nodes": [128, 64]}, + } + + def __init__(self, environment_specs, cfg): + super().__init__() + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self._dtype = torch.float + self._environment_specs = environment_specs + self._cfg = cfg + + async def sample_producer_impl(self, sample_producer_session): + player_actor_params = sample_producer_session.trial_info.parameters.actors[0] + + player_actor_name = player_actor_params.name + player_environment_specs = EnvironmentSpecs.deserialize(player_actor_params.config.environment_specs) + player_observation_space = player_environment_specs.get_observation_space() + player_action_space = player_environment_specs.get_action_space() + + observation = None + action = None + reward = None + + total_reward = 0 + + async for sample in sample_producer_session.all_trial_samples(): + actor_sample = sample.actors_data[player_actor_name] + if actor_sample.observation is None: + # This can happen when there is several "end-of-trial" samples + continue + + next_observation = torch.tensor( + player_observation_space.deserialize(actor_sample.observation).flat_value, dtype=self._dtype + ) + + if observation is not None: + # It's not the first sample, let's check if it is the last + done = sample.trial_state == cogment.TrialState.ENDED + sample_producer_session.produce_sample( + ( + observation, + next_observation, + action, + reward, + torch.ones(1, dtype=torch.int8) if done else torch.zeros(1, dtype=torch.int8), + total_reward, + ) + ) + if done: + break + + observation = next_observation + action_value = player_action_space.deserialize(actor_sample.action).value + action = torch.tensor(action_value, dtype=torch.int64) + reward = torch.tensor(actor_sample.reward if actor_sample.reward is not None else 0, dtype=self._dtype) + total_reward += reward.item() + + async def impl(self, run_session): + # Initializing a model + model_id = f"{run_session.run_id}_model" + + assert self._environment_specs.num_players == 1 + action_space = self._environment_specs.get_action_space() + observation_space = self._environment_specs.get_observation_space() + assert isinstance(action_space.gym_space, Discrete) + + epsilon_schedule = create_linear_schedule( + self._cfg.epsilon_schedule_start, + self._cfg.epsilon_schedule_end, + self._cfg.epsilon_schedule_duration_ratio * self._cfg.num_trials, + ) + + model = DQNModel( + model_id, + environment_implementation=self._environment_specs.implementation, + num_input=utils.flatdim(observation_space.gym_space), + num_output=utils.flatdim(action_space.gym_space), + num_hidden_nodes=self._cfg.value_network.num_hidden_nodes, + epsilon=epsilon_schedule(0), + dtype=self._dtype, + ) + _model_info, version_info = await run_session.model_registry.publish_initial_version(model) + + run_session.log_params( + self._cfg, + model_id=model_id, + environment_implementation=self._environment_specs.implementation, + ) + + # Configure the optimizer + optimizer = torch.optim.Adam( + model.network.parameters(), + lr=self._cfg.learning_rate, + ) + + # Initialize the target model + target_network = copy.deepcopy(model.network) + + replay_buffer = TorchReplayBuffer( + capacity=self._cfg.buffer_size, + observation_shape=(utils.flatdim(observation_space.gym_space),), + observation_dtype=self._dtype, + action_shape=(1,), + action_dtype=torch.int64, + reward_dtype=self._dtype, + seed=self._cfg.seed, + ) + + start_time = time.time() + total_reward_cum = 0 + + for (step_idx, _trial_id, trial_idx, sample,) in run_session.start_and_await_trials( + trials_id_and_params=[ + ( + f"{run_session.run_id}_{trial_idx}", + cogment.TrialParameters( + cog_settings, + environment_name="env", + environment_implementation=self._environment_specs.implementation, + environment_config=EnvironmentConfig( + run_id=run_session.run_id, + render=False, + seed=self._cfg.seed + trial_idx, + ), + actors=[ + cogment.ActorParameters( + cog_settings, + name="player", + class_name=PLAYER_ACTOR_CLASS, + implementation="actors.simple_dqn.SimpleDQNActor", + config=AgentConfig( + run_id=run_session.run_id, + seed=self._cfg.seed + trial_idx, + model_id=model_id, + model_version=-1, + model_update_frequency=self._cfg.model_update_frequency, + environment_specs=self._environment_specs.serialize(), + ), + ) + ], + ), + ) + for trial_idx in range(self._cfg.num_trials) + ], + sample_producer_impl=self.sample_producer_impl, + num_parallel_trials=self._cfg.num_parallel_trials, + ): + (observation, next_observation, action, reward, done, total_reward) = sample + replay_buffer.add( + observation=observation, next_observation=next_observation, action=action, reward=reward, done=done + ) + + trial_done = done.item() == 1 + + if trial_done: + run_session.log_metrics(trial_idx=trial_idx, total_reward=total_reward) + total_reward_cum += total_reward + if (trial_idx + 1) % 100 == 0: + total_reward_avg = total_reward_cum / 100 + run_session.log_metrics(total_reward_avg=total_reward_avg) + total_reward_cum = 0 + log.info( + f"[SimpleDQN/{run_session.run_id}] trial #{trial_idx + 1}/{self._cfg.num_trials} done (average total reward = {total_reward_avg})." + ) + + if ( + step_idx > self._cfg.learning_starts + and replay_buffer.size() > self._cfg.batch_size + and step_idx % self._cfg.train_frequency == 0 + ): + data = replay_buffer.sample(self._cfg.batch_size) + + with torch.no_grad(): + target_values, _ = target_network(data.next_observation).max(dim=1) + td_target = data.reward.flatten() + self._cfg.discount_factor * target_values * ( + 1 - data.done.flatten() + ) + + action_values = model.network(data.observation).gather(1, data.action).squeeze() + loss = torch.nn.functional.mse_loss(td_target, action_values) + + # optimize the model + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Update the epsilon + model.epsilon = epsilon_schedule(trial_idx) + + # Update the version info + model.num_samples_seen += data.size() + + if step_idx % self._cfg.target_update_frequency == 0: + target_network.load_state_dict(model.network.state_dict()) + + version_info = await run_session.model_registry.publish_version(model) + + if step_idx % 100 == 0: + end_time = time.time() + steps_per_seconds = 100 / (end_time - start_time) + start_time = end_time + run_session.log_metrics( + model_version_number=version_info["version_number"], + loss=loss.item(), + q_values=action_values.mean().item(), + batch_avg_reward=data.reward.mean().item(), + epsilon=model.epsilon, + steps_per_seconds=steps_per_seconds, + ) + + version_info = await run_session.model_registry.publish_version(model, archived=True) From 225d1d6c4eab6bc732bce494784c4a593ccbe80f Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Thu, 9 Mar 2023 14:03:32 -0500 Subject: [PATCH 02/29] test dqn configs --- config/experiment/simple_dqn/cartpole.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/config/experiment/simple_dqn/cartpole.yaml b/config/experiment/simple_dqn/cartpole.yaml index fb25fd60..2f36893f 100644 --- a/config/experiment/simple_dqn/cartpole.yaml +++ b/config/experiment/simple_dqn/cartpole.yaml @@ -4,7 +4,7 @@ defaults: - simple_dqn - override /services/environment: cartpole run: - class_name: actors.simple_dqn.SimpleDQNTraining + class_name: actors.dqn.SimpleDQNTraining seed: 618 # Archiving @@ -29,5 +29,5 @@ run: epsilon_schedule_start: 1 epsilon_schedule_end: 0.05 epsilon_schedule_duration_ratio: 0.75 - - + + From 33aaa942e09681828408736f47eea6d8f46d314a Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Thu, 9 Mar 2023 14:26:40 -0500 Subject: [PATCH 03/29] add MLP and conv network classes --- actors/dqn.py | 100 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/actors/dqn.py b/actors/dqn.py index 630208e9..bdda55c9 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -21,8 +21,10 @@ import math import numpy as np +from typing import List, Tuple, Union import cogment import torch +from torch import nn from gym.spaces import Discrete, utils from cogment_verse.specs import AgentConfig, cog_settings, EnvironmentConfig, EnvironmentSpecs @@ -44,6 +46,104 @@ def compute_value(t): return compute_value +# Acknowledgements: The networks and associated utils are adapted from RLHive + +def calculate_output_dim(net, input_shape): + if isinstance(input_shape, int): + input_shape = (input_shape,) + placeholder = torch.zeros((0,) + tuple(input_shape)) + output = net(placeholder) + return output.size()[1:] + +class MLPNetwork(nn.Module): + def __init__( + self, + in_dim: Tuple[int], + hidden_units: Union[int, List[int]] = 256, + noisy: bool = False, + std_init: float = 0.5, + ): + super().__init__() + if isinstance(hidden_units, int): + hidden_units = [hidden_units] + modules = [nn.Linear(np.prod(in_dim), hidden_units[0]), torch.nn.ReLU()] + for i in range(len(hidden_units) - 1): + modules.append(nn.Linear(hidden_units[i], hidden_units[i + 1])) + modules.append(torch.nn.ReLU()) + self.network = torch.nn.Sequential(*modules) + + def forward(self, x): + x = x.float() + x = torch.flatten(x, start_dim=1) + return self.network(x) + + +class ConvNetwork(nn.Module): + def __init__( + self, + in_dim, + channels=None, + mlp_layers=None, + kernel_sizes=1, + strides=1, + paddings=0, + normalization_factor=255, + noisy=False, + std_init=0.5, + ): + super().__init__() + self._normalization_factor = normalization_factor + if channels is not None: + if isinstance(kernel_sizes, int): + kernel_sizes = [kernel_sizes] * len(channels) + if isinstance(strides, int): + strides = [strides] * len(channels) + if isinstance(paddings, int): + paddings = [paddings] * len(channels) + + if not all( + len(x) == len(channels) for x in [kernel_sizes, strides, paddings] + ): + raise ValueError("The lengths of the parameter lists must be the same") + + # Convolutional Layers + channels.insert(0, in_dim[0]) + conv_seq = [] + for i in range(0, len(channels) - 1): + conv_seq.append( + torch.nn.Conv2d( + in_channels=channels[i], + out_channels=channels[i + 1], + kernel_size=kernel_sizes[i], + stride=strides[i], + padding=paddings[i], + ) + ) + conv_seq.append(torch.nn.ReLU()) + self.conv = torch.nn.Sequential(*conv_seq) + else: + self.conv = torch.nn.Identity() + + if mlp_layers is not None: + # MLP Layers + conv_output_size = calculate_output_dim(self.conv, in_dim) + self.mlp = MLPNetwork( + conv_output_size, mlp_layers, noisy=noisy, std_init=std_init + ) + else: + self.mlp = torch.nn.Identity() + + def forward(self, x): + if len(x.shape) == 3: + x = x.unsqueeze(0) + elif len(x.shape) == 5: + x = x.reshape(x.size(0), -1, x.size(-2), x.size(-1)) + x = x.float() + x = x / self._normalization_factor + x = self.conv(x) + x = self.mlp(x) + return x + class DQNModel(Model): def __init__( From e0409b8763d0bee08c0cb9f3f776b3227a283a09 Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Thu, 9 Mar 2023 14:49:42 -0500 Subject: [PATCH 04/29] DQN network with conv and MLP --- actors/dqn.py | 40 +++++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/actors/dqn.py b/actors/dqn.py index bdda55c9..2cb9d961 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -145,6 +145,23 @@ def forward(self, x): return x +class DQNNetwork(nn.Module): + def __init__( + self, + base_network: nn.Module, + hidden_dim: int, + out_dim: int, + ): + super().__init__() + self.base_network = base_network + self.output_layer = nn.Linear(hidden_dim, out_dim) + + def forward(self, x): + x = self.base_network(x) + x = x.flatten(start_dim=1) + return self.output_layer(x) + + class DQNModel(Model): def __init__( self, @@ -165,19 +182,8 @@ def __init__( self._num_hidden_nodes = list(num_hidden_nodes) self.epsilon = epsilon - self.network = torch.nn.Sequential( - torch.nn.Linear(self._num_input, self._num_hidden_nodes[0]), - torch.nn.ReLU(), - *[ - layer - for hidden_node_idx in range(len(self._num_hidden_nodes) - 1) - for layer in [ - torch.nn.Linear(self._num_hidden_nodes[hidden_node_idx], self._num_hidden_nodes[-1]), - torch.nn.ReLU(), - ] - ], - torch.nn.Linear(self._num_hidden_nodes[-1], self._num_output), - ) + self.base_network = MLPNetwork(num_input, self._num_hidden_nodes) + self.network = DQNNetwork(self.base_network, self._num_hidden_nodes[-1], self._num_output) # version user data self.num_samples_seen = 0 @@ -219,7 +225,7 @@ def load(cls, model_id, version_number, model_user_data, version_user_data, mode return model -class SimpleDQNActor: +class DQNActor: def __init__(self, _cfg): self._dtype = torch.float self.samples_since_update = 0 @@ -283,7 +289,7 @@ async def impl(self, actor_session): actor_session.do_action(action_space.serialize(action)) -class SimpleDQNTraining: +class DQNTraining: default_cfg = { "seed": 10, "num_trials": 5000, @@ -427,7 +433,7 @@ async def impl(self, run_session): cog_settings, name="player", class_name=PLAYER_ACTOR_CLASS, - implementation="actors.simple_dqn.SimpleDQNActor", + implementation="actors.dqn.DQNActor", config=AgentConfig( run_id=run_session.run_id, seed=self._cfg.seed + trial_idx, @@ -460,7 +466,7 @@ async def impl(self, run_session): run_session.log_metrics(total_reward_avg=total_reward_avg) total_reward_cum = 0 log.info( - f"[SimpleDQN/{run_session.run_id}] trial #{trial_idx + 1}/{self._cfg.num_trials} done (average total reward = {total_reward_avg})." + f"[DQN/{run_session.run_id}] trial #{trial_idx + 1}/{self._cfg.num_trials} done (average total reward = {total_reward_avg})." ) if ( From ffa6af155e1f03504dd498773d5a6e5b740c7d50 Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Thu, 9 Mar 2023 15:00:29 -0500 Subject: [PATCH 05/29] registering new DQN actor. And config --- config/experiment/simple_dqn/cartpole.yaml | 4 ++-- config/services/actor/dqn.yaml | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) create mode 100644 config/services/actor/dqn.yaml diff --git a/config/experiment/simple_dqn/cartpole.yaml b/config/experiment/simple_dqn/cartpole.yaml index 2f36893f..9a2a0b8e 100644 --- a/config/experiment/simple_dqn/cartpole.yaml +++ b/config/experiment/simple_dqn/cartpole.yaml @@ -1,10 +1,10 @@ # @package _global_ defaults: - override /services/actor: - - simple_dqn + - dqn - override /services/environment: cartpole run: - class_name: actors.dqn.SimpleDQNTraining + class_name: actors.dqn.DQNTraining seed: 618 # Archiving diff --git a/config/services/actor/dqn.yaml b/config/services/actor/dqn.yaml new file mode 100644 index 00000000..cd680fdf --- /dev/null +++ b/config/services/actor/dqn.yaml @@ -0,0 +1,3 @@ +simple_dqn: + class_name: actors.dqn.DQNActor + port: ${generate_port:actors.dqn.DQNActor} From 0c630c4f18e523d89558e24b04d319158e344868 Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Thu, 9 Mar 2023 15:26:12 -0500 Subject: [PATCH 06/29] fix flatten in MLP and DQN networks --- actors/dqn.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/actors/dqn.py b/actors/dqn.py index 2cb9d961..0ae36eda 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -74,7 +74,9 @@ def __init__( def forward(self, x): x = x.float() - x = torch.flatten(x, start_dim=1) + # print("x shape = ", x.shape) + if len(x.shape) > 1: + x = torch.flatten(x, start_dim=1) return self.network(x) @@ -158,7 +160,9 @@ def __init__( def forward(self, x): x = self.base_network(x) - x = x.flatten(start_dim=1) + # print("in DQN x shape = ", x.shape) + if len(x.shape) > 1: + x = x.flatten(start_dim=1) return self.output_layer(x) @@ -182,8 +186,13 @@ def __init__( self._num_hidden_nodes = list(num_hidden_nodes) self.epsilon = epsilon + # print("num_input = ", num_input) + # print("num_hidden_nodes = ", self._num_hidden_nodes) + # print("num_output = ", num_output) self.base_network = MLPNetwork(num_input, self._num_hidden_nodes) + # print("self.base_network = ", self.base_network) self.network = DQNNetwork(self.base_network, self._num_hidden_nodes[-1], self._num_output) + # print("self.network = ", self.network) # version user data self.num_samples_seen = 0 From d98c7145169068148d85942613fc2876fc19c471 Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Mon, 20 Mar 2023 12:15:39 -0400 Subject: [PATCH 07/29] torch and np seed --- actors/dqn.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/actors/dqn.py b/actors/dqn.py index 0ae36eda..b236a200 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -33,6 +33,8 @@ from cogment_verse import Model, TorchReplayBuffer # pylint: disable=abstract-class-instantiated +torch.manual_seed(0) +np.random.seed(0) torch.multiprocessing.set_sharing_strategy("file_system") log = logging.getLogger(__name__) @@ -74,7 +76,6 @@ def __init__( def forward(self, x): x = x.float() - # print("x shape = ", x.shape) if len(x.shape) > 1: x = torch.flatten(x, start_dim=1) return self.network(x) @@ -160,7 +161,6 @@ def __init__( def forward(self, x): x = self.base_network(x) - # print("in DQN x shape = ", x.shape) if len(x.shape) > 1: x = x.flatten(start_dim=1) return self.output_layer(x) @@ -186,13 +186,8 @@ def __init__( self._num_hidden_nodes = list(num_hidden_nodes) self.epsilon = epsilon - # print("num_input = ", num_input) - # print("num_hidden_nodes = ", self._num_hidden_nodes) - # print("num_output = ", num_output) self.base_network = MLPNetwork(num_input, self._num_hidden_nodes) - # print("self.base_network = ", self.base_network) self.network = DQNNetwork(self.base_network, self._num_hidden_nodes[-1], self._num_output) - # print("self.network = ", self.network) # version user data self.num_samples_seen = 0 From bd5fb003533a0e1347fac48bf0fce0eba8f36498 Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Sun, 30 Apr 2023 21:13:52 -0400 Subject: [PATCH 08/29] pylint fixes --- actors/dqn.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/actors/dqn.py b/actors/dqn.py index b236a200..defec870 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -18,10 +18,9 @@ import copy import time import json -import math +from typing import List, Tuple, Union import numpy as np -from typing import List, Tuple, Union import cogment import torch from torch import nn @@ -29,7 +28,7 @@ from cogment_verse.specs import AgentConfig, cog_settings, EnvironmentConfig, EnvironmentSpecs -from cogment_verse.constants import PLAYER_ACTOR_CLASS, WEB_ACTOR_NAME, HUMAN_ACTOR_IMPL +from cogment_verse.constants import PLAYER_ACTOR_CLASS from cogment_verse import Model, TorchReplayBuffer # pylint: disable=abstract-class-instantiated @@ -62,8 +61,6 @@ def __init__( self, in_dim: Tuple[int], hidden_units: Union[int, List[int]] = 256, - noisy: bool = False, - std_init: float = 0.5, ): super().__init__() if isinstance(hidden_units, int): From 4b65d123ef9af2f1feb01872d1143a99c7e204da Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Sun, 30 Apr 2023 21:14:18 -0400 Subject: [PATCH 09/29] black format --- actors/dqn.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/actors/dqn.py b/actors/dqn.py index defec870..2b40b1d8 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -47,8 +47,10 @@ def compute_value(t): return compute_value + # Acknowledgements: The networks and associated utils are adapted from RLHive + def calculate_output_dim(net, input_shape): if isinstance(input_shape, int): input_shape = (input_shape,) @@ -56,6 +58,7 @@ def calculate_output_dim(net, input_shape): output = net(placeholder) return output.size()[1:] + class MLPNetwork(nn.Module): def __init__( self, @@ -101,9 +104,7 @@ def __init__( if isinstance(paddings, int): paddings = [paddings] * len(channels) - if not all( - len(x) == len(channels) for x in [kernel_sizes, strides, paddings] - ): + if not all(len(x) == len(channels) for x in [kernel_sizes, strides, paddings]): raise ValueError("The lengths of the parameter lists must be the same") # Convolutional Layers @@ -127,9 +128,7 @@ def __init__( if mlp_layers is not None: # MLP Layers conv_output_size = calculate_output_dim(self.conv, in_dim) - self.mlp = MLPNetwork( - conv_output_size, mlp_layers, noisy=noisy, std_init=std_init - ) + self.mlp = MLPNetwork(conv_output_size, mlp_layers, noisy=noisy, std_init=std_init) else: self.mlp = torch.nn.Identity() From a00128c9c3c100978cfb972dde70311854d2a27e Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Sun, 30 Apr 2023 23:20:40 -0400 Subject: [PATCH 10/29] fix args --- actors/dqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/actors/dqn.py b/actors/dqn.py index 2b40b1d8..ecdafc40 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -128,7 +128,7 @@ def __init__( if mlp_layers is not None: # MLP Layers conv_output_size = calculate_output_dim(self.conv, in_dim) - self.mlp = MLPNetwork(conv_output_size, mlp_layers, noisy=noisy, std_init=std_init) + self.mlp = MLPNetwork(conv_output_size, mlp_layers) else: self.mlp = torch.nn.Identity() From 853f90a39e11902cd80f0ba0eb65bb7506c3005d Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Mon, 1 May 2023 09:02:32 -0400 Subject: [PATCH 11/29] have both simple_dqn and dqn configs --- config/experiment/dqn/cartpole.yaml | 32 ++++++++++++++++++++++ config/experiment/simple_dqn/cartpole.yaml | 4 +-- 2 files changed, 34 insertions(+), 2 deletions(-) create mode 100644 config/experiment/dqn/cartpole.yaml diff --git a/config/experiment/dqn/cartpole.yaml b/config/experiment/dqn/cartpole.yaml new file mode 100644 index 00000000..4edda77a --- /dev/null +++ b/config/experiment/dqn/cartpole.yaml @@ -0,0 +1,32 @@ +# @package _global_ +defaults: + - override /services/actor: + - dqn + - override /services/environment: cartpole + +run: + class_name: actors.dqn.DQNTraining + seed: 618 + + # Archiving + archive_model: True + archive_frequency: 20000 # Unit: steps + + # Training Params + num_trials: 10000 # Unit: trials + num_parallel_trials: 10 # Unit: trials + learning_starts: ${run.batch_size} # Unit: steps + target_update_frequency: 2000 # Unit: steps + buffer_size: 10000 # Unit: steps + train_frequency: 10 # Unit: steps + model_update_frequency: 10 # Unit: steps + + # Network Params + value_network: + num_hidden_nodes: [128, 64] + learning_rate: 0.000125 + discount_factor: 0.95 + batch_size: 64 + epsilon_schedule_start: 1 + epsilon_schedule_end: 0.05 + epsilon_schedule_duration_ratio: 0.75 diff --git a/config/experiment/simple_dqn/cartpole.yaml b/config/experiment/simple_dqn/cartpole.yaml index 4edda77a..e019081c 100644 --- a/config/experiment/simple_dqn/cartpole.yaml +++ b/config/experiment/simple_dqn/cartpole.yaml @@ -1,11 +1,11 @@ # @package _global_ defaults: - override /services/actor: - - dqn + - simple_dqn - override /services/environment: cartpole run: - class_name: actors.dqn.DQNTraining + class_name: actors.simple_dqn.DQNTraining seed: 618 # Archiving From af6472365c650cc13aceaa10ae704807b4493057 Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Mon, 1 May 2023 20:53:46 -0400 Subject: [PATCH 12/29] serialize, deserialize --- actors/dqn.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/actors/dqn.py b/actors/dqn.py index ecdafc40..44fb2f82 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -201,26 +201,35 @@ def save(self, model_data_f): return {"num_samples_seen": self.num_samples_seen} + @staticmethod + def serialize_model(model) -> bytes: + stream = io.BytesIO() + torch.save( + ( + model.network.state_dict(), + model.epsilon, + model.get_model_user_data(), + ), + stream, + ) + return stream.getvalue() + @classmethod - def load(cls, model_id, version_number, model_user_data, version_user_data, model_data_f): - # Create the model instance - model = DQNModel( - model_id=model_id, - version_number=version_number, + def deserialize_model(cls, serialized_model) -> DQNModel: + stream = io.BytesIO(serialized_model) + (network_state_dict, epsilon, model_user_data) = torch.load(stream) + + model = cls( + model_id=model_user_data["model_id"], environment_implementation=model_user_data["environment_implementation"], num_input=int(model_user_data["num_input"]), num_output=int(model_user_data["num_output"]), num_hidden_nodes=json.loads(model_user_data["num_hidden_nodes"]), epsilon=0, ) - - # Load the saved states - (network_state_dict, epsilon) = torch.load(model_data_f) model.network.load_state_dict(network_state_dict) model.epsilon = epsilon - - # Load version data - model.num_samples_seen = int(version_user_data["num_samples_seen"]) + model.num_samples_seen = int(model_user_data["num_samples_seen"]) return model From c84d7618cbd8cfb72a6b1b900a4d4a1537dfb19f Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Mon, 1 May 2023 20:56:14 -0400 Subject: [PATCH 13/29] pylint fixes --- actors/dqn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/actors/dqn.py b/actors/dqn.py index 44fb2f82..17150901 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -91,8 +91,6 @@ def __init__( strides=1, paddings=0, normalization_factor=255, - noisy=False, - std_init=0.5, ): super().__init__() self._normalization_factor = normalization_factor From 7deb328e9c8046c376780d1b249831c7667fbd13 Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Fri, 5 May 2023 23:30:15 -0400 Subject: [PATCH 14/29] Update actors/dqn.py Co-authored-by: William Duguay <110052871+wduguay-air@users.noreply.github.com> --- actors/dqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/actors/dqn.py b/actors/dqn.py index 17150901..bb5e3078 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -272,7 +272,7 @@ async def impl(self, actor_session): and self.samples_since_update % config.model_update_frequency == 0 ): model, _, _ = await actor_session.model_registry.retrieve_version( - DQNModel, config.model_id, config.model_version + DQNModel, config.model_id, config.model_iteration ) model.network.eval() self.samples_since_update = 0 From 3589729e8b7334a53117e93e0aa6d958388f6810 Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Fri, 5 May 2023 23:37:48 -0400 Subject: [PATCH 15/29] using new model registry --- actors/dqn.py | 27 ++++++++++++++-------- config/experiment/simple_dqn/cartpole.yaml | 4 ---- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/actors/dqn.py b/actors/dqn.py index 17150901..21c8620a 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -192,13 +192,9 @@ def get_model_user_data(self): "num_input": self._num_input, "num_output": self._num_output, "num_hidden_nodes": json.dumps(self._num_hidden_nodes), + "num_samples_seen": self.num_samples_seen, } - def save(self, model_data_f): - torch.save((self.network.state_dict(), self.epsilon), model_data_f) - - return {"num_samples_seen": self.num_samples_seen} - @staticmethod def serialize_model(model) -> bytes: stream = io.BytesIO() @@ -254,7 +250,7 @@ async def impl(self, actor_session): assert isinstance(action_space.gym_space, Discrete) model, _, _ = await actor_session.model_registry.retrieve_version( - DQNModel, config.model_id, config.model_version + DQNModel, config.model_id, config.model_iteration ) model.network.eval() @@ -392,7 +388,11 @@ async def impl(self, run_session): epsilon=epsilon_schedule(0), dtype=self._dtype, ) - _model_info, version_info = await run_session.model_registry.publish_initial_version(model) + serialized_model = DQNModel.serialize_model(model) + iteration_info = await run_session.model_registry.publish_model( + name=model_id, + model=serialized_model, + ) run_session.log_params( self._cfg, @@ -506,7 +506,11 @@ async def impl(self, run_session): if step_idx % self._cfg.target_update_frequency == 0: target_network.load_state_dict(model.network.state_dict()) - version_info = await run_session.model_registry.publish_version(model) + serialized_model = DQNModel.serialize_model(model) + iteration_info = await run_session.model_registry.publish_model( + name=model_id, + model=serialized_model, + ) if step_idx % 100 == 0: end_time = time.time() @@ -521,4 +525,9 @@ async def impl(self, run_session): steps_per_seconds=steps_per_seconds, ) - version_info = await run_session.model_registry.publish_version(model, archived=True) + serialized_model = DQNModel.serialize_model(model) + iteration_info = await run_session.model_registry.store_model( + name=model_id, + model=serialized_model, + ) + diff --git a/config/experiment/simple_dqn/cartpole.yaml b/config/experiment/simple_dqn/cartpole.yaml index e019081c..a86bb27a 100644 --- a/config/experiment/simple_dqn/cartpole.yaml +++ b/config/experiment/simple_dqn/cartpole.yaml @@ -8,10 +8,6 @@ run: class_name: actors.simple_dqn.DQNTraining seed: 618 - # Archiving - archive_model: True - archive_frequency: 20000 # Unit: steps - # Training Params num_trials: 10000 # Unit: trials num_parallel_trials: 10 # Unit: trials From f8b781be1c469f9d1449653e77e4723f50217b71 Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Fri, 5 May 2023 23:42:19 -0400 Subject: [PATCH 16/29] black format --- actors/dqn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/actors/dqn.py b/actors/dqn.py index 5aea2ffe..89b51195 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -530,4 +530,3 @@ async def impl(self, run_session): name=model_id, model=serialized_model, ) - From f7b1556cea7596f3b8bfb386c80ac30551628e52 Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Sat, 6 May 2023 00:09:22 -0400 Subject: [PATCH 17/29] pylint fixes --- actors/dqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/actors/dqn.py b/actors/dqn.py index 89b51195..957b5742 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -517,7 +517,7 @@ async def impl(self, run_session): steps_per_seconds = 100 / (end_time - start_time) start_time = end_time run_session.log_metrics( - model_version_number=version_info["version_number"], + model_iteration_number=iteration_info["iteration_number"], loss=loss.item(), q_values=action_values.mean().item(), batch_avg_reward=data.reward.mean().item(), From 39179d3129bd721c910333cda013f252c73ca98b Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Sat, 6 May 2023 00:10:52 -0400 Subject: [PATCH 18/29] io --- actors/dqn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/actors/dqn.py b/actors/dqn.py index 957b5742..75c636f1 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -14,10 +14,11 @@ # pylint: disable=E0611 -import logging import copy +import io import time import json +import logging from typing import List, Tuple, Union import numpy as np From 7f00367ded7dbe19e5ce897f3b860b339f1f46f5 Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Mon, 8 May 2023 09:51:37 -0400 Subject: [PATCH 19/29] import annotations --- actors/dqn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/actors/dqn.py b/actors/dqn.py index 75c636f1..62b3eb9b 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -20,6 +20,7 @@ import json import logging from typing import List, Tuple, Union +from __future__ import annotations import numpy as np import cogment From 2e1feb4cd53daecc2ade060dbf741e0a4b65ef88 Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Mon, 8 May 2023 10:05:38 -0400 Subject: [PATCH 20/29] minor fix --- actors/dqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/actors/dqn.py b/actors/dqn.py index 62b3eb9b..57ba1f66 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -14,13 +14,13 @@ # pylint: disable=E0611 +from __future__ import annotations import copy import io import time import json import logging from typing import List, Tuple, Union -from __future__ import annotations import numpy as np import cogment From 82d5e0b2594a386c148bfc3243e531e616e2ee1e Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Mon, 8 May 2023 10:08:02 -0400 Subject: [PATCH 21/29] remove archiving class --- config/experiment/dqn/cartpole.yaml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/config/experiment/dqn/cartpole.yaml b/config/experiment/dqn/cartpole.yaml index 4edda77a..e2f811e2 100644 --- a/config/experiment/dqn/cartpole.yaml +++ b/config/experiment/dqn/cartpole.yaml @@ -8,10 +8,6 @@ run: class_name: actors.dqn.DQNTraining seed: 618 - # Archiving - archive_model: True - archive_frequency: 20000 # Unit: steps - # Training Params num_trials: 10000 # Unit: trials num_parallel_trials: 10 # Unit: trials From 5bf676fd6cdf383950afa2f72044255279f68afe Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Mon, 8 May 2023 15:57:06 -0400 Subject: [PATCH 22/29] Update config/services/actor/dqn.yaml Co-authored-by: William Duguay <110052871+wduguay-air@users.noreply.github.com> --- config/services/actor/dqn.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/services/actor/dqn.yaml b/config/services/actor/dqn.yaml index cd680fdf..ffc687d0 100644 --- a/config/services/actor/dqn.yaml +++ b/config/services/actor/dqn.yaml @@ -1,3 +1,3 @@ -simple_dqn: +dqn: class_name: actors.dqn.DQNActor port: ${generate_port:actors.dqn.DQNActor} From c77d46129198a1dc9e6f528fd2f2b4a775f03381 Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Tue, 9 May 2023 09:57:37 -0400 Subject: [PATCH 23/29] Update actors/dqn.py Co-authored-by: William Duguay <110052871+wduguay-air@users.noreply.github.com> --- actors/dqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/actors/dqn.py b/actors/dqn.py index 57ba1f66..f3dcbf6a 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -172,7 +172,7 @@ def __init__( num_hidden_nodes, epsilon, dtype=torch.float, - version_number=0, + iteration=0, ): super().__init__(model_id, version_number) self._dtype = dtype From 568ac2b63732a73114776daae8a779f342dfecd3 Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Tue, 9 May 2023 09:57:46 -0400 Subject: [PATCH 24/29] Update actors/dqn.py Co-authored-by: William Duguay <110052871+wduguay-air@users.noreply.github.com> --- actors/dqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/actors/dqn.py b/actors/dqn.py index f3dcbf6a..4e172318 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -174,7 +174,7 @@ def __init__( dtype=torch.float, iteration=0, ): - super().__init__(model_id, version_number) + super().__init__(model_id, iteration) self._dtype = dtype self._environment_implementation = environment_implementation self._num_input = num_input From 06857411489e995858dd838d44f24499475a2666 Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Tue, 9 May 2023 09:57:55 -0400 Subject: [PATCH 25/29] Update actors/dqn.py Co-authored-by: William Duguay <110052871+wduguay-air@users.noreply.github.com> --- actors/dqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/actors/dqn.py b/actors/dqn.py index 4e172318..e7902ddf 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -447,7 +447,7 @@ async def impl(self, run_session): run_id=run_session.run_id, seed=self._cfg.seed + trial_idx, model_id=model_id, - model_version=-1, + model_iteration=-1, model_update_frequency=self._cfg.model_update_frequency, environment_specs=self._environment_specs.serialize(), ), From 73fde922249c0b5f5e5edbe62680ac9edd12468a Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Tue, 9 May 2023 10:02:31 -0400 Subject: [PATCH 26/29] current --- actors/dqn.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/actors/dqn.py b/actors/dqn.py index 57ba1f66..8ad743bf 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -190,6 +190,7 @@ def __init__( def get_model_user_data(self): return { + "model_id": self.model_id, "environment_implementation": self._environment_implementation, "num_input": self._num_input, "num_output": self._num_output, @@ -251,8 +252,8 @@ async def impl(self, actor_session): assert isinstance(action_space.gym_space, Discrete) - model, _, _ = await actor_session.model_registry.retrieve_version( - DQNModel, config.model_id, config.model_iteration + model = await DQNModel.retrieve_model( + actor_session.model_registry, config.model_id, config.model_iteration ) model.network.eval() @@ -447,7 +448,7 @@ async def impl(self, run_session): run_id=run_session.run_id, seed=self._cfg.seed + trial_idx, model_id=model_id, - model_version=-1, + model_iteration=-1, model_update_frequency=self._cfg.model_update_frequency, environment_specs=self._environment_specs.serialize(), ), From 3f4c0274554a9d0e81a4dbe284e0a87b5c123232 Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Thu, 11 May 2023 13:28:22 -0400 Subject: [PATCH 27/29] current --- actors/dqn.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/actors/dqn.py b/actors/dqn.py index b0c8143e..82e95e89 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -34,8 +34,6 @@ from cogment_verse import Model, TorchReplayBuffer # pylint: disable=abstract-class-instantiated -torch.manual_seed(0) -np.random.seed(0) torch.multiprocessing.set_sharing_strategy("file_system") log = logging.getLogger(__name__) @@ -52,7 +50,6 @@ def compute_value(t): # Acknowledgements: The networks and associated utils are adapted from RLHive - def calculate_output_dim(net, input_shape): if isinstance(input_shape, int): input_shape = (input_shape,) @@ -398,7 +395,7 @@ async def impl(self, run_session): ) run_session.log_params( - self._cfg, + # self._cfg, model_id=model_id, environment_implementation=self._environment_specs.implementation, ) From 8a5bb5cef8159e3859c24c059e902707db79db91 Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Thu, 11 May 2023 13:28:44 -0400 Subject: [PATCH 28/29] black format --- actors/dqn.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/actors/dqn.py b/actors/dqn.py index 82e95e89..e043991c 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -50,6 +50,7 @@ def compute_value(t): # Acknowledgements: The networks and associated utils are adapted from RLHive + def calculate_output_dim(net, input_shape): if isinstance(input_shape, int): input_shape = (input_shape,) @@ -249,9 +250,7 @@ async def impl(self, actor_session): assert isinstance(action_space.gym_space, Discrete) - model = await DQNModel.retrieve_model( - actor_session.model_registry, config.model_id, config.model_iteration - ) + model = await DQNModel.retrieve_model(actor_session.model_registry, config.model_id, config.model_iteration) model.network.eval() async for event in actor_session.all_events(): From 57b0fd4052500f2a8cead3893f4199e354804860 Mon Sep 17 00:00:00 2001 From: saikrishnagv_1996 Date: Mon, 12 Jun 2023 11:07:10 -0400 Subject: [PATCH 29/29] Update dqn.py --- actors/dqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/actors/dqn.py b/actors/dqn.py index e043991c..6fee1c4d 100644 --- a/actors/dqn.py +++ b/actors/dqn.py @@ -1,4 +1,4 @@ -# Copyright 2022 AI Redefined Inc. +# Copyright 2023 AI Redefined Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.