Skip to content

Commit

Permalink
[BugFix] compatibility to new Composite dist log_prob/entropy APIs
Browse files Browse the repository at this point in the history
ghstack-source-id: a09b6c34000f57a66736bb9811ca3656c861ec0c
Pull Request resolved: #2435
  • Loading branch information
vmoens committed Sep 12, 2024
1 parent d40fa4f commit 36545af
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
5 changes: 5 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7565,6 +7565,7 @@ def _create_mock_actor(
"action1": (action_key, "action1"),
},
log_prob_key=sample_log_prob_key,
aggregate_probabilities=True,
)
module_out_keys = [
("params", "action1", "loc"),
Expand Down Expand Up @@ -7634,6 +7635,7 @@ def _create_mock_actor_value(
"action1": ("action", "action1"),
},
log_prob_key=sample_log_prob_key,
aggregate_probabilities=True,
)
module_out_keys = [
("params", "action1", "loc"),
Expand Down Expand Up @@ -7690,6 +7692,7 @@ def _create_mock_actor_value_shared(
"action1": ("action", "action1"),
},
log_prob_key=sample_log_prob_key,
aggregate_probabilities=True,
)
module_out_keys = [
("params", "action1", "loc"),
Expand Down Expand Up @@ -8627,6 +8630,7 @@ def _create_mock_actor(
"action1": (action_key, "action1"),
},
log_prob_key=sample_log_prob_key,
aggregate_probabilities=True,
)
module_out_keys = [
("params", "action1", "loc"),
Expand Down Expand Up @@ -8727,6 +8731,7 @@ def _create_mock_common_layer_setup(
"action1": ("action", "action1"),
},
log_prob_key=sample_log_prob_key,
aggregate_probabilities=True,
)
module_out_keys = [
("params", "action1", "loc"),
Expand Down
9 changes: 7 additions & 2 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,13 @@ def _log_probs(
if isinstance(action, torch.Tensor):
log_prob = dist.log_prob(action)
else:
tensordict = dist.log_prob(tensordict)
log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
maybe_log_prob = dist.log_prob(tensordict)
if not isinstance(maybe_log_prob, torch.Tensor):
# In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not
# be a tensor
log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob)
else:
log_prob = maybe_log_prob
log_prob = log_prob.unsqueeze(-1)
return log_prob, dist

Expand Down
11 changes: 8 additions & 3 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,13 @@ def _log_weight(
if isinstance(action, torch.Tensor):
log_prob = dist.log_prob(action)
else:
tensordict = dist.log_prob(tensordict)
log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
maybe_log_prob = dist.log_prob(tensordict)
if not isinstance(maybe_log_prob, torch.Tensor):
# In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not
# be a tensor
log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob)
else:
log_prob = maybe_log_prob

log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
kl_approx = (prev_log_prob - log_prob).unsqueeze(-1)
Expand Down Expand Up @@ -1130,7 +1135,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
x = previous_dist.sample((self.samples_mc_kl,))
previous_log_prob = previous_dist.log_prob(x)
current_log_prob = current_dist.log_prob(x)
if is_tensor_collection(x):
if is_tensor_collection(current_log_prob):
previous_log_prob = previous_log_prob.get(
self.tensor_keys.sample_log_prob
)
Expand Down

0 comments on commit 36545af

Please sign in to comment.