Skip to content

Commit

Permalink
BaseCritic + BaseActor
Browse files Browse the repository at this point in the history
  • Loading branch information
beardyFace committed Dec 11, 2024
1 parent e92afa7 commit aa67a7a
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 104 deletions.
14 changes: 7 additions & 7 deletions cares_reinforcement_learning/algorithm/policy/CTD4.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.CTD4 import Actor, EnsembleCritic
from cares_reinforcement_learning.networks.CTD4 import Actor, Critic
from cares_reinforcement_learning.util.configurations import CTD4Config


class CTD4:
def __init__(
self,
actor_network: Actor,
ensemble_critic: EnsembleCritic,
ensemble_critic: Critic,
config: CTD4Config,
device: torch.device,
):
Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(
self.lr_ensemble_critic = config.critic_lr
self.ensemble_critic_optimizers = [
torch.optim.Adam(critic_net.parameters(), lr=self.lr_ensemble_critic)
for critic_net in self.ensemble_critic
for critic_net in self.ensemble_critic.critics
]

def select_action_from_policy(
Expand Down Expand Up @@ -172,7 +172,7 @@ def _update_critics(
u_set = []
std_set = []

for target_critic_net in self.target_ensemble_critic:
for target_critic_net in self.target_ensemble_critic.critics:
u, std = target_critic_net(next_states, next_actions)

u_set.append(u)
Expand All @@ -199,7 +199,7 @@ def _update_critics(
critic_loss_totals = []

for critic_net, critic_net_optimiser in zip(
self.ensemble_critic, self.ensemble_critic_optimizers
self.ensemble_critic.critics, self.ensemble_critic_optimizers
):
u_current, std_current = critic_net(states, actions)
current_distribution = torch.distributions.normal.Normal(
Expand Down Expand Up @@ -227,7 +227,7 @@ def _update_actor(self, states: torch.Tensor) -> float:

actions = self.actor_net(states)
with hlp.evaluating(self.ensemble_critic):
for critic_net in self.ensemble_critic:
for critic_net in self.ensemble_critic.critics:
actor_q_u, actor_q_std = critic_net(states, actions)

actor_q_u_set.append(actor_q_u)
Expand Down Expand Up @@ -297,7 +297,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:

# Update ensemble of target critics
for critic_net, target_critic_net in zip(
self.ensemble_critic, self.target_ensemble_critic
self.ensemble_critic.critics, self.target_ensemble_critic.critics
):
hlp.soft_update_params(critic_net, target_critic_net, self.tau)

Expand Down
1 change: 0 additions & 1 deletion cares_reinforcement_learning/networks/CTD4/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .actor import Actor, DefaultActor
from .critic import Critic, DefaultCritic
from .ensemble_critic import EnsembleCritic, BaseEnsembleCritic
44 changes: 38 additions & 6 deletions cares_reinforcement_learning/networks/CTD4/critic.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from torch import nn

from cares_reinforcement_learning.util.common import ContinuousDistributedCritic
from cares_reinforcement_learning.util.common import (
ContinuousDistributedCritic,
EnsembleCritic,
)
from cares_reinforcement_learning.util.configurations import CTD4Config, MLPConfig


# This is the default base network for CTD4 for reference and testing of default network configurations
class DefaultCritic(ContinuousDistributedCritic):
class DefaultContinuousDistributedCritic(ContinuousDistributedCritic):
def __init__(self, observation_size: int, action_num: int):
input_size = observation_size + action_num
hidden_sizes = [256, 256]

super().__init__(
input_size=input_size,
output_size=1,
config=MLPConfig(hidden_sizes=hidden_sizes),
)

Expand All @@ -34,8 +38,36 @@ def __init__(self, observation_size: int, action_num: int):
self.soft_std_layer = nn.Softplus()


class Critic(ContinuousDistributedCritic):
def __init__(self, observation_size: int, action_num: int, config: CTD4Config):
input_size = observation_size + action_num
class DefaultCritic(EnsembleCritic):
def __init__(self, observation_size: int, num_actions: int):
input_size = observation_size + num_actions

ensemble_size = 3
hidden_sizes = [256, 256]

super().__init__(
input_size=input_size,
output_size=1,
ensemble_size=ensemble_size,
config=MLPConfig(hidden_sizes=hidden_sizes),
)

for i in range(ensemble_size):
critic_net = DefaultContinuousDistributedCritic(
observation_size=observation_size, action_num=num_actions
)
self.add_module(f"critic_net_{i}", critic_net)
self.critics[i] = critic_net


class Critic(EnsembleCritic):
def __init__(self, observation_size: int, num_actions: int, config: CTD4Config):
input_size = observation_size + num_actions

super().__init__(input_size=input_size, config=config.critic_config)
super().__init__(
input_size=input_size,
output_size=1,
ensemble_size=config.ensemble_size,
config=config.critic_config,
critic_type=ContinuousDistributedCritic,
)
46 changes: 0 additions & 46 deletions cares_reinforcement_learning/networks/CTD4/ensemble_critic.py

This file was deleted.

Loading

0 comments on commit aa67a7a

Please sign in to comment.