diff --git a/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py b/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py index f435a0a8..d90d48e7 100644 --- a/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py +++ b/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type, Union import torch from pearl.action_representation_modules.action_representation_module import ( @@ -95,6 +95,10 @@ def __init__( temperature_advantage_weighted_regression: float = 0.5, advantage_clamp: float = 100.0, action_representation_module: Optional[ActionRepresentationModule] = None, + actor_network_instance: Optional[ActorNetwork] = None, + critic_network_instance: Optional[ + Union[ValueNetwork, QValueNetwork, torch.nn.Module] + ] = None, ) -> None: super(ImplicitQLearning, self).__init__( state_dim=state_dim, @@ -120,6 +124,8 @@ def __init__( is_action_continuous=action_space.is_continuous, # inferred from the action space on_policy=False, action_representation_module=action_representation_module, + actor_network_instance=actor_network_instance, + critic_network_instance=critic_network_instance, ) self._expectile = expectile diff --git a/pearl/utils/functional_utils/learning/critic_utils.py b/pearl/utils/functional_utils/learning/critic_utils.py index 51d9d7ab..09424b91 100644 --- a/pearl/utils/functional_utils/learning/critic_utils.py +++ b/pearl/utils/functional_utils/learning/critic_utils.py @@ -125,8 +125,12 @@ def update_critic_target_network( ) else: update_target_network( - target_network._model, - network._model, + ( + target_network._model + if hasattr(target_network, "_model") + else target_network + ), + network._model if hasattr(network, "_model") else network, tau=tau, )