Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CrossQ #204

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
bf635fd
CrossQ Implementation
beardyFace Oct 11, 2024
dfb33a5
Fixes to CrossQ implementation and default parameters
beardyFace Oct 21, 2024
304f430
Eval vs Train setting of actor/critic
beardyFace Oct 23, 2024
14f6b1c
Still not working
beardyFace Oct 25, 2024
aeabcfa
best crossq so far
beardyFace Oct 28, 2024
deadff2
MLP Generic norm and activation functions
beardyFace Nov 20, 2024
643425f
Shifting Configs into Actor/Critics to enable MLP parameters to be pa…
beardyFace Nov 20, 2024
16ff53e
Overkill MLP
beardyFace Nov 21, 2024
463b881
Updated all networks to use MLP
beardyFace Nov 22, 2024
3dcd047
Updated all networks but AE based for base + default + MLP networks
beardyFace Nov 25, 2024
388ec91
REDQ fix using Ensemble method from CTD4
beardyFace Nov 25, 2024
d2679fc
SACAE and TD3AE updated to reflect default + base + actor/critic layo…
beardyFace Nov 25, 2024
56be414
Type hint for types in networks
beardyFace Nov 25, 2024
2588626
update for BatchNorm eval vs test
beardyFace Nov 26, 2024
b3c27e7
NaSATD3 Default + Base + Custom + algorithm tests
beardyFace Dec 2, 2024
5651e72
added asserts
beardyFace Dec 2, 2024
909fa8a
merged main
beardyFace Dec 2, 2024
91c3532
Updated with MLP update
beardyFace Dec 2, 2024
2d54abb
CrossQ Base, Default, Custom - not MLP yet as different layout of bat…
beardyFace Dec 2, 2024
4bee841
Type hint for CrossQ
beardyFace Dec 2, 2024
aee0a08
merged main
beardyFace Dec 2, 2024
95763d3
update main
beardyFace Dec 3, 2024
64ed5ba
BatchReNorm with warm up moved into code base
beardyFace Dec 3, 2024
b4a3c3d
batchrenorm
beardyFace Dec 4, 2024
38a4219
BatchReNorm from pytorch rl
beardyFace Dec 4, 2024
5944c1e
batchrenorm from stablebaselines best so far - but not perfect
beardyFace Dec 4, 2024
07aad0a
Merge branch 'main' into alg/cross-q
beardyFace Dec 11, 2024
297b6cf
Merged with main - configurable CrossQ - shifted BatchReNorm to netwo…
beardyFace Dec 11, 2024
1ae79c0
beta for optimizers for actor and critic
beardyFace Dec 11, 2024
3542ab2
merge with main
beardyFace Dec 12, 2024
991cd46
README
beardyFace Dec 12, 2024
ca96b21
README
beardyFace Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ A factory class for creating a memory buffer that has been implemented into the
| REDQ | Vector | Continuous | [REDQ Paper](https://arxiv.org/pdf/2101.05982.pdf) |
| TQC | Vector | Continuous | [TQC Paper](https://arxiv.org/abs/1812.05905) |
| CTD4 | Vector | Continuous | [CTD4 Paper](https://arxiv.org/abs/2405.02576) |
| CrossQ | Vector | Continuous | [CrossQ Paper](https://arxiv.org/pdf/1902.05605) |
| Droq | Vector | Continuous | [DroQ Paper](https://arxiv.org/abs/2110.02034) |
| ----------- | -------------------------- | ------------ | --------------- |
| NaSATD3 | Image | Continuous | In Submission |
Expand Down
198 changes: 198 additions & 0 deletions cares_reinforcement_learning/algorithm/policy/CrossQ.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
"""
Original Paper: https://arxiv.org/pdf/1902.05605
Code based on: https://github.com/modelbased/minirllab/blob/main/agents/sac_crossq.py

This code runs automatic entropy tuning
"""

import logging
import os
from typing import Any

import numpy as np
import torch
import torch.nn.functional as F

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.CrossQ import Actor, Critic
from cares_reinforcement_learning.util.configurations import CrossQConfig


class CrossQ:
def __init__(
self,
actor_network: Actor,
critic_network: Critic,
config: CrossQConfig,
device: torch.device,
):
self.type = "policy"
self.device = device

# this may be called policy_net in other implementations
self.actor_net = actor_network.to(device)

# this may be called soft_q_net in other implementations
self.critic_net = critic_network.to(device)

self.gamma = config.gamma
self.reward_scale = config.reward_scale

self.learn_counter = 0
self.policy_update_freq = config.policy_update_freq

self.target_entropy = -np.prod(self.actor_net.num_actions)

self.actor_net_optimiser = torch.optim.Adam(
self.actor_net.parameters(), lr=config.actor_lr, betas=(0.5, 0.999)
)
self.critic_net_optimiser = torch.optim.Adam(
self.critic_net.parameters(), lr=config.critic_lr, betas=(0.5, 0.999)
)

# Temperature (alpha) for the entropy loss
# Set to initial alpha to 1.0 according to other baselines.
init_temperature = 1.0
self.log_alpha = torch.tensor(np.log(init_temperature)).to(device)
self.log_alpha.requires_grad = True
self.log_alpha_optimizer = torch.optim.Adam(
[self.log_alpha], lr=config.alpha_lr
)

# pylint: disable-next=unused-argument
def select_action_from_policy(
self, state: np.ndarray, evaluation: bool = False, noise_scale: float = 0
) -> np.ndarray:
# note that when evaluating this algorithm we need to select mu as action
self.actor_net.eval()
with torch.no_grad():
state_tensor = torch.FloatTensor(state).to(self.device)
state_tensor = state_tensor.unsqueeze(0)
if evaluation:
(_, _, action) = self.actor_net(state_tensor)
else:
(action, _, _) = self.actor_net(state_tensor)
action = action.cpu().data.numpy().flatten()
self.actor_net.train()
return action

@property
def alpha(self) -> torch.Tensor:
return self.log_alpha.exp()

def _update_critic(
self,
states: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
next_states: torch.Tensor,
dones: torch.Tensor,
) -> tuple[float, float, float]:

with torch.no_grad():
with hlp.evaluating(self.actor_net):
next_actions, next_log_pi, _ = self.actor_net(next_states)

cat_states = torch.cat([states, next_states], dim=0)
cat_actions = torch.cat([actions, next_actions], dim=0)

cat_q_values_one, cat_q_values_two = self.critic_net(cat_states, cat_actions)

q_values_one, q_values_one_next = torch.chunk(cat_q_values_one, chunks=2, dim=0)
q_values_two, q_values_two_next = torch.chunk(cat_q_values_two, chunks=2, dim=0)

target_q_values = (
torch.minimum(q_values_one_next, q_values_two_next)
- self.alpha * next_log_pi
)

q_target = (
rewards * self.reward_scale + self.gamma * (1 - dones) * target_q_values
)
torch.detach(q_target)

critic_loss_one = F.mse_loss(q_values_one, q_target)
critic_loss_two = F.mse_loss(q_values_two, q_target)

critic_loss_total = critic_loss_one + critic_loss_two

self.critic_net_optimiser.zero_grad()
critic_loss_total.backward()
self.critic_net_optimiser.step()

return critic_loss_one.item(), critic_loss_two.item(), critic_loss_total.item()

def _update_actor_alpha(self, states: torch.Tensor) -> tuple[float, float]:
pi, log_pi, _ = self.actor_net(states)

with hlp.evaluating(self.critic_net):
qf1_pi, qf2_pi = self.critic_net(states, pi)

min_qf_pi = torch.minimum(qf1_pi, qf2_pi)

actor_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

self.actor_net_optimiser.zero_grad()
actor_loss.backward()
self.actor_net_optimiser.step()

# update the temperature (alpha)
alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

self.log_alpha_optimizer.zero_grad()
alpha_loss.backward()
self.log_alpha_optimizer.step()

return actor_loss.item(), alpha_loss.item()

def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
self.learn_counter += 1

experiences = memory.sample_uniform(batch_size)
states, actions, rewards, next_states, dones, _ = experiences

batch_size = len(states)

# Convert into tensor
states = torch.FloatTensor(np.asarray(states)).to(self.device)
actions = torch.FloatTensor(np.asarray(actions)).to(self.device)
rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device)
next_states = torch.FloatTensor(np.asarray(next_states)).to(self.device)
dones = torch.LongTensor(np.asarray(dones)).to(self.device)

# Reshape to batch_size x whatever
rewards = rewards.unsqueeze(0).reshape(batch_size, 1)
dones = dones.unsqueeze(0).reshape(batch_size, 1)

info = {}

# Update the Critic
critic_loss_one, critic_loss_two, critic_loss_total = self._update_critic(
states, actions, rewards, next_states, dones
)
info["critic_loss_one"] = critic_loss_one
info["critic_loss_two"] = critic_loss_two
info["critic_loss"] = critic_loss_total

if self.learn_counter % self.policy_update_freq == 0:
# Update the Actor and Alpha
actor_loss, alpha_loss = self._update_actor_alpha(states)
info["actor_loss"] = actor_loss
info["alpha_loss"] = alpha_loss
info["alpha"] = self.alpha.item()

return info

def save_models(self, filepath: str, filename: str) -> None:
if not os.path.exists(filepath):
os.makedirs(filepath)

torch.save(self.actor_net.state_dict(), f"{filepath}/{filename}_actor.pht")
torch.save(self.critic_net.state_dict(), f"{filepath}/{filename}_critic.pht")
logging.info("models has been saved...")

def load_models(self, filepath: str, filename: str) -> None:
self.actor_net.load_state_dict(torch.load(f"{filepath}/{filename}_actor.pht"))
self.critic_net.load_state_dict(torch.load(f"{filepath}/{filename}_critic.pht"))
logging.info("models has been loaded...")
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/algorithm/policy/PERSAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
info["critic_loss"] = critic_loss_total

if self.learn_counter % self.policy_update_freq == 0:
# Update the Actor
# Update the Actor and Alpha
actor_loss, alpha_loss = self._update_actor_alpha(states_tensor)
info["actor_loss"] = actor_loss
info["alpha_loss"] = alpha_loss
Expand Down
1 change: 1 addition & 0 deletions cares_reinforcement_learning/algorithm/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@
from .MAPERSAC import MAPERSAC
from .LA3PSAC import LA3PSAC
from .TQC import TQC
from .CrossQ import CrossQ
from .DroQ import DroQ
2 changes: 2 additions & 0 deletions cares_reinforcement_learning/networks/CrossQ/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .actor import Actor, DefaultActor
from .critic import Critic
58 changes: 58 additions & 0 deletions cares_reinforcement_learning/networks/CrossQ/actor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from torch import nn

from cares_reinforcement_learning.networks.common import TanhGaussianPolicy
from cares_reinforcement_learning.networks.batchrenorm import BatchRenorm1d
from cares_reinforcement_learning.util.configurations import CrossQConfig, MLPConfig


class DefaultActor(TanhGaussianPolicy):
# DiagGaussianActor
"""torch.distributions implementation of an diagonal Gaussian policy."""

def __init__(
self,
observation_size: int,
num_actions: int,
hidden_sizes: list[int] | None = None,
log_std_bounds: list[float] | None = None,
):
if hidden_sizes is None:
hidden_sizes = [256, 256]

if log_std_bounds is None:
log_std_bounds = [-20.0, 2.0]

momentum = 0.01
super().__init__(
input_size=observation_size,
num_actions=num_actions,
log_std_bounds=log_std_bounds,
config=MLPConfig(hidden_sizes=hidden_sizes),
)

self.act_net = nn.Sequential(
BatchRenorm1d(observation_size, momentum=momentum),
nn.Linear(observation_size, hidden_sizes[0], bias=False),
nn.ReLU(),
BatchRenorm1d(hidden_sizes[0], momentum=momentum),
nn.Linear(hidden_sizes[0], hidden_sizes[1], bias=False),
nn.ReLU(),
BatchRenorm1d(hidden_sizes[1], momentum=momentum),
)

self.mean_linear = nn.Linear(hidden_sizes[-1], num_actions)
self.log_std_linear = nn.Linear(hidden_sizes[-1], num_actions)


class Actor(TanhGaussianPolicy):
# DiagGaussianActor
"""torch.distributions implementation of an diagonal Gaussian policy."""

def __init__(self, observation_size: int, num_actions: int, config: CrossQConfig):

super().__init__(
input_size=observation_size,
num_actions=num_actions,
log_std_bounds=config.log_std_bounds,
config=config.actor_config,
)
60 changes: 60 additions & 0 deletions cares_reinforcement_learning/networks/CrossQ/critic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from torch import nn

from cares_reinforcement_learning.networks.batchrenorm import BatchRenorm1d
from cares_reinforcement_learning.networks.common import TwinQNetwork
from cares_reinforcement_learning.util.configurations import CrossQConfig, MLPConfig


class DefaultCritic(TwinQNetwork):
def __init__(
self,
observation_size: int,
num_actions: int,
hidden_sizes: list[int] | None = None,
):
if hidden_sizes is None:
hidden_sizes = [2048, 2048]

input_size = observation_size + num_actions

super().__init__(
input_size=input_size,
output_size=1,
config=MLPConfig(hidden_sizes=hidden_sizes),
)

# Q1 architecture
# pylint: disable-next=invalid-name
momentum = 0.01
self.Q1 = nn.Sequential(
BatchRenorm1d(input_size, momentum=momentum),
nn.Linear(input_size, hidden_sizes[0], bias=False),
nn.ReLU(),
BatchRenorm1d(hidden_sizes[0], momentum=momentum),
nn.Linear(hidden_sizes[0], hidden_sizes[1], bias=False),
nn.ReLU(),
BatchRenorm1d(hidden_sizes[1], momentum=momentum),
nn.Linear(hidden_sizes[1], 1),
)

# Q2 architecture
# pylint: disable-next=invalid-name
self.Q2 = nn.Sequential(
BatchRenorm1d(input_size, momentum=momentum),
nn.Linear(input_size, hidden_sizes[0], bias=False),
nn.ReLU(),
BatchRenorm1d(hidden_sizes[0], momentum=momentum),
nn.Linear(hidden_sizes[0], hidden_sizes[1], bias=False),
nn.ReLU(),
BatchRenorm1d(hidden_sizes[1], momentum=momentum),
nn.Linear(hidden_sizes[1], 1),
)


class Critic(TwinQNetwork):
def __init__(self, observation_size: int, num_actions: int, config: CrossQConfig):
input_size = observation_size + num_actions

super().__init__(
input_size=input_size, output_size=1, config=config.critic_config
)
Loading
Loading