-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathbaseRNN.py
42 lines (35 loc) · 1.43 KB
/
baseRNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
""" A base class for RNN. """
import torch.nn as nn
class BaseRNN(nn.Module):
r"""
Applies a multi-layer RNN to an input sequence.
Note:
Do not use this class directly, use one of the sub classes.
Args:
vocab_size (int): size of the vocabulary
max_len (int): maximum allowed length for the sequence to be processed
hidden_size (int): number of features in the hidden state `h`
input_dropout_p (float): dropout probability for the input sequence
dropout_p (float): dropout probability for the output sequence
n_layers (int): number of recurrent layers
rnn_cell (str): type of RNN cell (Eg. 'LSTM' , 'GRU')
Inputs: ``*args``, ``**kwargs``
- ``*args``: variable length argument list.
- ``**kwargs``: arbitrary keyword arguments.
Attributes:
SYM_PAD: padding symbol
"""
def __init__(self, vocab_size, hidden_dim, dropout_p, n_layers, rnn_cell):
super(BaseRNN, self).__init__()
self.vocab_size = vocab_size
self.hidden_dim= hidden_dim
self.n_layers = n_layers
if rnn_cell.lower() == 'lstm':
self.rnn_cell = nn.LSTM
elif rnn_cell.lower() == 'gru':
self.rnn_cell = nn.GRU
else:
raise ValueError("Unsupported RNN Cell: {0}".format(rnn_cell))
self.dropout_p = dropout_p
def forward(self, *args, **kwargs):
raise NotImplementedError()