Skip to content

Commit

Permalink
add action representation to bootstrapped DQN
Browse files Browse the repository at this point in the history
Summary:
The current implementation of bootstrapped DQN encounters a dimension mismatch error when integrated with an action representation module, such as the OneHotActionTensorRepresentationModule used in CartPole-v1. This issue arises because the deep exploration mechanism fails to incorporate the action representation module during processing. This update resolves this issue.

Additionally, this update modifies the default parameters in our benchmark settings for bootstrapped DQN. Specifically, the prior scale has been adjusted from 0 to 100, and the number of elements in the Q ensemble network has been reduced from 10 to 5. The absence of a prior network (0 prior scale) results in all ensemble elements ranking actions similarly after a few training steps, leading to no exploration and, therefore, no performance improvements. Reducing the number of ensembles accelerates the training process.

Reviewed By: rodrigodesalvobraz

Differential Revision: D60557188

fbshipit-source-id: 57ec2181703d973ec59af65751794b319ed3d8fd
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Aug 14, 2024
1 parent f84334c commit 90d0b49
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ class DeepExploration(ExplorationModule):
def __init__(
self,
q_ensemble_network: EnsembleQValueNetwork,
action_representation_module: torch.nn.Module,
) -> None:
super(DeepExploration, self).__init__()
self.q_ensemble_network = q_ensemble_network
self.action_representation_module = action_representation_module

def act(
self,
Expand All @@ -70,6 +72,8 @@ def act(
actions = action_space.actions_batch.to(subjective_state.device)
# (action_space_size, action_dim)

actions = self.action_representation_module(actions)

with torch.no_grad():
q_values = self.q_ensemble_network.get_q_values(
state_batch=states_repeated,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from pearl.action_representation_modules.action_representation_module import (
ActionRepresentationModule,
)
from pearl.action_representation_modules.identity_action_representation_module import (
IdentityActionRepresentationModule,
)
from pearl.api.action import Action

from pearl.api.action_space import ActionSpace
Expand Down Expand Up @@ -62,11 +65,19 @@ def __init__(
soft_update_tau: float = 1.0,
action_representation_module: Optional[ActionRepresentationModule] = None,
) -> None:
assert isinstance(action_space, DiscreteActionSpace)
if action_representation_module is None:
action_representation_module = IdentityActionRepresentationModule(
max_number_actions=action_space.n,
representation_dim=action_space.action_dim,
)
PolicyLearner.__init__(
self=self,
training_rounds=training_rounds,
batch_size=batch_size,
exploration_module=DeepExploration(q_ensemble_network),
exploration_module=DeepExploration(
q_ensemble_network, action_representation_module
),
on_policy=False,
is_action_continuous=False,
action_representation_module=action_representation_module,
Expand Down
6 changes: 3 additions & 3 deletions pearl/utils/scripts/benchmark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,14 +292,14 @@
"replay_buffer_args": {
"capacity": 50000,
"p": 1.0,
"ensemble_size": 10,
"ensemble_size": 5,
},
"network_module": EnsembleQValueNetwork,
"network_args": {
"ensemble_size": 10,
"ensemble_size": 5,
"output_dim": 1,
"hidden_dims": [64, 64],
"prior_scale": 0.0,
"prior_scale": 100.0,
},
"action_representation_module": OneHotActionTensorRepresentationModule,
"action_representation_module_args": {},
Expand Down

0 comments on commit 90d0b49

Please sign in to comment.