From a1480fb8e0cc54b3228e926356058af8391ff290 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Mon, 27 Jan 2025 11:53:43 +0800 Subject: [PATCH 1/5] polish(pu): delete unused enable_fast_timestep argument --- ding/model/template/collaq.py | 6 +- ding/model/template/coma.py | 2 +- ding/model/template/q_learning.py | 130 +++++++++++++++++------------- ding/model/template/qmix.py | 2 +- ding/model/template/qtran.py | 2 +- ding/model/template/wqmix.py | 4 +- ding/policy/ngu.py | 6 +- ding/policy/r2d2.py | 6 +- ding/policy/r2d2_collect_traj.py | 6 +- ding/policy/r2d3.py | 8 +- 10 files changed, 93 insertions(+), 79 deletions(-) diff --git a/ding/model/template/collaq.py b/ding/model/template/collaq.py index 9872d0684a..fe2437f7f5 100644 --- a/ding/model/template/collaq.py +++ b/ding/model/template/collaq.py @@ -415,21 +415,21 @@ def forward(self, data: dict, single_step: bool = True) -> dict: { 'obs': agent_state, 'prev_state': colla_prev_state, - 'enable_fast_timestep': True + } ) colla_alone_output = self._q_network( { 'obs': agent_alone_padding_state, 'prev_state': colla_alone_prev_state, - 'enable_fast_timestep': True + } ) alone_output = self._q_alone_network( { 'obs': agent_alone_state, 'prev_state': alone_prev_state, - 'enable_fast_timestep': True + } ) diff --git a/ding/model/template/coma.py b/ding/model/template/coma.py index 02eb286e84..5ed7973635 100644 --- a/ding/model/template/coma.py +++ b/ding/model/template/coma.py @@ -70,7 +70,7 @@ def forward(self, inputs: Dict) -> Dict: T, B, A = agent_state.shape[:3] agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:]) prev_state = reduce(lambda x, y: x + y, prev_state) - output = self.main({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True}) + output = self.main({'obs': agent_state, 'prev_state': prev_state}) logit, next_state = output['logit'], output['next_state'] next_state, _ = list_split(next_state, step=A) logit = logit.reshape(T, B, A, -1) diff --git a/ding/model/template/q_learning.py b/ding/model/template/q_learning.py index eaec8fc45b..66db7520e7 100644 --- a/ding/model/template/q_learning.py +++ b/ding/model/template/q_learning.py @@ -855,18 +855,23 @@ def reshape(d): class DRQN(nn.Module): """ Overview: - The neural network structure and computation graph of DRQN (DQN + RNN = DRQN) algorithm, which is the most \ - common DQN variant for sequential data and paratially observable environment. The DRQN is composed of three \ - parts: ``encoder``, ``head`` and ``rnn``. The ``encoder`` is used to extract the feature from various \ - observation, the ``rnn`` is used to process the sequential observation and other data, and the ``head`` is \ - used to compute the Q value of each action dimension. + The DRQN (Deep Recurrent Q-Network) is a neural network model combining DQN with RNN to handle sequential + data and partially observable environments. It consists of three main components: ``encoder``, ``rnn``, + and ``head``. + - **Encoder**: Extracts features from various observation inputs. + - **RNN**: Processes sequential observations and other data. + - **Head**: Computes Q-values for each action dimension. + Interfaces: - ``__init__``, ``forward``. + - ``__init__``: Initializes the DRQN model. + - ``forward``: Defines the forward computation graph. .. note:: - Current ``DRQN`` supports two types of encoder: ``FCEncoder`` and ``ConvEncoder``, two types of head: \ - ``DiscreteHead`` and ``DuelingHead``, three types of rnn: ``normal (LSTM with LayerNorm)``, ``pytorch`` and \ - ``gru``. You can customize your own encoder, rnn or head by inheriting this class. + The current implementation supports: + - Two encoder types: ``FCEncoder`` and ``ConvEncoder``. + - Two head types: ``DiscreteHead`` and ``DuelingHead``. + - Three RNN types: ``normal (LSTM with LayerNorm)``, ``pytorch`` (PyTorch's native LSTM), and ``gru``. + You can extend the model by customizing your own encoder, RNN, or head by inheriting this class. """ def __init__( @@ -884,43 +889,47 @@ def __init__( ) -> None: """ Overview: - Initialize the DRQN Model according to the corresponding input arguments. + Initializes the DRQN model with specified parameters. Arguments: - - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84]. - - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3]. - - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \ - the last element must match ``head_hidden_size``. - - dueling (:obj:`Optional[bool]`): Whether choose ``DuelingHead`` or ``DiscreteHead (default)``. - - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network, defaults to None, \ - then it will be set to the last element of ``encoder_hidden_size_list``. - - head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output. - - lstm_type (:obj:`Optional[str]`): The type of RNN module, now support ['normal', 'pytorch', 'gru']. - - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \ - if ``None`` then default set it to ``nn.ReLU()``. - - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ - ``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN'] - - res_link (:obj:`bool`): Whether to enable the residual link, which is the skip connnection between \ - single frame data and the sequential data, defaults to False. + - obs_shape (:obj:`Union[int, SequenceType]`): Shape of the observation space, e.g., 8 or [4, 84, 84]. + - action_shape (:obj:`Union[int, SequenceType]`): Shape of the action space, e.g., 6 or [2, 3, 3]. + - encoder_hidden_size_list (:obj:`SequenceType`): List of hidden sizes for the encoder. The last element \ + must match ``head_hidden_size``. + - dueling (:obj:`Optional[bool]`): Use ``DuelingHead`` if True, otherwise use ``DiscreteHead``. + - head_hidden_size (:obj:`Optional[int]`): Hidden size for the head network. Defaults to the last \ + element of ``encoder_hidden_size_list`` if None. + - head_layer_num (:obj:`int`): Number of layers in the head network to compute Q-value outputs. + - lstm_type (:obj:`Optional[str]`): Type of RNN module. Supported types are ``normal``, ``pytorch``, \ + and ``gru``. + - activation (:obj:`Optional[nn.Module]`): Activation function used in the network. Defaults to ``nn.ReLU()``. + - norm_type (:obj:`Optional[str]`): Normalization type for the networks. Supported types are: \ + ['BN', 'IN', 'SyncBN', 'LN']. See ``ding.torch_utils.fc_block`` for more details. + - res_link (:obj:`bool`): Enables residual connections between single-frame data and sequential data. \ + Defaults to False. """ super(DRQN, self).__init__() - # For compatibility: 1, (1, ), [4, 32, 32] + # Compatibility for obs_shape/action_shape: Handles scalar, tuple, or multi-dimensional inputs. obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape) if head_hidden_size is None: head_hidden_size = encoder_hidden_size_list[-1] - # FC Encoder + + # Encoder: Determines the encoder type based on the observation shape. if isinstance(obs_shape, int) or len(obs_shape) == 1: + # FC Encoder self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) - # Conv Encoder elif len(obs_shape) == 3: + # Conv Encoder self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) else: raise RuntimeError( - "not support obs_shape for pre-defined encoder: {}, please customize your own DRQN".format(obs_shape) + f"Unsupported obs_shape for pre-defined encoder: {obs_shape}. Please customize your own DRQN." ) - # LSTM Type + + # RNN: Initializes the RNN module based on the specified lstm_type. self.rnn = get_lstm(lstm_type, input_size=head_hidden_size, hidden_size=head_hidden_size) self.res_link = res_link - # Head Type + + # Head: Determines the head type (Dueling or Discrete) and its configuration. if dueling: head_cls = DuelingHead else: @@ -943,31 +952,35 @@ def __init__( def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps: Optional[list] = None) -> Dict: """ Overview: - DRQN forward computation graph, input observation tensor to predict q_value. + Defines the forward pass of the DRQN model. Takes observation and previous RNN states as inputs \ + and predicts Q-values. + Arguments: - - inputs (:obj:`torch.Tensor`): The dict of input data, including observation and previous rnn state. - - inference: (:obj:'bool'): Whether to enable inference forward mode, if True, we unroll the one timestep \ - transition, otherwise, we unroll the eentire sequence transitions. - - saved_state_timesteps: (:obj:'Optional[list]'): When inference is False, we unroll the sequence \ - transitions, then we would use this list to indicate how to save and return hidden state. - ArgumentsKeys: - - obs (:obj:`torch.Tensor`): The raw observation tensor. - - prev_state (:obj:`list`): The previous rnn state tensor, whose structure depends on ``lstm_type``. + - inputs (:obj:`Dict`): Input data dictionary containing observation and previous RNN state. + - inference (:obj:`bool`): If True, unrolls one timestep (used during evaluation). If False, unrolls \ + the entire sequence (used during training). + - saved_state_timesteps (:obj:`Optional[list]`): When inference is False, specifies the timesteps \ + whose hidden states are saved and returned. + Input Keys: + - obs (:obj:`torch.Tensor`): Raw observation tensor. + - prev_state (:obj:`list`): Previous RNN state tensor, structure depends on ``lstm_type``. + Returns: - - outputs (:obj:`Dict`): The output of DRQN's forward, including logit (q_value) and next state. - ReturnsKeys: - - logit (:obj:`torch.Tensor`): Discrete Q-value output of each possible action dimension. - - next_state (:obj:`list`): The next rnn state tensor, whose structure depends on ``lstm_type``. + - outputs (:obj:`Dict`): Dictionary containing the Q-value logits and next RNN state. + Output Keys: + - logit (:obj:`torch.Tensor`): Discrete Q-value output for each action dimension. + - next_state (:obj:`list`): Next RNN state tensor. Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape`` - - logit (:obj:`torch.Tensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape`` + - obs (:obj:`torch.Tensor`): :math:`(B, N)` where B is batch size and N is ``obs_shape``. + - logit (:obj:`torch.Tensor`): :math:`(B, M)` where B is batch size and M is ``action_shape``. + Examples: - >>> # Init input's Keys: + >>> # Initialize input keys >>> prev_state = [[torch.randn(1, 1, 64) for __ in range(2)] for _ in range(4)] # B=4 >>> obs = torch.randn(4,64) >>> model = DRQN(64, 64) # arguments: 'obs_shape' and 'action_shape' >>> outputs = model({'obs': inputs, 'prev_state': prev_state}, inference=True) - >>> # Check outputs's Keys + >>> # Validate output keys and shapes >>> assert isinstance(outputs, dict) >>> assert outputs['logit'].shape == (4, 64) >>> assert len(outputs['next_state']) == 4 @@ -976,9 +989,9 @@ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps: """ x, prev_state = inputs['obs'], inputs['prev_state'] - # for both inference and other cases, the network structure is encoder -> rnn network -> head - # the difference is inference take the data with seq_len=1 (or T = 1) - # NOTE(rjy): in most situations, set inference=True when evaluate and inference=False when training + # Forward pass: Encoder -> RNN -> Head + # in most situations, set inference=True when evaluate and inference=False when training + # Inference mode: Processes one timestep (seq_len=1). if inference: x = self.encoder(x) if self.res_link: @@ -992,6 +1005,7 @@ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps: x = self.head(x) x['next_state'] = next_state return x + # Training mode: Processes the entire sequence. else: # In order to better explain why rnn needs saved_state and which states need to be stored, # let's take r2d2 as an example @@ -999,20 +1013,20 @@ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps: # 1) data['burnin_nstep_obs'] = data['obs'][:bs + self._nstep] # 2) data['main_obs'] = data['obs'][bs:-self._nstep] # 3) data['target_obs'] = data['obs'][bs + self._nstep:] - # NOTE(rjy): (T, B, N) or (T, B, C, H, W) - assert len(x.shape) in [3, 5], x.shape + assert len(x.shape) in [3, 5], f"Expected shape (T, B, N) or (T, B, C, H, W), got {x.shape}" x = parallel_wrapper(self.encoder)(x) # (T, B, N) if self.res_link: a = x - # NOTE(rjy) lstm_embedding stores all hidden_state + # lstm_embedding stores all hidden_state lstm_embedding = [] # TODO(nyz) how to deal with hidden_size key-value hidden_state_list = [] + if saved_state_timesteps is not None: saved_state = [] - for t in range(x.shape[0]): # T timesteps - # NOTE(rjy) use x[t:t+1] but not x[t] can keep original dimension - output, prev_state = self.rnn(x[t:t + 1], prev_state) # output: (1,B, head_hidden_size) + for t in range(x.shape[0]): # Iterate over timesteps (T). + # use x[t:t+1] but not x[t] can keep the original dimension + output, prev_state = self.rnn(x[t:t + 1], prev_state) # RNN step output: (1, B, hidden_size) if saved_state_timesteps is not None and t + 1 in saved_state_timesteps: saved_state.append(prev_state) lstm_embedding.append(output) @@ -1023,7 +1037,7 @@ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps: if self.res_link: x = x + a x = parallel_wrapper(self.head)(x) # (T, B, action_shape) - # NOTE(rjy): x['next_state'] is the hidden state of the last timestep inputted to lstm + # x['next_state'] is the hidden state of the last timestep inputted to lstm # the last timestep state including the hidden state (h) and the cell state (c) # shape: {list: B{dict: 2{Tensor:(1, 1, head_hidden_size}}} x['next_state'] = prev_state diff --git a/ding/model/template/qmix.py b/ding/model/template/qmix.py index b5cde0806b..e9f2475819 100644 --- a/ding/model/template/qmix.py +++ b/ding/model/template/qmix.py @@ -227,7 +227,7 @@ def forward(self, data: dict, single_step: bool = True) -> dict: ), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0])) prev_state = reduce(lambda x, y: x + y, prev_state) agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:]) - output = self._q_network({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True}) + output = self._q_network({'obs': agent_state, 'prev_state': prev_state}) agent_q, next_state = output['logit'], output['next_state'] next_state, _ = list_split(next_state, step=A) agent_q = agent_q.reshape(T, B, A, -1) diff --git a/ding/model/template/qtran.py b/ding/model/template/qtran.py index 6e627f1d15..9114245f13 100644 --- a/ding/model/template/qtran.py +++ b/ding/model/template/qtran.py @@ -100,7 +100,7 @@ def forward(self, data: dict, single_step: bool = True) -> dict: ), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0])) prev_state = reduce(lambda x, y: x + y, prev_state) agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:]) - output = self._q_network({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True}) + output = self._q_network({'obs': agent_state, 'prev_state': prev_state}) agent_q, next_state = output['logit'], output['next_state'] next_state, _ = list_split(next_state, step=A) agent_q = agent_q.reshape(T, B, A, -1) diff --git a/ding/model/template/wqmix.py b/ding/model/template/wqmix.py index f80aa25d4a..13a3335c11 100644 --- a/ding/model/template/wqmix.py +++ b/ding/model/template/wqmix.py @@ -177,7 +177,7 @@ def forward(self, data: dict, single_step: bool = True, q_star: bool = False) -> { 'obs': agent_state, 'prev_state': prev_state, - 'enable_fast_timestep': True + } ) # here is the forward pass of the agent networks of Q_star agent_q, next_state = output['logit'], output['next_state'] @@ -223,7 +223,7 @@ def forward(self, data: dict, single_step: bool = True, q_star: bool = False) -> { 'obs': agent_state, 'prev_state': prev_state, - 'enable_fast_timestep': True + } ) # here is the forward pass of the agent networks of Q agent_q, next_state = output['logit'], output['next_state'] diff --git a/ding/policy/ngu.py b/ding/policy/ngu.py index 95fe2dd82a..1898b83247 100644 --- a/ding/policy/ngu.py +++ b/ding/policy/ngu.py @@ -290,7 +290,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: 'action': data['burnin_nstep_action'], 'reward': data['burnin_nstep_reward'], 'beta': data['burnin_nstep_beta'], - 'enable_fast_timestep': True + } tmp = self._learn_model.forward( inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep] @@ -304,7 +304,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: 'action': data['main_action'], 'reward': data['main_reward'], 'beta': data['main_beta'], - 'enable_fast_timestep': True + } self._learn_model.reset(data_id=None, state=tmp['saved_state'][0]) q_value = self._learn_model.forward(inputs)['logit'] @@ -317,7 +317,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: 'action': data['target_action'], 'reward': data['target_reward'], 'beta': data['target_beta'], - 'enable_fast_timestep': True + } with torch.no_grad(): target_q_value = self._target_model.forward(next_inputs)['logit'] diff --git a/ding/policy/r2d2.py b/ding/policy/r2d2.py index 0726c2c820..420d238cad 100644 --- a/ding/policy/r2d2.py +++ b/ding/policy/r2d2.py @@ -317,7 +317,7 @@ def _forward_learn(self, data: List[List[Dict[str, Any]]]) -> Dict[str, Any]: if len(data['burnin_nstep_obs']) != 0: with torch.no_grad(): - inputs = {'obs': data['burnin_nstep_obs'], 'enable_fast_timestep': True} + inputs = {'obs': data['burnin_nstep_obs']} burnin_output = self._learn_model.forward( inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep] ) # keys include 'logit', 'hidden_state' 'saved_state', \ @@ -327,12 +327,12 @@ def _forward_learn(self, data: List[List[Dict[str, Any]]]) -> Dict[str, Any]: ) self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][0]) - inputs = {'obs': data['main_obs'], 'enable_fast_timestep': True} + inputs = {'obs': data['main_obs']} q_value = self._learn_model.forward(inputs)['logit'] self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][1]) self._target_model.reset(data_id=None, state=burnin_output_target['saved_state'][1]) - next_inputs = {'obs': data['target_obs'], 'enable_fast_timestep': True} + next_inputs = {'obs': data['target_obs']} with torch.no_grad(): target_q_value = self._target_model.forward(next_inputs)['logit'] # argmax_action double_dqn diff --git a/ding/policy/r2d2_collect_traj.py b/ding/policy/r2d2_collect_traj.py index 2cf312010f..5dbe7c3eb1 100644 --- a/ding/policy/r2d2_collect_traj.py +++ b/ding/policy/r2d2_collect_traj.py @@ -266,7 +266,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: if len(data['burnin_nstep_obs']) != 0: with torch.no_grad(): - inputs = {'obs': data['burnin_nstep_obs'], 'enable_fast_timestep': True} + inputs = {'obs': data['burnin_nstep_obs']} burnin_output = self._learn_model.forward( inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep] ) @@ -275,12 +275,12 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: ) self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][0]) - inputs = {'obs': data['main_obs'], 'enable_fast_timestep': True} + inputs = {'obs': data['main_obs']} q_value = self._learn_model.forward(inputs)['logit'] self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][1]) self._target_model.reset(data_id=None, state=burnin_output_target['saved_state'][1]) - next_inputs = {'obs': data['target_obs'], 'enable_fast_timestep': True} + next_inputs = {'obs': data['target_obs']} with torch.no_grad(): target_q_value = self._target_model.forward(next_inputs)['logit'] # argmax_action double_dqn diff --git a/ding/policy/r2d3.py b/ding/policy/r2d3.py index feb8362921..e4f3e56a6a 100644 --- a/ding/policy/r2d3.py +++ b/ding/policy/r2d3.py @@ -275,7 +275,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: if len(data['burnin_nstep_obs']) != 0: with torch.no_grad(): - inputs = {'obs': data['burnin_nstep_obs'], 'enable_fast_timestep': True} + inputs = {'obs': data['burnin_nstep_obs']} burnin_output = self._learn_model.forward( inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep, self._burnin_step + 1] @@ -286,14 +286,14 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: ) self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][0]) - inputs = {'obs': data['main_obs'], 'enable_fast_timestep': True} + inputs = {'obs': data['main_obs']} q_value = self._learn_model.forward(inputs)['logit'] # n-step self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][1]) self._target_model.reset(data_id=None, state=burnin_output_target['saved_state'][1]) - next_inputs = {'obs': data['target_obs'], 'enable_fast_timestep': True} + next_inputs = {'obs': data['target_obs']} with torch.no_grad(): target_q_value = self._target_model.forward(next_inputs)['logit'] # argmax_action double_dqn @@ -303,7 +303,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][2]) self._target_model.reset(data_id=None, state=burnin_output_target['saved_state'][2]) - next_inputs_one_step = {'obs': data['target_obs_one_step'], 'enable_fast_timestep': True} + next_inputs_one_step = {'obs': data['target_obs_one_step']} with torch.no_grad(): target_q_value_one_step = self._target_model.forward(next_inputs_one_step)['logit'] # argmax_action double_dqn From ecfe91a7c7b52dd38cefb389d8bbd94d1593bba8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Mon, 27 Jan 2025 11:59:57 +0800 Subject: [PATCH 2/5] polish(pu): delete unused empty lines --- ding/model/template/collaq.py | 3 --- ding/model/template/wqmix.py | 2 -- ding/policy/ngu.py | 3 --- 3 files changed, 8 deletions(-) diff --git a/ding/model/template/collaq.py b/ding/model/template/collaq.py index fe2437f7f5..0fa4a2c841 100644 --- a/ding/model/template/collaq.py +++ b/ding/model/template/collaq.py @@ -415,21 +415,18 @@ def forward(self, data: dict, single_step: bool = True) -> dict: { 'obs': agent_state, 'prev_state': colla_prev_state, - } ) colla_alone_output = self._q_network( { 'obs': agent_alone_padding_state, 'prev_state': colla_alone_prev_state, - } ) alone_output = self._q_alone_network( { 'obs': agent_alone_state, 'prev_state': alone_prev_state, - } ) diff --git a/ding/model/template/wqmix.py b/ding/model/template/wqmix.py index 13a3335c11..251c025f9e 100644 --- a/ding/model/template/wqmix.py +++ b/ding/model/template/wqmix.py @@ -177,7 +177,6 @@ def forward(self, data: dict, single_step: bool = True, q_star: bool = False) -> { 'obs': agent_state, 'prev_state': prev_state, - } ) # here is the forward pass of the agent networks of Q_star agent_q, next_state = output['logit'], output['next_state'] @@ -223,7 +222,6 @@ def forward(self, data: dict, single_step: bool = True, q_star: bool = False) -> { 'obs': agent_state, 'prev_state': prev_state, - } ) # here is the forward pass of the agent networks of Q agent_q, next_state = output['logit'], output['next_state'] diff --git a/ding/policy/ngu.py b/ding/policy/ngu.py index 1898b83247..51fabda807 100644 --- a/ding/policy/ngu.py +++ b/ding/policy/ngu.py @@ -290,7 +290,6 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: 'action': data['burnin_nstep_action'], 'reward': data['burnin_nstep_reward'], 'beta': data['burnin_nstep_beta'], - } tmp = self._learn_model.forward( inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep] @@ -304,7 +303,6 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: 'action': data['main_action'], 'reward': data['main_reward'], 'beta': data['main_beta'], - } self._learn_model.reset(data_id=None, state=tmp['saved_state'][0]) q_value = self._learn_model.forward(inputs)['logit'] @@ -317,7 +315,6 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: 'action': data['target_action'], 'reward': data['target_reward'], 'beta': data['target_beta'], - } with torch.no_grad(): target_q_value = self._target_model.forward(next_inputs)['logit'] From a9152a6b96406206cdb8821adc00cffc4bef998c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Mon, 27 Jan 2025 11:59:57 +0800 Subject: [PATCH 3/5] polish(pu): delete unused empty lines --- ding/model/template/collaq.py | 23 ++++++++--------------- ding/model/template/wqmix.py | 2 -- ding/policy/ngu.py | 3 --- 3 files changed, 8 insertions(+), 20 deletions(-) diff --git a/ding/model/template/collaq.py b/ding/model/template/collaq.py index fe2437f7f5..3a1e4fec08 100644 --- a/ding/model/template/collaq.py +++ b/ding/model/template/collaq.py @@ -411,27 +411,20 @@ def forward(self, data: dict, single_step: bool = True) -> dict: agent_alone_state = agent_alone_state.reshape(T, -1, *agent_alone_state.shape[3:]) agent_alone_padding_state = agent_alone_padding_state.reshape(T, -1, *agent_alone_padding_state.shape[3:]) - colla_output = self._q_network( - { - 'obs': agent_state, - 'prev_state': colla_prev_state, - - } - ) + colla_output = self._q_network({ + 'obs': agent_state, + 'prev_state': colla_prev_state, + }) colla_alone_output = self._q_network( { 'obs': agent_alone_padding_state, 'prev_state': colla_alone_prev_state, - - } - ) - alone_output = self._q_alone_network( - { - 'obs': agent_alone_state, - 'prev_state': alone_prev_state, - } ) + alone_output = self._q_alone_network({ + 'obs': agent_alone_state, + 'prev_state': alone_prev_state, + }) agent_alone_q, alone_next_state = alone_output['logit'], alone_output['next_state'] agent_colla_alone_q, colla_alone_next_state = colla_alone_output['logit'], colla_alone_output['next_state'] diff --git a/ding/model/template/wqmix.py b/ding/model/template/wqmix.py index 13a3335c11..251c025f9e 100644 --- a/ding/model/template/wqmix.py +++ b/ding/model/template/wqmix.py @@ -177,7 +177,6 @@ def forward(self, data: dict, single_step: bool = True, q_star: bool = False) -> { 'obs': agent_state, 'prev_state': prev_state, - } ) # here is the forward pass of the agent networks of Q_star agent_q, next_state = output['logit'], output['next_state'] @@ -223,7 +222,6 @@ def forward(self, data: dict, single_step: bool = True, q_star: bool = False) -> { 'obs': agent_state, 'prev_state': prev_state, - } ) # here is the forward pass of the agent networks of Q agent_q, next_state = output['logit'], output['next_state'] diff --git a/ding/policy/ngu.py b/ding/policy/ngu.py index 1898b83247..51fabda807 100644 --- a/ding/policy/ngu.py +++ b/ding/policy/ngu.py @@ -290,7 +290,6 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: 'action': data['burnin_nstep_action'], 'reward': data['burnin_nstep_reward'], 'beta': data['burnin_nstep_beta'], - } tmp = self._learn_model.forward( inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep] @@ -304,7 +303,6 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: 'action': data['main_action'], 'reward': data['main_reward'], 'beta': data['main_beta'], - } self._learn_model.reset(data_id=None, state=tmp['saved_state'][0]) q_value = self._learn_model.forward(inputs)['logit'] @@ -317,7 +315,6 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: 'action': data['target_action'], 'reward': data['target_reward'], 'beta': data['target_beta'], - } with torch.no_grad(): target_q_value = self._target_model.forward(next_inputs)['logit'] From a31a951692d0a51cfa7b5f7aba4d0b6f79854f28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Mon, 27 Jan 2025 15:35:00 +0800 Subject: [PATCH 4/5] style(pu): polish comment's format --- ding/model/template/q_learning.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/ding/model/template/q_learning.py b/ding/model/template/q_learning.py index 66db7520e7..6954a52a28 100644 --- a/ding/model/template/q_learning.py +++ b/ding/model/template/q_learning.py @@ -863,8 +863,7 @@ class DRQN(nn.Module): - **Head**: Computes Q-values for each action dimension. Interfaces: - - ``__init__``: Initializes the DRQN model. - - ``forward``: Defines the forward computation graph. + ``__init__``, ``forward``. .. note:: The current implementation supports: @@ -889,23 +888,23 @@ def __init__( ) -> None: """ Overview: - Initializes the DRQN model with specified parameters. + Initialize the DRQN model with specified parameters. Arguments: - obs_shape (:obj:`Union[int, SequenceType]`): Shape of the observation space, e.g., 8 or [4, 84, 84]. - action_shape (:obj:`Union[int, SequenceType]`): Shape of the action space, e.g., 6 or [2, 3, 3]. - encoder_hidden_size_list (:obj:`SequenceType`): List of hidden sizes for the encoder. The last element \ - must match ``head_hidden_size``. + must match ``head_hidden_size``. - dueling (:obj:`Optional[bool]`): Use ``DuelingHead`` if True, otherwise use ``DiscreteHead``. - head_hidden_size (:obj:`Optional[int]`): Hidden size for the head network. Defaults to the last \ - element of ``encoder_hidden_size_list`` if None. + element of ``encoder_hidden_size_list`` if None. - head_layer_num (:obj:`int`): Number of layers in the head network to compute Q-value outputs. - lstm_type (:obj:`Optional[str]`): Type of RNN module. Supported types are ``normal``, ``pytorch``, \ - and ``gru``. + and ``gru``. - activation (:obj:`Optional[nn.Module]`): Activation function used in the network. Defaults to ``nn.ReLU()``. - norm_type (:obj:`Optional[str]`): Normalization type for the networks. Supported types are: \ - ['BN', 'IN', 'SyncBN', 'LN']. See ``ding.torch_utils.fc_block`` for more details. + ['BN', 'IN', 'SyncBN', 'LN']. See ``ding.torch_utils.fc_block`` for more details. - res_link (:obj:`bool`): Enables residual connections between single-frame data and sequential data. \ - Defaults to False. + Defaults to False. """ super(DRQN, self).__init__() # Compatibility for obs_shape/action_shape: Handles scalar, tuple, or multi-dimensional inputs. @@ -954,26 +953,23 @@ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps: Overview: Defines the forward pass of the DRQN model. Takes observation and previous RNN states as inputs \ and predicts Q-values. - Arguments: - inputs (:obj:`Dict`): Input data dictionary containing observation and previous RNN state. - inference (:obj:`bool`): If True, unrolls one timestep (used during evaluation). If False, unrolls \ - the entire sequence (used during training). + the entire sequence (used during training). - saved_state_timesteps (:obj:`Optional[list]`): When inference is False, specifies the timesteps \ - whose hidden states are saved and returned. - Input Keys: + whose hidden states are saved and returned. + ArgumentsKeys: - obs (:obj:`torch.Tensor`): Raw observation tensor. - prev_state (:obj:`list`): Previous RNN state tensor, structure depends on ``lstm_type``. - Returns: - - outputs (:obj:`Dict`): Dictionary containing the Q-value logits and next RNN state. - Output Keys: + - outputs (:obj:`Dict`): The output of DRQN's forward, including logit (q_value) and next state. + ReturnsKeys: - logit (:obj:`torch.Tensor`): Discrete Q-value output for each action dimension. - next_state (:obj:`list`): Next RNN state tensor. Shapes: - obs (:obj:`torch.Tensor`): :math:`(B, N)` where B is batch size and N is ``obs_shape``. - logit (:obj:`torch.Tensor`): :math:`(B, M)` where B is batch size and M is ``action_shape``. - Examples: >>> # Initialize input keys >>> prev_state = [[torch.randn(1, 1, 64) for __ in range(2)] for _ in range(4)] # B=4 From 6634bcf8f7215b9436f42a063a1ced83088a567f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Mon, 27 Jan 2025 15:38:55 +0800 Subject: [PATCH 5/5] style(pu): polish comment's format --- ding/model/template/q_learning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ding/model/template/q_learning.py b/ding/model/template/q_learning.py index 6954a52a28..c81479f06b 100644 --- a/ding/model/template/q_learning.py +++ b/ding/model/template/q_learning.py @@ -900,7 +900,8 @@ def __init__( - head_layer_num (:obj:`int`): Number of layers in the head network to compute Q-value outputs. - lstm_type (:obj:`Optional[str]`): Type of RNN module. Supported types are ``normal``, ``pytorch``, \ and ``gru``. - - activation (:obj:`Optional[nn.Module]`): Activation function used in the network. Defaults to ``nn.ReLU()``. + - activation (:obj:`Optional[nn.Module]`): Activation function used in the network. Defaults to \ + ``nn.ReLU()``. - norm_type (:obj:`Optional[str]`): Normalization type for the networks. Supported types are: \ ['BN', 'IN', 'SyncBN', 'LN']. See ``ding.torch_utils.fc_block`` for more details. - res_link (:obj:`bool`): Enables residual connections between single-frame data and sequential data. \