Skip to content

Commit

Permalink
Final fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
trsvchn committed Oct 15, 2023
1 parent 2be179f commit e3a068b
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 4 deletions.
3 changes: 2 additions & 1 deletion flexynesis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,5 @@
from .feature_selection import *
from .data_augmentation import *
from .utils import *
from .config import *
from .config import *
from . import models
5 changes: 4 additions & 1 deletion flexynesis/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def main():
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument("--data_path", help="(Required) Path to the folder with train/test data files", type=str, required = True)
parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str, choices=["DirectPred", "supervised_vae", "MultiTripletNetwork"], required = True)
parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str, choices=["DirectPred", "DirectPredCNN", "supervised_vae", "MultiTripletNetwork"], required = True)
parser.add_argument("--target_variables", help="(Required) Which variables in 'clin.csv' to use for predictions, comma-separated if multiple", type = str, required = True)
parser.add_argument('--config_path', type=str, default=None, help='Optional path to an external hyperparameter configuration file in YAML format.')
parser.add_argument("--batch_variables",
Expand Down Expand Up @@ -51,6 +51,9 @@ def main():
elif args.model_class == "MultiTripletNetwork":
model_class = flexynesis.MultiTripletNetwork
config_name = 'MultiTripletNetwork'
elif args.model_class == "DirectPredCNN":
model_class = flexynesis.models.DirectPredCNN
config_name = 'DirectPredCNN'
else:
raise ValueError(f"Invalid model_class: {args.model_class}")

Expand Down
5 changes: 3 additions & 2 deletions flexynesis/models_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,15 @@ def forward(self, x):
x (torch.Tensor): The output tensor after passing through the MLP network.
"""
x = self.layer_1(x)
x = self.dropout(x)
if x.size(0) != 1: # Skip BatchNorm if batch size is 1
# x = self.dropout(x)
if (x.size(0) != 1) and self.training: # Skip BatchNorm if batch size is 1
x = self.batchnorm(x)
x = self.relu(x)
x = self.dropout(x)
x = self.layer_out(x)
return x


class EmbeddingNetwork(nn.Module):
"""
A simple feed-forward neural network for generating embeddings.
Expand Down
2 changes: 2 additions & 0 deletions flexynesis/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class CNN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()

self.layer_1 = nn.Conv1d(input_dim, hidden_dim, kernel_size=1)
self.batchnorm = nn.BatchNorm1d(hidden_dim)
self.relu = nn.ReLU()
Expand All @@ -22,6 +23,7 @@ def forward(self, x):
x = x.unsqueeze(-1)

x = self.layer_1(x)
# TODO: for 1 at train
x = self.batchnorm(x)
x = self.relu(x)
x = self.dropout(x)
Expand Down

0 comments on commit e3a068b

Please sign in to comment.