diff --git a/README.md b/README.md index 6d33358..67b8cc0 100644 --- a/README.md +++ b/README.md @@ -226,7 +226,11 @@ A factory class for creating a memory buffer that has been implemented into the | REDQ | Vector | Continuous | [REDQ Paper](https://arxiv.org/pdf/2101.05982.pdf) | | TQC | Vector | Continuous | [TQC Paper](https://arxiv.org/abs/1812.05905) | | CTD4 | Vector | Continuous | [CTD4 Paper](https://arxiv.org/abs/2405.02576) | +<<<<<<< HEAD | CrossQ | Vector | Continuous | [CrossQ Paper](https://arxiv.org/pdf/1902.05605) | +======= +| Droq | Vector | Continuous | [DroQ Paper](https://arxiv.org/abs/2110.02034) | +>>>>>>> main | ----------- | -------------------------- | ------------ | --------------- | | NaSATD3 | Image | Continuous | In Submission | | TD3AE | Image | Continuous | [TD3AE Paper](https://arxiv.org/abs/1910.01741) | diff --git a/cares_reinforcement_learning/algorithm/policy/DroQ.py b/cares_reinforcement_learning/algorithm/policy/DroQ.py new file mode 100644 index 0000000..8dcacee --- /dev/null +++ b/cares_reinforcement_learning/algorithm/policy/DroQ.py @@ -0,0 +1,23 @@ +""" +Original Paper: https://openreview.net/pdf?id=xCVJMsPv3RT +Code based on: https://github.com/TakuyaHiraoka/Dropout-Q-Functions-for-Doubly-Efficient-Reinforcement-Learning/blob/main/KUCodebase/code/agent.py + +This code runs automatic entropy tuning +""" + +import torch + +from cares_reinforcement_learning.algorithm.policy import SAC +from cares_reinforcement_learning.networks.DroQ import Actor, Critic +from cares_reinforcement_learning.util.configurations import DroQConfig + + +class DroQ(SAC): + def __init__( + self, + actor_network: Actor, + critic_network: Critic, + config: DroQConfig, + device: torch.device, + ): + super().__init__(actor_network, critic_network, config, device) diff --git a/cares_reinforcement_learning/algorithm/policy/__init__.py b/cares_reinforcement_learning/algorithm/policy/__init__.py index 86b3ba3..e50ed0e 100644 --- a/cares_reinforcement_learning/algorithm/policy/__init__.py +++ b/cares_reinforcement_learning/algorithm/policy/__init__.py @@ -20,4 +20,8 @@ from .MAPERSAC import MAPERSAC from .LA3PSAC import LA3PSAC from .TQC import TQC +<<<<<<< HEAD from .CrossQ import CrossQ +======= +from .DroQ import DroQ +>>>>>>> main diff --git a/cares_reinforcement_learning/networks/DroQ/__init__.py b/cares_reinforcement_learning/networks/DroQ/__init__.py new file mode 100644 index 0000000..f306b72 --- /dev/null +++ b/cares_reinforcement_learning/networks/DroQ/__init__.py @@ -0,0 +1,2 @@ +from .actor import DefaultActor, Actor +from .critic import DefaultCritic, Critic diff --git a/cares_reinforcement_learning/networks/DroQ/actor.py b/cares_reinforcement_learning/networks/DroQ/actor.py new file mode 100644 index 0000000..28359b5 --- /dev/null +++ b/cares_reinforcement_learning/networks/DroQ/actor.py @@ -0,0 +1,6 @@ +""" +This is a stub file for the Actor class - reads directly off SAC's Actor class. +""" + +# pylint: disable=unused-import +from cares_reinforcement_learning.networks.SAC import Actor, DefaultActor diff --git a/cares_reinforcement_learning/networks/DroQ/critic.py b/cares_reinforcement_learning/networks/DroQ/critic.py new file mode 100644 index 0000000..800e6a4 --- /dev/null +++ b/cares_reinforcement_learning/networks/DroQ/critic.py @@ -0,0 +1,62 @@ +""" +This is a stub file for the Critic class - reads directly off SAC's Critic class. +""" + +# pylint: disable=unused-import +from torch import nn + +from cares_reinforcement_learning.networks.common import TwinQNetwork +from cares_reinforcement_learning.networks.SAC import Critic +from cares_reinforcement_learning.util.configurations import DroQConfig, MLPConfig + + +class DefaultCritic(TwinQNetwork): + def __init__( + self, + observation_size: int, + num_actions: int, + ): + input_size = observation_size + num_actions + hidden_sizes = [256, 256] + + critic_config: MLPConfig = MLPConfig( + hidden_sizes=hidden_sizes, + dropout_layer="Dropout", + dropout_layer_args={"p": 0.005}, + norm_layer="LayerNorm", + layer_order=["dropout", "layernorm", "activation"], + ) + + super().__init__( + input_size=input_size, + output_size=1, + config=critic_config, + ) + + # Q1 architecture + # pylint: disable-next=invalid-name + self.Q1 = nn.Sequential( + nn.Linear(input_size, hidden_sizes[0]), + nn.Dropout(0.005), + nn.LayerNorm(hidden_sizes[0]), + nn.ReLU(), + nn.Linear(hidden_sizes[0], hidden_sizes[1]), + nn.Dropout(0.005), + nn.LayerNorm(hidden_sizes[1]), + nn.ReLU(), + nn.Linear(hidden_sizes[1], 1), + ) + + # Q2 architecture + # pylint: disable-next=invalid-name + self.Q2 = nn.Sequential( + nn.Linear(input_size, hidden_sizes[0]), + nn.Dropout(0.005), + nn.LayerNorm(hidden_sizes[0]), + nn.ReLU(), + nn.Linear(hidden_sizes[0], hidden_sizes[1]), + nn.Dropout(0.005), + nn.LayerNorm(hidden_sizes[1]), + nn.ReLU(), + nn.Linear(hidden_sizes[1], 1), + ) diff --git a/cares_reinforcement_learning/util/configurations.py b/cares_reinforcement_learning/util/configurations.py index 7e61ed0..8a660eb 100644 --- a/cares_reinforcement_learning/util/configurations.py +++ b/cares_reinforcement_learning/util/configurations.py @@ -452,6 +452,36 @@ class CrossQConfig(AlgorithmConfig): ) +class DroQConfig(SACConfig): + algorithm: str = Field("DroQ", Literal=True) + actor_lr: float = 3e-4 + critic_lr: float = 3e-4 + alpha_lr: float = 3e-4 + + gamma: float = 0.99 + tau: float = 0.005 + reward_scale: float = 1.0 + + G: int = 20 + + log_std_bounds: list[float] = [-20, 2] + + policy_update_freq: int = 1 + target_update_freq: int = 1 + + hidden_size_actor: list[int] = [256, 256] + hidden_size_critic: list[int] = [256, 256] + + actor_config: MLPConfig = MLPConfig(hidden_sizes=[256, 256]) + critic_config: MLPConfig = MLPConfig( + hidden_sizes=[256, 256], + dropout_layer="Dropout", + dropout_layer_args={"p": 0.005}, + norm_layer="LayerNorm", + layer_order=["dropout", "layernorm", "activation"], + ) + + class DynaSACConfig(SACConfig): algorithm: str = Field("DynaSAC", Literal=True) actor_lr: float = 3e-4 diff --git a/cares_reinforcement_learning/util/network_factory.py b/cares_reinforcement_learning/util/network_factory.py index 969a489..616ca97 100644 --- a/cares_reinforcement_learning/util/network_factory.py +++ b/cares_reinforcement_learning/util/network_factory.py @@ -269,6 +269,23 @@ def create_RDSAC(observation_size, action_num, config: acf.RDSACConfig): return agent +def create_DroQ(observation_size, action_num, config: acf.DroQConfig): + from cares_reinforcement_learning.algorithm.policy import DroQ + from cares_reinforcement_learning.networks.DroQ import Actor, Critic + + actor = Actor(observation_size, action_num, config=config) + critic = Critic(observation_size, action_num, config=config) + + device = hlp.get_device() + agent = DroQ( + actor_network=actor, + critic_network=critic, + config=config, + device=device, + ) + return agent + + def create_CrossQ(observation_size, action_num, config: acf.CrossQConfig): from cares_reinforcement_learning.algorithm.policy import CrossQ from cares_reinforcement_learning.networks.CrossQ import Actor, Critic