Skip to content

Commit

Permalink
Isolate everything, no base (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
trsvchn authored Nov 14, 2023
1 parent cd8575a commit 4a1a930
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 255 deletions.
231 changes: 0 additions & 231 deletions flexynesis/models/base_direct_pred.py

This file was deleted.

29 changes: 11 additions & 18 deletions flexynesis/models/direct_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,25 @@ def __init__(self, config, dataset, target_variables, batch_variables = None, va
self.variables = target_variables + batch_variables if batch_variables else target_variables
self.val_size = val_size
self.dat_train, self.dat_val = self.prepare_data()
self.feature_importances = {}
# Instantiate layers.
self._init_encoders()
self._init_output_layers()
self.feature_importances = {}

layers = list(dataset.dat.keys())
input_dims = [len(dataset.features[layers[i]]) for i in range(len(layers))]

def _init_encoders(self):
layers = list(self.dataset.dat.keys())
input_dims = [len(self.dataset.features[layers[i]]) for i in range(len(layers))]
self.encoders = nn.ModuleList([
MLP(input_dim=input_dims[i], hidden_dim=self.config["hidden_dim"], output_dim=self.config["latent_dim"])
for i in range(len(layers))
])
MLP(input_dim=input_dims[i],
hidden_dim=self.config['hidden_dim'],
output_dim=self.config['latent_dim']) for i in range(len(layers))])

def _init_output_layers(self):
layers = list(self.dataset.dat.keys())
self.MLPs = nn.ModuleDict() # using ModuleDict to store multiple MLPs
for var in self.target_variables:
if self.dataset.variable_types[var] == "numerical":
if self.dataset.variable_types[var] == 'numerical':
num_class = 1
else:
num_class = len(np.unique(self.dataset.ann[var]))
self.MLPs[var] = MLP(
input_dim=self.config["latent_dim"] * len(layers),
hidden_dim=self.config["hidden_dim"],
output_dim=num_class,
)
self.MLPs[var] = MLP(input_dim=self.config['latent_dim'] * len(layers),
hidden_dim=self.config['hidden_dim'],
output_dim=num_class)

def forward(self, x_list):
"""
Expand Down
Loading

0 comments on commit 4a1a930

Please sign in to comment.