Skip to content

Commit

Permalink
merge with main
Browse files Browse the repository at this point in the history
  • Loading branch information
beardyFace committed Dec 12, 2024
2 parents 1ae79c0 + de44df3 commit 3542ab2
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 0 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,11 @@ A factory class for creating a memory buffer that has been implemented into the
| REDQ | Vector | Continuous | [REDQ Paper](https://arxiv.org/pdf/2101.05982.pdf) |
| TQC | Vector | Continuous | [TQC Paper](https://arxiv.org/abs/1812.05905) |
| CTD4 | Vector | Continuous | [CTD4 Paper](https://arxiv.org/abs/2405.02576) |
<<<<<<< HEAD
| CrossQ | Vector | Continuous | [CrossQ Paper](https://arxiv.org/pdf/1902.05605) |
=======
| Droq | Vector | Continuous | [DroQ Paper](https://arxiv.org/abs/2110.02034) |
>>>>>>> main
| ----------- | -------------------------- | ------------ | --------------- |
| NaSATD3 | Image | Continuous | In Submission |
| TD3AE | Image | Continuous | [TD3AE Paper](https://arxiv.org/abs/1910.01741) |
Expand Down
23 changes: 23 additions & 0 deletions cares_reinforcement_learning/algorithm/policy/DroQ.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Original Paper: https://openreview.net/pdf?id=xCVJMsPv3RT
Code based on: https://github.com/TakuyaHiraoka/Dropout-Q-Functions-for-Doubly-Efficient-Reinforcement-Learning/blob/main/KUCodebase/code/agent.py
This code runs automatic entropy tuning
"""

import torch

from cares_reinforcement_learning.algorithm.policy import SAC
from cares_reinforcement_learning.networks.DroQ import Actor, Critic
from cares_reinforcement_learning.util.configurations import DroQConfig


class DroQ(SAC):
def __init__(
self,
actor_network: Actor,
critic_network: Critic,
config: DroQConfig,
device: torch.device,
):
super().__init__(actor_network, critic_network, config, device)
4 changes: 4 additions & 0 deletions cares_reinforcement_learning/algorithm/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,8 @@
from .MAPERSAC import MAPERSAC
from .LA3PSAC import LA3PSAC
from .TQC import TQC
<<<<<<< HEAD
from .CrossQ import CrossQ
=======
from .DroQ import DroQ
>>>>>>> main
2 changes: 2 additions & 0 deletions cares_reinforcement_learning/networks/DroQ/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .actor import DefaultActor, Actor
from .critic import DefaultCritic, Critic
6 changes: 6 additions & 0 deletions cares_reinforcement_learning/networks/DroQ/actor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
This is a stub file for the Actor class - reads directly off SAC's Actor class.
"""

# pylint: disable=unused-import
from cares_reinforcement_learning.networks.SAC import Actor, DefaultActor
62 changes: 62 additions & 0 deletions cares_reinforcement_learning/networks/DroQ/critic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
This is a stub file for the Critic class - reads directly off SAC's Critic class.
"""

# pylint: disable=unused-import
from torch import nn

from cares_reinforcement_learning.networks.common import TwinQNetwork
from cares_reinforcement_learning.networks.SAC import Critic
from cares_reinforcement_learning.util.configurations import DroQConfig, MLPConfig


class DefaultCritic(TwinQNetwork):
def __init__(
self,
observation_size: int,
num_actions: int,
):
input_size = observation_size + num_actions
hidden_sizes = [256, 256]

critic_config: MLPConfig = MLPConfig(
hidden_sizes=hidden_sizes,
dropout_layer="Dropout",
dropout_layer_args={"p": 0.005},
norm_layer="LayerNorm",
layer_order=["dropout", "layernorm", "activation"],
)

super().__init__(
input_size=input_size,
output_size=1,
config=critic_config,
)

# Q1 architecture
# pylint: disable-next=invalid-name
self.Q1 = nn.Sequential(
nn.Linear(input_size, hidden_sizes[0]),
nn.Dropout(0.005),
nn.LayerNorm(hidden_sizes[0]),
nn.ReLU(),
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
nn.Dropout(0.005),
nn.LayerNorm(hidden_sizes[1]),
nn.ReLU(),
nn.Linear(hidden_sizes[1], 1),
)

# Q2 architecture
# pylint: disable-next=invalid-name
self.Q2 = nn.Sequential(
nn.Linear(input_size, hidden_sizes[0]),
nn.Dropout(0.005),
nn.LayerNorm(hidden_sizes[0]),
nn.ReLU(),
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
nn.Dropout(0.005),
nn.LayerNorm(hidden_sizes[1]),
nn.ReLU(),
nn.Linear(hidden_sizes[1], 1),
)
30 changes: 30 additions & 0 deletions cares_reinforcement_learning/util/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,36 @@ class CrossQConfig(AlgorithmConfig):
)


class DroQConfig(SACConfig):
algorithm: str = Field("DroQ", Literal=True)
actor_lr: float = 3e-4
critic_lr: float = 3e-4
alpha_lr: float = 3e-4

gamma: float = 0.99
tau: float = 0.005
reward_scale: float = 1.0

G: int = 20

log_std_bounds: list[float] = [-20, 2]

policy_update_freq: int = 1
target_update_freq: int = 1

hidden_size_actor: list[int] = [256, 256]
hidden_size_critic: list[int] = [256, 256]

actor_config: MLPConfig = MLPConfig(hidden_sizes=[256, 256])
critic_config: MLPConfig = MLPConfig(
hidden_sizes=[256, 256],
dropout_layer="Dropout",
dropout_layer_args={"p": 0.005},
norm_layer="LayerNorm",
layer_order=["dropout", "layernorm", "activation"],
)


class DynaSACConfig(SACConfig):
algorithm: str = Field("DynaSAC", Literal=True)
actor_lr: float = 3e-4
Expand Down
17 changes: 17 additions & 0 deletions cares_reinforcement_learning/util/network_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,23 @@ def create_RDSAC(observation_size, action_num, config: acf.RDSACConfig):
return agent


def create_DroQ(observation_size, action_num, config: acf.DroQConfig):
from cares_reinforcement_learning.algorithm.policy import DroQ
from cares_reinforcement_learning.networks.DroQ import Actor, Critic

actor = Actor(observation_size, action_num, config=config)
critic = Critic(observation_size, action_num, config=config)

device = hlp.get_device()
agent = DroQ(
actor_network=actor,
critic_network=critic,
config=config,
device=device,
)
return agent


def create_CrossQ(observation_size, action_num, config: acf.CrossQConfig):
from cares_reinforcement_learning.algorithm.policy import CrossQ
from cares_reinforcement_learning.networks.CrossQ import Actor, Critic
Expand Down

0 comments on commit 3542ab2

Please sign in to comment.