-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Batch support for LSTM_AE #10
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,6 @@ def __init__(self, input_dim, out_dim, h_dims, h_activ, out_activ): | |
self.h_activ, self.out_activ = h_activ, out_activ | ||
|
||
def forward(self, x): | ||
x = x.unsqueeze(0) | ||
for index, layer in enumerate(self.layers): | ||
x, (h_n, c_n) = layer(x) | ||
|
||
|
@@ -36,7 +35,7 @@ def forward(self, x): | |
elif self.out_activ and index == self.num_layers - 1: | ||
return self.out_activ(h_n).squeeze() | ||
|
||
return h_n.squeeze() | ||
return h_n | ||
|
||
|
||
class Decoder(nn.Module): | ||
|
@@ -56,20 +55,21 @@ def __init__(self, input_dim, out_dim, h_dims, h_activ): | |
self.layers.append(layer) | ||
|
||
self.h_activ = h_activ | ||
self.dense_matrix = nn.Parameter( | ||
torch.rand((layer_dims[-1], out_dim), dtype=torch.float), | ||
requires_grad=True | ||
) | ||
self.dense_layer = nn.Linear(layer_dims[-1], out_dim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The dense layer is created using the 'nn.Linear' method. It would be beneficial to add a comment explaining why this change was made from using a dense matrix to a dense layer. [medium] |
||
|
||
def forward(self, x, seq_len): | ||
x = x.repeat(seq_len, 1).unsqueeze(0) | ||
if len(x.shape) == 1 : # In case the batch dimension is not there | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a good practice to handle tensor shape manipulations in a separate function. This would make the code more readable and maintainable. [medium] |
||
x = x.repeat(seq_len, 1) # Add the sequence dimension by repeating the embedding | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider using the .unsqueeze() function to add an extra dimension to the tensor instead of using .repeat(). This could potentially improve performance as it avoids creating a larger tensor. [medium] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The 'repeat' method is used to repeat the tensor along a specified dimension. However, it's not immediately clear why this is necessary. Adding a comment to explain this would improve code readability. [medium] |
||
else : | ||
x = x.unsqueeze(1).repeat(1, seq_len, 1) # Add the sequence dimension by repeating the embedding | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The 'unsqueeze' method is used to add an extra dimension to the tensor. However, this is only done when the shape of 'x' is 1. It would be good to add a comment explaining why this is necessary, to improve code readability. [medium] |
||
|
||
for index, layer in enumerate(self.layers): | ||
x, (h_n, c_n) = layer(x) | ||
|
||
if self.h_activ and index < self.num_layers - 1: | ||
x = self.h_activ(x) | ||
|
||
return torch.mm(x.squeeze(), self.dense_matrix) | ||
return self.dense_layer(x) | ||
|
||
|
||
###### | ||
|
@@ -88,7 +88,10 @@ def __init__(self, input_dim, encoding_dim, h_dims=[], h_activ=nn.Sigmoid(), | |
h_activ) | ||
|
||
def forward(self, x): | ||
seq_len = x.shape[0] | ||
if len(x.shape) <= 2 : # In case the batch dimension is not there | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of checking the tensor dimensions in every forward pass, consider reshaping the input tensor to always have a batch dimension. This would simplify the forward methods and potentially improve performance. [important] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding comments to explain the purpose of the conditional statements checking the shape of 'x'. This would make the code more readable and easier to maintain. [medium] |
||
seq_len = x.shape[0] | ||
else : | ||
seq_len = x.shape[1] | ||
x = self.encoder(x) | ||
x = self.decoder(x, seq_len) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change from using a dense matrix to a linear layer is a good one as it makes the code more readable and leverages PyTorch's built-in functionality. However, ensure that this change doesn't affect the model's performance or results. [medium]