From e3a068be61ea13c1b5343a56319f3219e16268d1 Mon Sep 17 00:00:00 2001 From: Taras Savchyn Date: Sun, 15 Oct 2023 20:50:17 +0200 Subject: [PATCH] Final fixes --- flexynesis/__init__.py | 3 ++- flexynesis/__main__.py | 5 ++++- flexynesis/models_shared.py | 5 +++-- flexynesis/modules.py | 2 ++ 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/flexynesis/__init__.py b/flexynesis/__init__.py index 13846fa..a08068d 100644 --- a/flexynesis/__init__.py +++ b/flexynesis/__init__.py @@ -59,4 +59,5 @@ from .feature_selection import * from .data_augmentation import * from .utils import * -from .config import * \ No newline at end of file +from .config import * +from . import models diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index 869b389..a363b45 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -11,7 +11,7 @@ def main(): formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--data_path", help="(Required) Path to the folder with train/test data files", type=str, required = True) - parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str, choices=["DirectPred", "supervised_vae", "MultiTripletNetwork"], required = True) + parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str, choices=["DirectPred", "DirectPredCNN", "supervised_vae", "MultiTripletNetwork"], required = True) parser.add_argument("--target_variables", help="(Required) Which variables in 'clin.csv' to use for predictions, comma-separated if multiple", type = str, required = True) parser.add_argument('--config_path', type=str, default=None, help='Optional path to an external hyperparameter configuration file in YAML format.') parser.add_argument("--batch_variables", @@ -51,6 +51,9 @@ def main(): elif args.model_class == "MultiTripletNetwork": model_class = flexynesis.MultiTripletNetwork config_name = 'MultiTripletNetwork' + elif args.model_class == "DirectPredCNN": + model_class = flexynesis.models.DirectPredCNN + config_name = 'DirectPredCNN' else: raise ValueError(f"Invalid model_class: {args.model_class}") diff --git a/flexynesis/models_shared.py b/flexynesis/models_shared.py index c5241f2..900887e 100644 --- a/flexynesis/models_shared.py +++ b/flexynesis/models_shared.py @@ -132,14 +132,15 @@ def forward(self, x): x (torch.Tensor): The output tensor after passing through the MLP network. """ x = self.layer_1(x) - x = self.dropout(x) - if x.size(0) != 1: # Skip BatchNorm if batch size is 1 + # x = self.dropout(x) + if (x.size(0) != 1) and self.training: # Skip BatchNorm if batch size is 1 x = self.batchnorm(x) x = self.relu(x) x = self.dropout(x) x = self.layer_out(x) return x + class EmbeddingNetwork(nn.Module): """ A simple feed-forward neural network for generating embeddings. diff --git a/flexynesis/modules.py b/flexynesis/modules.py index efe763d..d5c0169 100644 --- a/flexynesis/modules.py +++ b/flexynesis/modules.py @@ -10,6 +10,7 @@ class CNN(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super().__init__() + self.layer_1 = nn.Conv1d(input_dim, hidden_dim, kernel_size=1) self.batchnorm = nn.BatchNorm1d(hidden_dim) self.relu = nn.ReLU() @@ -22,6 +23,7 @@ def forward(self, x): x = x.unsqueeze(-1) x = self.layer_1(x) + # TODO: for 1 at train x = self.batchnorm(x) x = self.relu(x) x = self.dropout(x)