From cf0a26130961ca06a33f9f300eefb61a3bef0458 Mon Sep 17 00:00:00 2001 From: mustapha ajeghrir <66799406+Mustapha-AJEGHRIR@users.noreply.github.com> Date: Thu, 4 Aug 2022 21:25:40 +0200 Subject: [PATCH] Adding Batch support for LSTM_AE --- sequitur/models/lstm_ae.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/sequitur/models/lstm_ae.py b/sequitur/models/lstm_ae.py index 8b1eec7..cf89417 100644 --- a/sequitur/models/lstm_ae.py +++ b/sequitur/models/lstm_ae.py @@ -27,7 +27,6 @@ def __init__(self, input_dim, out_dim, h_dims, h_activ, out_activ): self.h_activ, self.out_activ = h_activ, out_activ def forward(self, x): - x = x.unsqueeze(0) for index, layer in enumerate(self.layers): x, (h_n, c_n) = layer(x) @@ -36,7 +35,7 @@ def forward(self, x): elif self.out_activ and index == self.num_layers - 1: return self.out_activ(h_n).squeeze() - return h_n.squeeze() + return h_n class Decoder(nn.Module): @@ -56,20 +55,21 @@ def __init__(self, input_dim, out_dim, h_dims, h_activ): self.layers.append(layer) self.h_activ = h_activ - self.dense_matrix = nn.Parameter( - torch.rand((layer_dims[-1], out_dim), dtype=torch.float), - requires_grad=True - ) + self.dense_layer = nn.Linear(layer_dims[-1], out_dim) def forward(self, x, seq_len): - x = x.repeat(seq_len, 1).unsqueeze(0) + if len(x.shape) == 1 : # In case the batch dimension is not there + x = x.repeat(seq_len, 1) # Add the sequence dimension by repeating the embedding + else : + x = x.unsqueeze(1).repeat(1, seq_len, 1) # Add the sequence dimension by repeating the embedding + for index, layer in enumerate(self.layers): x, (h_n, c_n) = layer(x) if self.h_activ and index < self.num_layers - 1: x = self.h_activ(x) - return torch.mm(x.squeeze(), self.dense_matrix) + return self.dense_layer(x) ###### @@ -88,7 +88,10 @@ def __init__(self, input_dim, encoding_dim, h_dims=[], h_activ=nn.Sigmoid(), h_activ) def forward(self, x): - seq_len = x.shape[0] + if len(x.shape) <= 2 : # In case the batch dimension is not there + seq_len = x.shape[0] + else : + seq_len = x.shape[1] x = self.encoder(x) x = self.decoder(x, seq_len)