Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

polish(pu): delete unused enable_fast_timestep argument #855

Merged
merged 6 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 8 additions & 15 deletions ding/model/template/collaq.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,27 +411,20 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
agent_alone_state = agent_alone_state.reshape(T, -1, *agent_alone_state.shape[3:])
agent_alone_padding_state = agent_alone_padding_state.reshape(T, -1, *agent_alone_padding_state.shape[3:])

colla_output = self._q_network(
{
'obs': agent_state,
'prev_state': colla_prev_state,
'enable_fast_timestep': True
}
)
colla_output = self._q_network({
'obs': agent_state,
'prev_state': colla_prev_state,
})
colla_alone_output = self._q_network(
{
'obs': agent_alone_padding_state,
'prev_state': colla_alone_prev_state,
'enable_fast_timestep': True
}
)
alone_output = self._q_alone_network(
{
'obs': agent_alone_state,
'prev_state': alone_prev_state,
'enable_fast_timestep': True
}
)
alone_output = self._q_alone_network({
'obs': agent_alone_state,
'prev_state': alone_prev_state,
})

agent_alone_q, alone_next_state = alone_output['logit'], alone_output['next_state']
agent_colla_alone_q, colla_alone_next_state = colla_alone_output['logit'], colla_alone_output['next_state']
Expand Down
2 changes: 1 addition & 1 deletion ding/model/template/coma.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def forward(self, inputs: Dict) -> Dict:
T, B, A = agent_state.shape[:3]
agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:])
prev_state = reduce(lambda x, y: x + y, prev_state)
output = self.main({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True})
output = self.main({'obs': agent_state, 'prev_state': prev_state})
logit, next_state = output['logit'], output['next_state']
next_state, _ = list_split(next_state, step=A)
logit = logit.reshape(T, B, A, -1)
Expand Down
119 changes: 65 additions & 54 deletions ding/model/template/q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,18 +855,22 @@ def reshape(d):
class DRQN(nn.Module):
"""
Overview:
The neural network structure and computation graph of DRQN (DQN + RNN = DRQN) algorithm, which is the most \
common DQN variant for sequential data and paratially observable environment. The DRQN is composed of three \
parts: ``encoder``, ``head`` and ``rnn``. The ``encoder`` is used to extract the feature from various \
observation, the ``rnn`` is used to process the sequential observation and other data, and the ``head`` is \
used to compute the Q value of each action dimension.
The DRQN (Deep Recurrent Q-Network) is a neural network model combining DQN with RNN to handle sequential
data and partially observable environments. It consists of three main components: ``encoder``, ``rnn``,
and ``head``.
- **Encoder**: Extracts features from various observation inputs.
- **RNN**: Processes sequential observations and other data.
- **Head**: Computes Q-values for each action dimension.

Interfaces:
``__init__``, ``forward``.

.. note::
Current ``DRQN`` supports two types of encoder: ``FCEncoder`` and ``ConvEncoder``, two types of head: \
``DiscreteHead`` and ``DuelingHead``, three types of rnn: ``normal (LSTM with LayerNorm)``, ``pytorch`` and \
``gru``. You can customize your own encoder, rnn or head by inheriting this class.
The current implementation supports:
- Two encoder types: ``FCEncoder`` and ``ConvEncoder``.
- Two head types: ``DiscreteHead`` and ``DuelingHead``.
- Three RNN types: ``normal (LSTM with LayerNorm)``, ``pytorch`` (PyTorch's native LSTM), and ``gru``.
You can extend the model by customizing your own encoder, RNN, or head by inheriting this class.
"""

def __init__(
Expand All @@ -884,43 +888,48 @@ def __init__(
) -> None:
"""
Overview:
Initialize the DRQN Model according to the corresponding input arguments.
Initialize the DRQN model with specified parameters.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
- action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
the last element must match ``head_hidden_size``.
- dueling (:obj:`Optional[bool]`): Whether choose ``DuelingHead`` or ``DiscreteHead (default)``.
- head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network, defaults to None, \
then it will be set to the last element of ``encoder_hidden_size_list``.
- head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output.
- lstm_type (:obj:`Optional[str]`): The type of RNN module, now support ['normal', 'pytorch', 'gru'].
- activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
if ``None`` then default set it to ``nn.ReLU()``.
- norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']
- res_link (:obj:`bool`): Whether to enable the residual link, which is the skip connnection between \
single frame data and the sequential data, defaults to False.
- obs_shape (:obj:`Union[int, SequenceType]`): Shape of the observation space, e.g., 8 or [4, 84, 84].
- action_shape (:obj:`Union[int, SequenceType]`): Shape of the action space, e.g., 6 or [2, 3, 3].
- encoder_hidden_size_list (:obj:`SequenceType`): List of hidden sizes for the encoder. The last element \
must match ``head_hidden_size``.
- dueling (:obj:`Optional[bool]`): Use ``DuelingHead`` if True, otherwise use ``DiscreteHead``.
- head_hidden_size (:obj:`Optional[int]`): Hidden size for the head network. Defaults to the last \
element of ``encoder_hidden_size_list`` if None.
- head_layer_num (:obj:`int`): Number of layers in the head network to compute Q-value outputs.
- lstm_type (:obj:`Optional[str]`): Type of RNN module. Supported types are ``normal``, ``pytorch``, \
and ``gru``.
- activation (:obj:`Optional[nn.Module]`): Activation function used in the network. Defaults to \
``nn.ReLU()``.
- norm_type (:obj:`Optional[str]`): Normalization type for the networks. Supported types are: \
['BN', 'IN', 'SyncBN', 'LN']. See ``ding.torch_utils.fc_block`` for more details.
- res_link (:obj:`bool`): Enables residual connections between single-frame data and sequential data. \
Defaults to False.
"""
super(DRQN, self).__init__()
# For compatibility: 1, (1, ), [4, 32, 32]
# Compatibility for obs_shape/action_shape: Handles scalar, tuple, or multi-dimensional inputs.
obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
if head_hidden_size is None:
head_hidden_size = encoder_hidden_size_list[-1]
# FC Encoder

# Encoder: Determines the encoder type based on the observation shape.
if isinstance(obs_shape, int) or len(obs_shape) == 1:
# FC Encoder
self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
# Conv Encoder
elif len(obs_shape) == 3:
# Conv Encoder
self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
else:
raise RuntimeError(
"not support obs_shape for pre-defined encoder: {}, please customize your own DRQN".format(obs_shape)
f"Unsupported obs_shape for pre-defined encoder: {obs_shape}. Please customize your own DRQN."
)
# LSTM Type

# RNN: Initializes the RNN module based on the specified lstm_type.
self.rnn = get_lstm(lstm_type, input_size=head_hidden_size, hidden_size=head_hidden_size)
self.res_link = res_link
# Head Type

# Head: Determines the head type (Dueling or Discrete) and its configuration.
if dueling:
head_cls = DuelingHead
else:
Expand All @@ -943,31 +952,32 @@ def __init__(
def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps: Optional[list] = None) -> Dict:
"""
Overview:
DRQN forward computation graph, input observation tensor to predict q_value.
Defines the forward pass of the DRQN model. Takes observation and previous RNN states as inputs \
and predicts Q-values.
Arguments:
- inputs (:obj:`torch.Tensor`): The dict of input data, including observation and previous rnn state.
- inference: (:obj:'bool'): Whether to enable inference forward mode, if True, we unroll the one timestep \
transition, otherwise, we unroll the eentire sequence transitions.
- saved_state_timesteps: (:obj:'Optional[list]'): When inference is False, we unroll the sequence \
transitions, then we would use this list to indicate how to save and return hidden state.
- inputs (:obj:`Dict`): Input data dictionary containing observation and previous RNN state.
- inference (:obj:`bool`): If True, unrolls one timestep (used during evaluation). If False, unrolls \
the entire sequence (used during training).
- saved_state_timesteps (:obj:`Optional[list]`): When inference is False, specifies the timesteps \
whose hidden states are saved and returned.
ArgumentsKeys:
- obs (:obj:`torch.Tensor`): The raw observation tensor.
- prev_state (:obj:`list`): The previous rnn state tensor, whose structure depends on ``lstm_type``.
- obs (:obj:`torch.Tensor`): Raw observation tensor.
- prev_state (:obj:`list`): Previous RNN state tensor, structure depends on ``lstm_type``.
Returns:
- outputs (:obj:`Dict`): The output of DRQN's forward, including logit (q_value) and next state.
ReturnsKeys:
- logit (:obj:`torch.Tensor`): Discrete Q-value output of each possible action dimension.
- next_state (:obj:`list`): The next rnn state tensor, whose structure depends on ``lstm_type``.
- logit (:obj:`torch.Tensor`): Discrete Q-value output for each action dimension.
- next_state (:obj:`list`): Next RNN state tensor.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``
- logit (:obj:`torch.Tensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape``
- obs (:obj:`torch.Tensor`): :math:`(B, N)` where B is batch size and N is ``obs_shape``.
- logit (:obj:`torch.Tensor`): :math:`(B, M)` where B is batch size and M is ``action_shape``.
Examples:
>>> # Init input's Keys:
>>> # Initialize input keys
>>> prev_state = [[torch.randn(1, 1, 64) for __ in range(2)] for _ in range(4)] # B=4
>>> obs = torch.randn(4,64)
>>> model = DRQN(64, 64) # arguments: 'obs_shape' and 'action_shape'
>>> outputs = model({'obs': inputs, 'prev_state': prev_state}, inference=True)
>>> # Check outputs's Keys
>>> # Validate output keys and shapes
>>> assert isinstance(outputs, dict)
>>> assert outputs['logit'].shape == (4, 64)
>>> assert len(outputs['next_state']) == 4
Expand All @@ -976,9 +986,9 @@ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps:
"""

x, prev_state = inputs['obs'], inputs['prev_state']
# for both inference and other cases, the network structure is encoder -> rnn network -> head
# the difference is inference take the data with seq_len=1 (or T = 1)
# NOTE(rjy): in most situations, set inference=True when evaluate and inference=False when training
# Forward pass: Encoder -> RNN -> Head
# in most situations, set inference=True when evaluate and inference=False when training
# Inference mode: Processes one timestep (seq_len=1).
if inference:
x = self.encoder(x)
if self.res_link:
Expand All @@ -992,27 +1002,28 @@ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps:
x = self.head(x)
x['next_state'] = next_state
return x
# Training mode: Processes the entire sequence.
else:
# In order to better explain why rnn needs saved_state and which states need to be stored,
# let's take r2d2 as an example
# in r2d2,
# 1) data['burnin_nstep_obs'] = data['obs'][:bs + self._nstep]
# 2) data['main_obs'] = data['obs'][bs:-self._nstep]
# 3) data['target_obs'] = data['obs'][bs + self._nstep:]
# NOTE(rjy): (T, B, N) or (T, B, C, H, W)
assert len(x.shape) in [3, 5], x.shape
assert len(x.shape) in [3, 5], f"Expected shape (T, B, N) or (T, B, C, H, W), got {x.shape}"
x = parallel_wrapper(self.encoder)(x) # (T, B, N)
if self.res_link:
a = x
# NOTE(rjy) lstm_embedding stores all hidden_state
# lstm_embedding stores all hidden_state
lstm_embedding = []
# TODO(nyz) how to deal with hidden_size key-value
hidden_state_list = []

if saved_state_timesteps is not None:
saved_state = []
for t in range(x.shape[0]): # T timesteps
# NOTE(rjy) use x[t:t+1] but not x[t] can keep original dimension
output, prev_state = self.rnn(x[t:t + 1], prev_state) # output: (1,B, head_hidden_size)
for t in range(x.shape[0]): # Iterate over timesteps (T).
# use x[t:t+1] but not x[t] can keep the original dimension
output, prev_state = self.rnn(x[t:t + 1], prev_state) # RNN step output: (1, B, hidden_size)
if saved_state_timesteps is not None and t + 1 in saved_state_timesteps:
saved_state.append(prev_state)
lstm_embedding.append(output)
Expand All @@ -1023,7 +1034,7 @@ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps:
if self.res_link:
x = x + a
x = parallel_wrapper(self.head)(x) # (T, B, action_shape)
# NOTE(rjy): x['next_state'] is the hidden state of the last timestep inputted to lstm
# x['next_state'] is the hidden state of the last timestep inputted to lstm
# the last timestep state including the hidden state (h) and the cell state (c)
# shape: {list: B{dict: 2{Tensor:(1, 1, head_hidden_size}}}
x['next_state'] = prev_state
Expand Down
2 changes: 1 addition & 1 deletion ding/model/template/qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0]))
prev_state = reduce(lambda x, y: x + y, prev_state)
agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:])
output = self._q_network({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True})
output = self._q_network({'obs': agent_state, 'prev_state': prev_state})
agent_q, next_state = output['logit'], output['next_state']
next_state, _ = list_split(next_state, step=A)
agent_q = agent_q.reshape(T, B, A, -1)
Expand Down
2 changes: 1 addition & 1 deletion ding/model/template/qtran.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0]))
prev_state = reduce(lambda x, y: x + y, prev_state)
agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:])
output = self._q_network({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True})
output = self._q_network({'obs': agent_state, 'prev_state': prev_state})
agent_q, next_state = output['logit'], output['next_state']
next_state, _ = list_split(next_state, step=A)
agent_q = agent_q.reshape(T, B, A, -1)
Expand Down
2 changes: 0 additions & 2 deletions ding/model/template/wqmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ def forward(self, data: dict, single_step: bool = True, q_star: bool = False) ->
{
'obs': agent_state,
'prev_state': prev_state,
'enable_fast_timestep': True
}
) # here is the forward pass of the agent networks of Q_star
agent_q, next_state = output['logit'], output['next_state']
Expand Down Expand Up @@ -223,7 +222,6 @@ def forward(self, data: dict, single_step: bool = True, q_star: bool = False) ->
{
'obs': agent_state,
'prev_state': prev_state,
'enable_fast_timestep': True
}
) # here is the forward pass of the agent networks of Q
agent_q, next_state = output['logit'], output['next_state']
Expand Down
3 changes: 0 additions & 3 deletions ding/policy/ngu.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
'action': data['burnin_nstep_action'],
'reward': data['burnin_nstep_reward'],
'beta': data['burnin_nstep_beta'],
'enable_fast_timestep': True
}
tmp = self._learn_model.forward(
inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep]
Expand All @@ -304,7 +303,6 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
'action': data['main_action'],
'reward': data['main_reward'],
'beta': data['main_beta'],
'enable_fast_timestep': True
}
self._learn_model.reset(data_id=None, state=tmp['saved_state'][0])
q_value = self._learn_model.forward(inputs)['logit']
Expand All @@ -317,7 +315,6 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
'action': data['target_action'],
'reward': data['target_reward'],
'beta': data['target_beta'],
'enable_fast_timestep': True
}
with torch.no_grad():
target_q_value = self._target_model.forward(next_inputs)['logit']
Expand Down
Loading
Loading