Skip to content

Commit

Permalink
Handling target vs policy update parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
beardyFace committed Nov 4, 2024
1 parent 2ed4652 commit 13f85cf
Show file tree
Hide file tree
Showing 15 changed files with 120 additions and 69 deletions.
10 changes: 6 additions & 4 deletions cares_reinforcement_learning/algorithm/mbrl/DynaSAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def __init__(
self.action_num = self.actor_net.num_actions

self.learn_counter = 0
self.policy_update_freq = 1
self.policy_update_freq = config.policy_update_freq
self.target_update_freq = config.target_update_freq

self.actor_net_optimiser = torch.optim.Adam(
self.actor_net.parameters(), lr=config.actor_lr
Expand Down Expand Up @@ -93,10 +94,11 @@ def _update_critic_actor(self, states, actions, rewards, next_states, dones):
# Update Critic
self._update_critic(states, actions, rewards, next_states, dones)

# Update Actor
self._update_actor(states)

if self.learn_counter % self.policy_update_freq == 0:
# Update Actor
self._update_actor(states)

if self.learn_counter % self.target_update_freq == 0:
hlp.soft_update_params(self.critic_net, self.target_critic_net, self.tau)

def _update_critic(self, states, actions, rewards, next_states, dones):
Expand Down
4 changes: 2 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/LA3PSAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
self.prioritized_fraction = config.prioritized_fraction

self.learn_counter = 0
self.policy_update_freq = 1
self.target_update_freq = config.target_update_freq

self.target_entropy = -self.actor_net.num_actions

Expand Down Expand Up @@ -190,7 +190,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
uniform_batch_size = int(batch_size * (1 - self.prioritized_fraction))
priority_batch_size = int(batch_size * self.prioritized_fraction)

policy_update = self.learn_counter % self.policy_update_freq == 0
policy_update = self.learn_counter % self.target_update_freq == 0

######################### UNIFORM SAMPLING #########################
experiences = memory.sample_uniform(uniform_batch_size)
Expand Down
16 changes: 9 additions & 7 deletions cares_reinforcement_learning/algorithm/policy/LAPSAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def __init__(
self.min_priority = config.min_priority

self.learn_counter = 0
self.policy_update_freq = 1
self.policy_update_freq = config.policy_update_freq
self.target_update_freq = config.target_update_freq

self.target_entropy = -self.actor_net.num_actions

Expand Down Expand Up @@ -178,13 +179,14 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
info["huber_lose_two"] = huber_lose_two
info["critic_loss_total"] = critic_loss_total

# Update the Actor
actor_loss, alpha_loss = self._update_actor_alpha(states)
info["actor_loss"] = actor_loss
info["alpha_loss"] = alpha_loss
info["alpha"] = self.alpha.item()

if self.learn_counter % self.policy_update_freq == 0:
# Update the Actor
actor_loss, alpha_loss = self._update_actor_alpha(states)
info["actor_loss"] = actor_loss
info["alpha_loss"] = alpha_loss
info["alpha"] = self.alpha.item()

if self.learn_counter % self.target_update_freq == 0:
hlp.soft_update_params(self.critic_net, self.target_critic_net, self.tau)

memory.update_priorities(indices, priorities)
Expand Down
16 changes: 9 additions & 7 deletions cares_reinforcement_learning/algorithm/policy/MAPERSAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def __init__(
self.min_priority = config.min_priority

self.learn_counter = 0
self.policy_update_freq = 1
self.policy_update_freq = config.policy_update_freq
self.target_update_freq = config.target_update_freq

self.target_entropy = -self.actor_net.num_actions

Expand Down Expand Up @@ -284,13 +285,14 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
)
info["critic_loss_total"] = critic_loss_total

# Update the Actor
actor_loss, alpha_loss = self._update_actor_alpha(states, weights)
info["actor_loss"] = actor_loss
info["alpha_loss"] = alpha_loss
info["alpha"] = self.alpha.item()

if self.learn_counter % self.policy_update_freq == 0:
# Update the Actor
actor_loss, alpha_loss = self._update_actor_alpha(states, weights)
info["actor_loss"] = actor_loss
info["alpha_loss"] = alpha_loss
info["alpha"] = self.alpha.item()

if self.learn_counter % self.target_update_freq == 0:
hlp.soft_update_params(self.critic_net, self.target_critic_net, self.tau)

memory.update_priorities(indices, priorities)
Expand Down
14 changes: 8 additions & 6 deletions cares_reinforcement_learning/algorithm/policy/PERSAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def __init__(
self.min_priority = config.min_priority

self.learn_counter = 0
self.policy_update_freq = 1
self.policy_update_freq = config.policy_update_freq
self.target_update_freq = config.target_update_freq

self.target_entropy = -self.actor_net.num_actions

Expand Down Expand Up @@ -175,12 +176,13 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
)
info["critic_loss"] = critic_loss_total

# Update the Actor
actor_loss = self._update_actor_alpha(states)
info["actor_loss"] = actor_loss
info["alpha"] = self.alpha.item()

if self.learn_counter % self.policy_update_freq == 0:
# Update the Actor
actor_loss = self._update_actor_alpha(states)
info["actor_loss"] = actor_loss
info["alpha"] = self.alpha.item()

if self.learn_counter % self.target_update_freq == 0:
hlp.soft_update_params(self.critic_net, self.target_critic_net, self.tau)

memory.update_priorities(indices, priorities)
Expand Down
16 changes: 9 additions & 7 deletions cares_reinforcement_learning/algorithm/policy/RDSAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def __init__(
self.per_alpha = config.per_alpha

self.learn_counter = 0
self.policy_update_freq = 1
self.policy_update_freq = config.policy_update_freq
self.target_update_freq = config.target_update_freq

self.target_entropy = -self.actor_net.num_actions

Expand Down Expand Up @@ -245,13 +246,14 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
)
info["critic_loss_total"] = critic_loss_total

# Update the Actor
actor_loss, alpha_loss = self._update_actor_alpha(states, weights)
info["actor_loss"] = actor_loss
info["alpha_loss"] = alpha_loss
info["alpha"] = self.alpha.item()

if self.learn_counter % self.policy_update_freq == 0:
# Update the Actor
actor_loss, alpha_loss = self._update_actor_alpha(states, weights)
info["actor_loss"] = actor_loss
info["alpha_loss"] = alpha_loss
info["alpha"] = self.alpha.item()

if self.learn_counter % self.target_update_freq == 0:
hlp.soft_update_params(self.critic_net, self.target_critic_net, self.tau)

memory.update_priorities(indices, priorities)
Expand Down
16 changes: 9 additions & 7 deletions cares_reinforcement_learning/algorithm/policy/REDQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def __init__(
self.tau = config.tau

self.learn_counter = 0
self.policy_update_freq = 1
self.policy_update_freq = config.policy_update_freq
self.target_update_freq = config.target_update_freq

self.device = device

Expand Down Expand Up @@ -193,13 +194,14 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
)
info["critic_loss_totals"] = critic_loss_totals

# Update the Actor
actor_loss, alpha_loss = self._update_actor_alpha(idx, states)
info["actor_loss"] = actor_loss
info["alpha_loss"] = alpha_loss
info["alpha"] = self.alpha.item()

if self.learn_counter % self.policy_update_freq == 0:
# Update the Actor
actor_loss, alpha_loss = self._update_actor_alpha(idx, states)
info["actor_loss"] = actor_loss
info["alpha_loss"] = alpha_loss
info["alpha"] = self.alpha.item()

if self.learn_counter % self.target_update_freq == 0:
# Update ensemble of target critics
for critic_net, target_critic_net in zip(
self.ensemble_critics, self.target_ensemble_critics
Expand Down
16 changes: 9 additions & 7 deletions cares_reinforcement_learning/algorithm/policy/SAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def __init__(
self.reward_scale = config.reward_scale

self.learn_counter = 0
self.policy_update_freq = 1
self.policy_update_freq = config.policy_update_freq
self.target_update_freq = config.target_update_freq

self.target_entropy = -self.actor_net.num_actions

Expand Down Expand Up @@ -167,13 +168,14 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
info["critic_loss_two"] = critic_loss_two
info["critic_loss"] = critic_loss_total

# Update the Actor and Alpha
actor_loss, alpha_loss = self._update_actor_alpha(states)
info["actor_loss"] = actor_loss
info["alpha_loss"] = alpha_loss
info["alpha"] = self.alpha.item()

if self.learn_counter % self.policy_update_freq == 0:
# Update the Actor and Alpha
actor_loss, alpha_loss = self._update_actor_alpha(states)
info["actor_loss"] = actor_loss
info["alpha_loss"] = alpha_loss
info["alpha"] = self.alpha.item()

if self.learn_counter % self.target_update_freq == 0:
hlp.soft_update_params(self.critic_net, self.target_critic_net, self.tau)

return info
Expand Down
4 changes: 2 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/SACAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def __init__(
self.reward_scale = config.reward_scale

self.learn_counter = 0
self.policy_update_freq = 2
self.target_update_freq = 2
self.policy_update_freq = config.policy_update_freq
self.target_update_freq = config.target_update_freq

actor_beta = 0.9
critic_beta = 0.9
Expand Down
16 changes: 9 additions & 7 deletions cares_reinforcement_learning/algorithm/policy/SACD.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def __init__(
self.reward_scale = config.reward_scale

self.learn_counter = 0
self.policy_update_freq = 1
self.policy_update_freq = config.policy_update_freq
self.target_update_freq = config.target_update_freq

self.action_num = self.actor_net.num_actions

Expand Down Expand Up @@ -182,13 +183,14 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
)
info["critic_loss"] = critic_loss_total

# Update the Actor and Alpha
actor_loss, alpha_loss = self._update_actor_alpha(states)
info["actor_loss"] = actor_loss
info["alpha_loss"] = alpha_loss
info["alpha"] = self.alpha.item()

if self.learn_counter % self.policy_update_freq == 0:
# Update the Actor and Alpha
actor_loss, alpha_loss = self._update_actor_alpha(states)
info["actor_loss"] = actor_loss
info["alpha_loss"] = alpha_loss
info["alpha"] = self.alpha.item()

if self.learn_counter % self.target_update_freq == 0:
hlp.soft_update_params(self.critic_net, self.target_critic_net, self.tau)

return info
Expand Down
16 changes: 9 additions & 7 deletions cares_reinforcement_learning/algorithm/policy/TQC.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def __init__(
)

self.learn_counter = 0
self.policy_update_freq = 1
self.policy_update_freq = config.policy_update_freq
self.target_update_freq = config.target_update_freq

self.device = device

Expand Down Expand Up @@ -178,13 +179,14 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
)
info["critic_loss"] = critic_loss_total

# Update the Actor
actor_loss, alpha_loss = self._update_actor(states)
info["actor_loss"] = actor_loss
info["alpha_loss"] = alpha_loss
info["alpha"] = self.alpha.item()

if self.learn_counter % self.policy_update_freq == 0:
# Update the Actor
actor_loss, alpha_loss = self._update_actor(states)
info["actor_loss"] = actor_loss
info["alpha_loss"] = alpha_loss
info["alpha"] = self.alpha.item()

if self.learn_counter % self.target_update_freq == 0:
hlp.soft_update_params(self.critic_net, self.target_critic_net, self.tau)

return info
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/algorithm/value/DQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
self.network.parameters(), lr=config.lr
)

def select_action_from_policy(self, state):
def select_action_from_policy(self, state) -> int:
self.network.eval()
with torch.no_grad():
state_tensor = torch.FloatTensor(state).to(self.device)
Expand Down
3 changes: 2 additions & 1 deletion cares_reinforcement_learning/networks/SACD/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@ def __init__(

def forward(
self, state: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
action_probs = self.act_net(state)
max_probability_action = torch.argmax(action_probs)
dist = torch.distributions.Categorical(action_probs)
action = dist.sample()

# Offset any values which are zero by a small amount so no nan nonsense
zero_offset = action_probs == 0.0
zero_offset = zero_offset.float() * 1e-8
Expand Down
Loading

0 comments on commit 13f85cf

Please sign in to comment.