Skip to content

Commit

Permalink
Add vanilla CNN (#9)
Browse files Browse the repository at this point in the history
* Refactor DirectPred class: add 2 methods for layer creation

* Add Simple CNN implementation

* Duplicate example notebook

* Final fixes

* Update example notebook and config file
  • Loading branch information
trsvchn authored Oct 15, 2023
1 parent ed7394b commit b80c972
Show file tree
Hide file tree
Showing 10 changed files with 2,887 additions and 19 deletions.
2,770 changes: 2,770 additions & 0 deletions examples/tutorials/brca_subtypes_cnn.ipynb

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions examples/tutorials/conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,28 @@ DirectPred:
name: epochs
categories: [200]

DirectPredCNN:
- type: Integer
name: latent_dim
low: 16
high: 64
- type: Integer
name: hidden_dim
low: 64
high: 512
- type: Real
name: lr
low: 0.0001
high: 0.01
prior: log-uniform
- type: Integer
name: batch_size
low: 32
high: 128
- type: Categorical
name: epochs
categories: [200]

supervised_vae:
- type: Integer
name: latent_dim
Expand Down
3 changes: 2 additions & 1 deletion flexynesis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,5 @@
from .feature_selection import *
from .data_augmentation import *
from .utils import *
from .config import *
from .config import *
from . import models
5 changes: 4 additions & 1 deletion flexynesis/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def main():
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument("--data_path", help="(Required) Path to the folder with train/test data files", type=str, required = True)
parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str, choices=["DirectPred", "supervised_vae", "MultiTripletNetwork"], required = True)
parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str, choices=["DirectPred", "DirectPredCNN", "supervised_vae", "MultiTripletNetwork"], required = True)
parser.add_argument("--target_variables", help="(Required) Which variables in 'clin.csv' to use for predictions, comma-separated if multiple", type = str, required = True)
parser.add_argument('--config_path', type=str, default=None, help='Optional path to an external hyperparameter configuration file in YAML format.')
parser.add_argument("--batch_variables",
Expand Down Expand Up @@ -51,6 +51,9 @@ def main():
elif args.model_class == "MultiTripletNetwork":
model_class = flexynesis.MultiTripletNetwork
config_name = 'MultiTripletNetwork'
elif args.model_class == "DirectPredCNN":
model_class = flexynesis.models.DirectPredCNN
config_name = 'DirectPredCNN'
else:
raise ValueError(f"Invalid model_class: {args.model_class}")

Expand Down
9 changes: 8 additions & 1 deletion flexynesis/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,12 @@
Real(0.0001, 0.01, prior='log-uniform', name='lr'),
Integer(32, 128, name='batch_size'),
Categorical(epochs, name='epochs')
]
],
"DirectPredCNN": [
Integer(16, 128, name="latent_dim"),
Integer(64, 512, name="hidden_dim"),
Real(0.0001, 0.01, prior="log-uniform", name="lr"),
Integer(32, 128, name="batch_size"),
Categorical(epochs, name="epochs")
],
}
31 changes: 19 additions & 12 deletions flexynesis/model_DirectPred.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,31 @@ def __init__(self, config, dataset, target_variables, batch_variables = None, va
self.val_size = val_size
self.dat_train, self.dat_val = self.prepare_data()
self.feature_importances = {}

layers = list(dataset.dat.keys())
input_dims = [len(dataset.features[layers[i]]) for i in range(len(layers))]

# Instantiate layers.
self._init_encoders()
self._init_output_layers()

def _init_encoders(self):
layers = list(self.dataset.dat.keys())
input_dims = [len(self.dataset.features[layers[i]]) for i in range(len(layers))]
self.encoders = nn.ModuleList([
MLP(input_dim=input_dims[i],
hidden_dim=self.config['hidden_dim'],
output_dim=self.config['latent_dim']) for i in range(len(layers))])
MLP(input_dim=input_dims[i], hidden_dim=self.config["hidden_dim"], output_dim=self.config["latent_dim"])
for i in range(len(layers))
])

self.MLPs = nn.ModuleDict() # using ModuleDict to store multiple MLPs
def _init_output_layers(self):
layers = list(self.dataset.dat.keys())
self.MLPs = nn.ModuleDict() # using ModuleDict to store multiple MLPs
for var in self.target_variables:
if self.dataset.variable_types[var] == 'numerical':
if self.dataset.variable_types[var] == "numerical":
num_class = 1
else:
num_class = len(np.unique(self.dataset.ann[var]))
self.MLPs[var] = MLP(input_dim=self.config['latent_dim'] * len(layers),
hidden_dim=self.config['hidden_dim'],
output_dim=num_class)
self.MLPs[var] = MLP(
input_dim=self.config["latent_dim"] * len(layers),
hidden_dim=self.config["hidden_dim"],
output_dim=num_class,
)

def forward(self, x_list):
"""
Expand Down
3 changes: 2 additions & 1 deletion flexynesis/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .direct_pred import DirectPred
from .direct_pred_cnn import DirectPredCNN
from .supervised_vae import SupervisedVAE
from .triplet_encoder import MultiTripletNetwork

__all__ = ["DirectPred", "SupervisedVAE", "MultiTripletNetwork"]
__all__ = ["DirectPred", "DirectPredCNN", "SupervisedVAE", "MultiTripletNetwork"]
29 changes: 29 additions & 0 deletions flexynesis/models/direct_pred_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import numpy as np
from torch import nn

from .direct_pred import DirectPred
from ..modules import CNN


class DirectPredCNN(DirectPred):
def _init_encoders(self):
layers = list(self.dataset.dat.keys())
input_dims = [len(self.dataset.features[layers[i]]) for i in range(len(layers))]
self.encoders = nn.ModuleList([
CNN(input_dim=input_dims[i], hidden_dim=self.config["hidden_dim"], output_dim=self.config["latent_dim"])
for i in range(len(layers))
])

def _init_output_layers(self):
layers = list(self.dataset.dat.keys())
self.MLPs = nn.ModuleDict()
for var in self.target_variables:
if self.dataset.variable_types[var] == "numerical":
num_class = 1
else:
num_class = len(np.unique(self.dataset.ann[var]))
self.MLPs[var] = CNN(
input_dim=self.config["latent_dim"] * len(layers),
hidden_dim=self.config["hidden_dim"],
output_dim=num_class,
)
5 changes: 3 additions & 2 deletions flexynesis/models_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,15 @@ def forward(self, x):
x (torch.Tensor): The output tensor after passing through the MLP network.
"""
x = self.layer_1(x)
x = self.dropout(x)
if x.size(0) != 1: # Skip BatchNorm if batch size is 1
# x = self.dropout(x)
if (x.size(0) != 1) and self.training: # Skip BatchNorm if batch size is 1
x = self.batchnorm(x)
x = self.relu(x)
x = self.dropout(x)
x = self.layer_out(x)
return x


class EmbeddingNetwork(nn.Module):
"""
A simple feed-forward neural network for generating embeddings.
Expand Down
29 changes: 28 additions & 1 deletion flexynesis/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,31 @@
from .models_shared import Encoder, Decoder, MLP, EmbeddingNetwork, Classifier
from .model_TripletEncoder import MultiEmbeddingNetwork

__all__ = ["Encoder", "Decoder", "MLP", "EmbeddingNetwork", "MultiEmbeddingNetwork", "Classifier"]
__all__ = ["Encoder", "Decoder", "MLP", "EmbeddingNetwork", "MultiEmbeddingNetwork", "Classifier", "CNN"]


class CNN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()

self.layer_1 = nn.Conv1d(input_dim, hidden_dim, kernel_size=1)
self.batchnorm = nn.BatchNorm1d(hidden_dim)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.1)
self.layer_out = nn.Conv1d(hidden_dim, output_dim, kernel_size=1)

def forward(self, x):
"""(N, C) -> (N, C, L) -> (N, C).
"""
x = x.unsqueeze(-1)

x = self.layer_1(x)
# TODO: for 1 at train
x = self.batchnorm(x)
x = self.relu(x)
x = self.dropout(x)

x = self.layer_out(x)

x = x.squeeze(-1)
return x

0 comments on commit b80c972

Please sign in to comment.