Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added adaptive learning rate feature. #6180

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion ml-agents/mlagents/trainers/agent_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
GlobalAgentId,
GlobalGroupId,
)
from mlagents.trainers.torch_entities.action_log_probs import LogProbsTuple
from mlagents.trainers.torch_entities.action_log_probs import LogProbsTuple, MusTuple, SigmasTuple
from mlagents.trainers.torch_entities.utils import ModelUtils

T = TypeVar("T")
Expand Down Expand Up @@ -251,6 +251,28 @@ def _process_step(
except KeyError:
log_probs_tuple = LogProbsTuple.empty_log_probs()

try:
stored_action_mus = stored_take_action_outputs["mus"]
if not isinstance(stored_action_mus, MusTuple):
stored_action_mus = stored_action_mus.to_mus_tuple()
mus_tuple = MusTuple(
continuous=stored_action_mus.continuous[idx],
discrete=stored_action_mus.discrete[idx],
)
except KeyError:
mus_tuple = MusTuple.empty_mus()

try:
stored_action_sigmas = stored_take_action_outputs["sigmas"]
if not isinstance(stored_action_sigmas, SigmasTuple):
stored_action_sigmas = stored_action_sigmas.to_sigmas_tuple()
sigmas_tuple = SigmasTuple(
continuous=stored_action_sigmas.continuous[idx],
discrete=stored_action_sigmas.discrete[idx],
)
except KeyError:
sigmas_tuple = MusTuple.empty_mus()

action_mask = stored_decision_step.action_mask
prev_action = self.policy.retrieve_previous_action([global_agent_id])[0, :]

Expand All @@ -266,6 +288,8 @@ def _process_step(
done=done,
action=action_tuple,
action_probs=log_probs_tuple,
action_mus=mus_tuple,
action_sigmas=sigmas_tuple,
action_mask=action_mask,
prev_action=prev_action,
interrupted=interrupted,
Expand Down
4 changes: 4 additions & 0 deletions ml-agents/mlagents/trainers/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class BufferKey(enum.Enum):
CONTINUOUS_ACTION = "continuous_action"
NEXT_CONT_ACTION = "next_continuous_action"
CONTINUOUS_LOG_PROBS = "continuous_log_probs"
CONTINUOUS_MUS = "continuous_mus"
DISCRETE_MUS = "discrete_mus"
CONTINUOUS_SIGMAS = "continuous_sigmas"
DISCRETE_SIGMAS = "discrete_sigmas"
DISCRETE_ACTION = "discrete_action"
NEXT_DISC_ACTION = "next_discrete_action"
DISCRETE_LOG_PROBS = "discrete_log_probs"
Expand Down
26 changes: 22 additions & 4 deletions ml-agents/mlagents/trainers/ppo/optimizer_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from mlagents.trainers.torch_entities.networks import ValueNetwork
from mlagents.trainers.torch_entities.agent_action import AgentAction
from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs
from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs, ActionMus, ActionSigmas
from mlagents.trainers.torch_entities.utils import ModelUtils
from mlagents.trainers.trajectory import ObsUtil

Expand Down Expand Up @@ -66,8 +66,10 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
self.decay_learning_rate = ModelUtils.DecayedValue(
self.hyperparameters.learning_rate_schedule,
self.hyperparameters.learning_rate,
1e-10,
self.hyperparameters.lr_min,
self.trainer_settings.max_steps,
self.hyperparameters.desired_lr_kl,
self.hyperparameters.lr_max
)
self.decay_epsilon = ModelUtils.DecayedValue(
self.hyperparameters.epsilon_schedule,
Expand All @@ -92,6 +94,10 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):

self.stream_names = list(self.reward_signals.keys())

self.loss = torch.zeros(1, device=default_device())

self.last_actions = None

@property
def critic(self):
return self._critic
Expand Down Expand Up @@ -153,13 +159,17 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:

log_probs = run_out["log_probs"]
entropy = run_out["entropy"]
mus = run_out["mus"]
sigmas = run_out["sigmas"]

values, _ = self.critic.critic_pass(
current_obs,
memories=value_memories,
sequence_length=self.policy.sequence_length,
)
old_log_probs = ActionLogProbs.from_buffer(batch).flatten()
old_mus = ActionMus.from_buffer(batch).flatten()
old_sigmas = ActionSigmas.from_buffer(batch).flatten()
log_probs = log_probs.flatten()
loss_masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool)
value_loss = ModelUtils.trust_region_value_loss(
Expand All @@ -172,16 +182,22 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
loss_masks,
decay_eps,
)
loss = (
self.loss = (
policy_loss
+ 0.5 * value_loss
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks)
)

# adaptive learning rate
if self.hyperparameters.learning_rate_schedule == ScheduleType.ADAPTIVE:
decay_lr = self.decay_learning_rate.get_value(
self.policy.get_current_step(), mus, old_mus, sigmas, old_sigmas
)

# Set optimizer learning rate
ModelUtils.update_learning_rate(self.optimizer, decay_lr)
self.optimizer.zero_grad()
loss.backward()
self.loss.backward()

self.optimizer.step()
update_stats = {
Expand All @@ -194,6 +210,8 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
"Policy/Beta": decay_bet,
}

self.loss = torch.zeros(1, device=default_device())

return update_stats

# TODO move module update into TorchOptimizer for reward_provider
Expand Down
4 changes: 4 additions & 0 deletions ml-agents/mlagents/trainers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class EncoderType(Enum):
class ScheduleType(Enum):
CONSTANT = "constant"
LINEAR = "linear"
ADAPTIVE = "adaptive"
# TODO add support for lesson based scheduling
# LESSON = "lesson"

Expand Down Expand Up @@ -158,6 +159,9 @@ class HyperparamSettings:
batch_size: int = 1024
buffer_size: int = 10240
learning_rate: float = 3.0e-4
desired_lr_kl: float = 0.008
lr_min: float = 1.0e-10
lr_max: float = 1.0e-2
learning_rate_schedule: ScheduleType = ScheduleType.CONSTANT


Expand Down
170 changes: 170 additions & 0 deletions ml-agents/mlagents/trainers/torch_entities/action_log_probs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,38 @@
from mlagents_envs.base_env import _ActionTupleBase


class MusTuple(_ActionTupleBase):
"""
An object whose fields correspond to the mean of action data of continuous and discrete
spaces. Dimensions are of (n_agents, continuous_size) and (n_agents, discrete_size),
respectively. Note, this also holds when continuous or discrete size is
zero.
"""
@property
def discrete_dtype(self) -> np.dtype:
return np.float32

@staticmethod
def empty_mus() -> "MusTuple":
return MusTuple()


class SigmasTuple(_ActionTupleBase):
"""
An object whose fields correspond to the std of action data of continuous and discrete
spaces. Dimensions are of (n_agents, continuous_size) and (n_agents, discrete_size),
respectively. Note, this also holds when continuous or discrete size is
zero.
"""
@property
def discrete_dtype(self) -> np.dtype:
return np.float32

@staticmethod
def empty_sigmas() -> "SigmasTuple":
return SigmasTuple()


class LogProbsTuple(_ActionTupleBase):
"""
An object whose fields correspond to the log probs of actions of different types.
Expand Down Expand Up @@ -116,3 +148,141 @@ def from_buffer(buff: AgentBuffer) -> "ActionLogProbs":
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1])
]
return ActionLogProbs(continuous, discrete, None)


class ActionMus(NamedTuple):
continuous_tensor: torch.Tensor
discrete_list: Optional[List[torch.Tensor]]
all_discrete_list: Optional[List[torch.Tensor]]

@property
def discrete_tensor(self):
"""
Returns the discrete log probs list as a stacked tensor
"""
return torch.stack(self.discrete_list, dim=-1)

@property
def all_discrete_tensor(self):
"""
Returns the discrete log probs of each branch as a tensor
"""
return torch.cat(self.all_discrete_list, dim=1)

def to_mus_tuple(self) -> MusTuple:
mus_tuple = MusTuple()
if self.continuous_tensor is not None:
continuous = ModelUtils.to_numpy(self.continuous_tensor)
mus_tuple.add_continuous(continuous)
if self.discrete_list is not None:
discrete = ModelUtils.to_numpy(self.discrete_tensor)
mus_tuple.add_discrete(discrete)
return mus_tuple

def _to_tensor_list(self) -> List[torch.Tensor]:
"""
Returns the tensors in the ActionLogProbs as a flat List of torch Tensors. This
is private and serves as a utility for self.flatten()
"""
tensor_list: List[torch.Tensor] = []
if self.continuous_tensor is not None:
tensor_list.append(self.continuous_tensor)
if self.discrete_list is not None:
tensor_list.append(self.discrete_tensor)
return tensor_list

def flatten(self) -> torch.Tensor:
"""
A utility method that returns all log probs in ActionLogProbs as a flattened tensor.
This is useful for algorithms like PPO which can treat all log probs in the same way.
"""
return torch.cat(self._to_tensor_list(), dim=1)

@staticmethod
def from_buffer(buff: AgentBuffer) -> "ActionMus":
"""
A static method that accesses continuous and discrete log probs fields in an AgentBuffer
and constructs the corresponding ActionLogProbs from the retrieved np arrays.
"""
continuous: torch.Tensor = None
discrete: List[torch.Tensor] = None # type: ignore

if BufferKey.CONTINUOUS_MUS in buff:
continuous = ModelUtils.list_to_tensor(buff[BufferKey.CONTINUOUS_MUS])
if BufferKey.DISCRETE_MUS in buff:
discrete_tensor = ModelUtils.list_to_tensor(buff[BufferKey.DISCRETE_MUS])
# This will keep discrete_list = None which enables flatten()
if discrete_tensor.shape[1] > 0:
discrete = [
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1])
]
return ActionMus(continuous, discrete, None)


class ActionSigmas(NamedTuple):
continuous_tensor: torch.Tensor
discrete_list: Optional[List[torch.Tensor]]
all_discrete_list: Optional[List[torch.Tensor]]

@property
def discrete_tensor(self):
"""
Returns the discrete log probs list as a stacked tensor
"""
return torch.stack(self.discrete_list, dim=-1)

@property
def all_discrete_tensor(self):
"""
Returns the discrete log probs of each branch as a tensor
"""
return torch.cat(self.all_discrete_list, dim=1)

def to_sigmas_tuple(self) -> SigmasTuple:
sigmas_tuple = SigmasTuple()
if self.continuous_tensor is not None:
continuous = ModelUtils.to_numpy(self.continuous_tensor)
sigmas_tuple.add_continuous(continuous)
if self.discrete_list is not None:
discrete = ModelUtils.to_numpy(self.discrete_tensor)
sigmas_tuple.add_discrete(discrete)
return sigmas_tuple

def _to_tensor_list(self) -> List[torch.Tensor]:
"""
Returns the tensors in the ActionLogProbs as a flat List of torch Tensors. This
is private and serves as a utility for self.flatten()
"""
tensor_list: List[torch.Tensor] = []
if self.continuous_tensor is not None:
tensor_list.append(self.continuous_tensor)
if self.discrete_list is not None:
tensor_list.append(self.discrete_tensor)
return tensor_list

def flatten(self) -> torch.Tensor:
"""
A utility method that returns all log probs in ActionLogProbs as a flattened tensor.
This is useful for algorithms like PPO which can treat all log probs in the same way.
"""
return torch.cat(self._to_tensor_list(), dim=1)

@staticmethod
def from_buffer(buff: AgentBuffer) -> "ActionSigmas":
"""
A static method that accesses continuous and discrete log probs fields in an AgentBuffer
and constructs the corresponding ActionLogProbs from the retrieved np arrays.
"""
continuous: torch.Tensor = None
discrete: List[torch.Tensor] = None # type: ignore

if BufferKey.CONTINUOUS_SIGMAS in buff:
continuous = ModelUtils.list_to_tensor(buff[BufferKey.CONTINUOUS_SIGMAS])
if BufferKey.DISCRETE_SIGMAS in buff:
discrete_tensor = ModelUtils.list_to_tensor(buff[BufferKey.DISCRETE_SIGMAS])
# This will keep discrete_list = None which enables flatten()
if discrete_tensor.shape[1] > 0:
discrete = [
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1])
]
return ActionSigmas(continuous, discrete, None)
28 changes: 24 additions & 4 deletions ml-agents/mlagents/trainers/torch_entities/action_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
MultiCategoricalDistribution,
)
from mlagents.trainers.torch_entities.agent_action import AgentAction
from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs
from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs, ActionMus, ActionSigmas
from mlagents_envs.base_env import ActionSpec


Expand Down Expand Up @@ -146,9 +146,25 @@ def _get_probs_and_entropy(
entropies = torch.cat(entropies_list, dim=1)
return action_log_probs, entropies

def _get_mus_and_sigmas(self, actions, dists):
continuous_mus: Optional[torch.Tensor] = None
continuous_sigmas: Optional[torch.Tensor] = None
discrete_mus: Optional[torch.Tensor] = None
discrete_sigmas: Optional[torch.Tensor] = None
all_discrete_mus: Optional[List[torch.Tensor]] = None
all_discrete_sigmas: Optional[List[torch.Tensor]] = None
if dists.continuous is not None:
continuous_mus = dists.continuous.mu()
continuous_sigmas = dists.continuous.sigma()
action_mus = ActionMus(continuous_mus, discrete_mus, all_discrete_mus)
action_sigmas = ActionSigmas(
continuous_sigmas, discrete_sigmas, all_discrete_sigmas
)
return action_mus, action_sigmas

def evaluate(
self, inputs: torch.Tensor, masks: torch.Tensor, actions: AgentAction
) -> Tuple[ActionLogProbs, torch.Tensor]:
) -> Tuple[ActionLogProbs, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Given actions and encoding from the network body, gets the distributions and
computes the log probabilites and entropies.
Expand All @@ -159,9 +175,12 @@ def evaluate(
"""
dists = self._get_dists(inputs, masks)
log_probs, entropies = self._get_probs_and_entropy(actions, dists)
# mus = dists.continuous.deterministic_sample()
mus = dists.continuous.mu()
sigmas = dists.continuous.sigma()
# Use the sum of entropy across actions, not the mean
entropy_sum = torch.sum(entropies, dim=1)
return log_probs, entropy_sum
return log_probs, entropy_sum, mus, sigmas

def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -228,4 +247,5 @@ def forward(
log_probs, entropies = self._get_probs_and_entropy(actions, dists)
# Use the sum of entropy across actions, not the mean
entropy_sum = torch.sum(entropies, dim=1)
return (actions, log_probs, entropy_sum)
mus, sigmas = self._get_mus_and_sigmas(actions, dists)
return (actions, log_probs, entropy_sum, mus, sigmas)
Loading
Loading