Skip to content

Commit

Permalink
turn on mlps for batch variables
Browse files Browse the repository at this point in the history
  • Loading branch information
borauyar committed Nov 23, 2023
1 parent 4a1a930 commit 81ccf31
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion flexynesis/models/direct_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, config, dataset, target_variables, batch_variables = None, va
output_dim=self.config['latent_dim']) for i in range(len(layers))])

self.MLPs = nn.ModuleDict() # using ModuleDict to store multiple MLPs
for var in self.target_variables:
for var in self.variables:
if self.dataset.variable_types[var] == 'numerical':
num_class = 1
else:
Expand Down
2 changes: 1 addition & 1 deletion flexynesis/models/supervised_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, config, dataset, target_variables, batch_variables = None, v
# define supervisor heads
# using ModuleDict to store multiple MLPs
self.MLPs = nn.ModuleDict()
for var in self.target_variables:
for var in self.variables:
if self.dataset.variable_types[var] == 'numerical':
num_class = 1
else:
Expand Down
2 changes: 1 addition & 1 deletion flexynesis/models/triplet_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(self, config, dataset, target_variables, batch_variables = None, va

# define supervisor heads for both target and batch variables
self.MLPs = nn.ModuleDict() # using ModuleDict to store multiple MLPs
for var in self.target_variables:
for var in self.variables:
if self.variable_types[var] == 'numerical':
num_class = 1
else:
Expand Down

0 comments on commit 81ccf31

Please sign in to comment.