-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathminLSTMcell.py
38 lines (29 loc) · 1.4 KB
/
minLSTMcell.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.nn as nn
class MinLSTMCell(nn.Module):
def __init__(self, units, input_shape):
super(MinLSTMCell, self).__init__()
self.units = units
self.input_shape = input_shape
# Initialize the linear layers for the forget gate, input gate, and hidden state transformation
self.linear_f = nn.Linear(self.input_shape, self.units)
self.linear_i = nn.Linear(self.input_shape, self.units)
self.linear_h = nn.Linear(self.input_shape, self.units)
def forward(self, pre_h, x_t):
"""
pre_h: (batch_size, units) - previous hidden state (h_prev)
x_t: (batch_size, input_size) - input at time step t
"""
# Forget gate: f_t = sigmoid(W_f * x_t)
f_t = torch.sigmoid(self.linear_f(x_t)) # (batch_size, units)
# Input gate: i_t = sigmoid(W_i * x_t)
i_t = torch.sigmoid(self.linear_i(x_t)) # (batch_size, units)
# Hidden state: tilde_h_t = W_h * x_t
tilde_h_t = self.linear_h(x_t) # (batch_size, units)
# Normalize the gates
sum_f_i = f_t + i_t
f_prime_t = f_t / sum_f_i # (batch_size, units)
i_prime_t = i_t / sum_f_i # (batch_size, units)
# New hidden state: h_t = f_prime_t * pre_h + i_prime_t * tilde_h_t
h_t = f_prime_t * pre_h + i_prime_t * tilde_h_t # (batch_size, units)
return h_t # (batch_size, units)