Skip to content

Commit

Permalink
Merge pull request #68 from BIMSBbioinfo/cross_modal_pred
Browse files Browse the repository at this point in the history
Define a new architecture for Cross-Modality Encoder/Decoder Network
  • Loading branch information
borauyar authored Mar 17, 2024
2 parents ec49614 + 379276e commit e0f200c
Show file tree
Hide file tree
Showing 6 changed files with 529 additions and 8 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ jobs:
conda activate my_env
flexynesis --data_path dataset1 --model_class supervised_vae --target_variables Erlotinib,Crizotinib --fusion_type early --hpo_iter 1 --features_min 50 --features_top_percentile 5 --log_transform False --data_types gex,cnv --outdir . --prefix erlotinib_svae --early_stop_patience 3 --use_loss_weighting True --evaluate_baseline_performance False
- name: Run CrossModalPred
shell: bash -l {0}
run: |
conda activate my_env
flexynesis --data_path dataset1 --model_class CrossModalPred --target_variables Erlotinib --fusion_type intermediate --hpo_iter 1 --features_min 50 --features_top_percentile 5 --log_transform False --data_types gex,cnv --input_layers gex --output_layers cnv --outdir . --prefix erlotinib_crossmodal --early_stop_patience 3 --use_loss_weighting True --evaluate_baseline_performance False
- name: Run MultiTripletNetwork
shell: bash -l {0}
run: |
Expand Down
47 changes: 41 additions & 6 deletions flexynesis/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def main():

parser.add_argument("--data_path", help="(Required) Path to the folder with train/test data files", type=str, required = True)
parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str,
choices=["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork"], required = True)
choices=["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork", "CrossModalPred"], required = True)
parser.add_argument("--gnn_conv_type", help="If model_class is set to DirectPredGCNN, choose which graph convolution type to use", type=str,
choices=["GC", "GCN", "GAT", "SAGE"])
parser.add_argument("--target_variables",
Expand All @@ -36,6 +36,16 @@ def main():
parser.add_argument("--features_min", help="Minimum number of features to retain after feature selection", type=int, default = 500)
parser.add_argument("--features_top_percentile", help="Top percentile features to retain after feature selection", type=float, default = 20)
parser.add_argument("--data_types", help="(Required) Which omic data matrices to work on, comma-separated: e.g. 'gex,cnv'", type=str, required = True)
parser.add_argument("--input_layers",
help="If model_class is set to CrossModalPred, choose which data types to use as input/encoded layers"
"Comma-separated if multiple",
type=str, default = None
)
parser.add_argument("--output_layers",
help="If model_class is set to CrossModalPred, choose which data types to use as output/decoded layers"
"Comma-separated if multiple",
type=str, default = None
)
parser.add_argument("--outdir", help="Path to the output folder to save the model outputs", type=str, default = os.getcwd())
parser.add_argument("--prefix", help="Job prefix to use for output files", type=str, default = 'job')
parser.add_argument("--log_transform", help="whether to apply log-transformation to input data matrices", type=str, choices=['True', 'False'], default = 'False')
Expand Down Expand Up @@ -71,9 +81,14 @@ def main():
"or --batch_variables."]))

# 3. Check for compatibility of fusion_type with DirectPredGCNN
if args.fusion_type == "early" and args.model_class == "DirectPredGCNN":
parser.error("The 'DirectPredGCNN' model cannot be used with early fusion type. "
"Use --fusion_type intermediate instead.")
if args.fusion_type == "early":
if args.model_class == "DirectPredGCNN":
parser.error("The 'DirectPredGCNN' model cannot be used with early fusion type. "
"Use --fusion_type intermediate instead.")
if args.model_class == 'CrossModalPred':
parser.error("The 'CrossModalPred' model cannot be used with early fusion type. "
"Use --fusion_type intermediate instead.")


# 4. Check for device availability if --accelerator is set.
if args.use_gpu:
Expand Down Expand Up @@ -108,6 +123,23 @@ def main():
else:
gnn_conv_type = None

# 6. Check CrossModalPred arguments
input_layers = args.input_layers
output_layers = args.output_layers
datatypes = args.data_types.strip().split(',')
if args.model_class == 'CrossModalPred':
# check if input output layers are matching the requested data types
if args.input_layers:
input_layers = input_layers.strip().split(',')
# Check if input_layers are a subset of datatypes
if not all(layer in datatypes for layer in input_layers):
raise ValueError(f"Input layers {input_layers} are not a valid subset of the data types: ({datatypes}).")
# check if output_layers are a subset of datatypes
if args.output_layers:
output_layers = output_layers.strip().split(',')
if not all(layer in datatypes for layer in output_layers):
raise ValueError(f"Output layers {output_layers} are not a valid subset of the data types: ({datatypes}).")

# Validate paths
if not os.path.exists(args.data_path):
raise FileNotFoundError(f"Input --data_path doesn't exist at:", {args.data_path})
Expand All @@ -120,6 +152,7 @@ class AvailableModels(NamedTuple):
supervised_vae: tuple[supervised_vae, str] = supervised_vae, "supervised_vae"
MultiTripletNetwork: tuple[MultiTripletNetwork, str] = MultiTripletNetwork, "MultiTripletNetwork"
DirectPredGCNN: tuple[DirectPredGCNN, str] = DirectPredGCNN, "DirectPredGCNN"
CrossModalPred: tuple[CrossModalPred, str] = CrossModalPred, "CrossModalPred"

available_models = AvailableModels()
model_class = getattr(available_models, args.model_class, None)
Expand All @@ -140,7 +173,7 @@ class AvailableModels(NamedTuple):
concatenate = True

data_importer = flexynesis.DataImporter(path = args.data_path,
data_types = args.data_types.strip().split(','),
data_types = datatypes,
concatenate = concatenate,
log_transform = args.log_transform == 'True',
correlation_threshold = args.correlation_threshold,
Expand Down Expand Up @@ -173,7 +206,9 @@ class AvailableModels(NamedTuple):
use_loss_weighting = args.use_loss_weighting == 'True',
early_stop_patience = int(args.early_stop_patience),
device_type = device_type,
gnn_conv_type = gnn_conv_type)
gnn_conv_type = gnn_conv_type,
input_layers = input_layers,
output_layers = output_layers)

# do a hyperparameter search training multiple models and get the best_configuration
model, best_params = tuner.perform_tuning()
Expand Down
8 changes: 8 additions & 0 deletions flexynesis/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
Integer(32, 128, name='batch_size'),
Categorical(epochs, name='epochs')
],
'CrossModalPred': [
Integer(16, 128, name='latent_dim'),
Integer(64, 512, name='hidden_dim'),
Integer(8, 32, name='supervisor_hidden_dim'),
Real(0.0001, 0.01, prior='log-uniform', name='lr'),
Integer(32, 128, name='batch_size'),
Categorical(epochs, name='epochs')
],
'MultiTripletNetwork': [
Integer(16, 128, name='latent_dim'),
Integer(64, 512, name='hidden_dim'),
Expand Down
8 changes: 7 additions & 1 deletion flexynesis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def __init__(self, dataset, model_class, config_name, target_variables,
batch_variables = None, surv_event_var = None, surv_time_var = None,
n_iter = 10, config_path = None, plot_losses = False,
val_size = 0.2, use_loss_weighting = True, early_stop_patience = -1,
device_type = None, gnn_conv_type = None):
device_type = None, gnn_conv_type = None,
input_layers = None, output_layers = None):
self.dataset = dataset
self.model_class = model_class
self.target_variables = target_variables
Expand All @@ -72,6 +73,8 @@ def __init__(self, dataset, model_class, config_name, target_variables,
self.early_stop_patience = early_stop_patience
self.use_loss_weighting = use_loss_weighting
self.gnn_conv_type = gnn_conv_type
self.input_layers = input_layers
self.output_layers = output_layers

# If config_path is provided, use it
if config_path:
Expand All @@ -95,6 +98,9 @@ def objective(self, params, current_step, total_steps):
"use_loss_weighting": self.use_loss_weighting, "device_type": self.device_type}
if self.model_class.__name__ == 'DirectPredGCNN':
model_args["gnn_conv_type"] = self.gnn_conv_type
if self.model_class.__name__ == 'CrossModalPred':
model_args["input_layers"] = self.input_layers
model_args["output_layers"] = self.output_layers

model = self.model_class(**model_args)
print(params)
Expand Down
3 changes: 2 additions & 1 deletion flexynesis/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .direct_pred_gcnn import DirectPredGCNN
from .supervised_vae import supervised_vae
from .triplet_encoder import MultiTripletNetwork
from .crossmodal_pred import CrossModalPred

__all__ = ["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork"]
__all__ = ["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork", "CrossModalPred"]
Loading

0 comments on commit e0f200c

Please sign in to comment.