diff --git a/pearl/utils/functional_utils/learning/critic_utils.py b/pearl/utils/functional_utils/learning/critic_utils.py index a52b675d..51d9d7ab 100644 --- a/pearl/utils/functional_utils/learning/critic_utils.py +++ b/pearl/utils/functional_utils/learning/critic_utils.py @@ -194,4 +194,5 @@ def twin_critic_action_value_loss( loss = criterion( q_1.reshape_as(expected_target_batch), expected_target_batch.detach() ) + criterion(q_2.reshape_as(expected_target_batch), expected_target_batch.detach()) + loss = loss / 2.0 return loss, q_1, q_2