Skip to content

Commit

Permalink
define activation function at the beginning
Browse files Browse the repository at this point in the history
  • Loading branch information
borauyar committed Apr 4, 2024
1 parent 75f5b6e commit f2db9d6
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions flexynesis/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dims, latent_dim):
super(Encoder, self).__init__()

self.LeakyReLU = nn.LeakyReLU(0.2)
self.act = nn.LeakyReLU(0.2)

hidden_layers = []

hidden_layers.append(nn.Linear(input_dim, hidden_dims[0]))
nn.init.xavier_uniform_(hidden_layers[-1].weight)
hidden_layers.append(self.LeakyReLU)
hidden_layers.append(self.act)
hidden_layers.append(nn.BatchNorm1d(hidden_dims[0]))

for i in range(len(hidden_dims)-1):
hidden_layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
nn.init.xavier_uniform_(hidden_layers[-1].weight)
hidden_layers.append(self.LeakyReLU)
hidden_layers.append(self.act)
hidden_layers.append(nn.BatchNorm1d(hidden_dims[i+1]))

self.hidden_layers = nn.Sequential(*hidden_layers)
Expand Down Expand Up @@ -68,19 +68,19 @@ class Decoder(nn.Module):
def __init__(self, latent_dim, hidden_dims, output_dim):
super(Decoder, self).__init__()

self.LeakyReLU = nn.LeakyReLU(0.2)
self.act = nn.LeakyReLU(0.2)

hidden_layers = []

hidden_layers.append(nn.Linear(latent_dim, hidden_dims[0]))
nn.init.xavier_uniform_(hidden_layers[-1].weight)
hidden_layers.append(self.LeakyReLU)
hidden_layers.append(self.act)
hidden_layers.append(nn.BatchNorm1d(hidden_dims[0]))

for i in range(len(hidden_dims) - 1):
hidden_layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1]))
nn.init.xavier_uniform_(hidden_layers[-1].weight)
hidden_layers.append(self.LeakyReLU)
hidden_layers.append(self.act)
hidden_layers.append(nn.BatchNorm1d(hidden_dims[i+1]))

self.hidden_layers = nn.Sequential(*hidden_layers)
Expand All @@ -99,7 +99,7 @@ def forward(self, x):
x_hat (torch.Tensor): The reconstructed output tensor.
"""
h = self.hidden_layers(x)
x_hat = torch.tanh(self.FC_output(h))
x_hat = torch.sigmoid(self.FC_output(h))
return x_hat


Expand Down

0 comments on commit f2db9d6

Please sign in to comment.