Skip to content

Commit

Permalink
Merge pull request #83 from BIMSBbioinfo/82-implement-gnns-with-early…
Browse files Browse the repository at this point in the history
…-fusion

82 implement gnns with early fusion
  • Loading branch information
borauyar authored Jul 3, 2024
2 parents 14bae65 + a72f6a3 commit 36e9dd4
Show file tree
Hide file tree
Showing 7 changed files with 516 additions and 31 deletions.
36 changes: 24 additions & 12 deletions flexynesis/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import flexynesis
from flexynesis.models import *
from lightning.pytorch.callbacks import EarlyStopping

from .data import STRING, MultiOmicDatasetNW

def main():
"""
Expand Down Expand Up @@ -57,7 +57,7 @@ def main():

parser.add_argument("--data_path", help="(Required) Path to the folder with train/test data files", type=str, required = True)
parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str,
choices=["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork", "CrossModalPred"], required = True)
choices=["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork", "CrossModalPred", "GNNEarly"], required = True)
parser.add_argument("--gnn_conv_type", help="If model_class is set to DirectPredGCNN, choose which graph convolution type to use", type=str,
choices=["GC", "GCN", "GAT", "SAGE"])
parser.add_argument("--target_variables",
Expand Down Expand Up @@ -109,11 +109,6 @@ def main():
parser.add_argument("--string_organism", help="STRING DB organism id.", type=int, default=9606)
parser.add_argument("--string_node_name", help="Type of node name.", type=str, choices=["gene_name", "gene_id"], default="gene_name")


warnings.filterwarnings("ignore", ".*does not have many workers.*")
warnings.filterwarnings("ignore", "has been removed as a dependency of the")
warnings.filterwarnings("ignore", "The `srun` command is available on your system but is not used")

args = parser.parse_args()

# do some sanity checks on input arguments
Expand Down Expand Up @@ -157,10 +152,10 @@ def main():
torch.set_num_threads(args.threads)

# 5. check GNN arguments
if args.model_class == 'DirectPredGCNN':
if args.model_class == 'DirectPredGCNN' or args.model_class == 'GNNEarly':
if not args.gnn_conv_type:
warning_message = "\n".join([
"\n\n!!! When running DirectPredGCNN, a convolution type can be set",
"\n\n!!! When running DirectPredGCNN or GNNEarly, a convolution type can be set",
"with the --gnn_conv_type flag. See `flexynesis -h` for full set of options.",
"Falling back on the default convolution type: GC !!!\n\n"
])
Expand Down Expand Up @@ -202,7 +197,8 @@ class AvailableModels(NamedTuple):
MultiTripletNetwork: tuple[MultiTripletNetwork, str] = MultiTripletNetwork, "MultiTripletNetwork"
DirectPredGCNN: tuple[DirectPredGCNN, str] = DirectPredGCNN, "DirectPredGCNN"
CrossModalPred: tuple[CrossModalPred, str] = CrossModalPred, "CrossModalPred"

GNNEarly: tuple[GNNEarly, str] = GNNEarly, "GNNEarly"

available_models = AvailableModels()
model_class = getattr(available_models, args.model_class, None)
if model_class is None:
Expand Down Expand Up @@ -237,6 +233,16 @@ class AvailableModels(NamedTuple):
downsample = args.subsample)
train_dataset, test_dataset = data_importer.import_data(force = True)

if args.model_class == 'GNNEarly':
# overlay datasets with network info
# this is a temporary solution
print("[INFO] Overlaying the dataset with network data from STRINGDB")
obj = STRING('STRING', "9606", "gene_name")
train_dataset = MultiOmicDatasetNW(train_dataset, obj.graph_df)
train_dataset.print_stats()
test_dataset = MultiOmicDatasetNW(test_dataset, obj.graph_df)


# print feature logs to file (we use these tables to track which features are dropped/selected and why)
feature_logs = data_importer.feature_logs
for key in feature_logs.keys():
Expand Down Expand Up @@ -347,14 +353,20 @@ class AvailableModels(NamedTuple):
print("[INFO] Computing off-the-shelf method performance on first target variable:",model.target_variables[0])
var = model.target_variables[0]
metrics = pd.DataFrame()

# in the case when GNNEarly was used, the we use the initial multiomicdataset for train/test
# because GNNEarly requires a modified dataset structure to fit the networks (temporary solution)
train = train_dataset.multiomic_dataset if args.model_class == 'GNNEarly' else train_dataset
test = test_dataset.multiomic_dataset if args.model_class == 'GNNEarly' else test_dataset

if var != model.surv_event_var:
metrics = flexynesis.evaluate_baseline_performance(train_dataset, test_dataset,
metrics = flexynesis.evaluate_baseline_performance(train, test,
variable_name = var,
n_folds=5,
n_jobs = int(args.threads))
if model.surv_event_var and model.surv_time_var:
print("[INFO] Computing off-the-shelf method performance on survival variable:",model.surv_time_var)
metrics_baseline_survival = flexynesis.evaluate_baseline_survival_performance(train_dataset, test_dataset,
metrics_baseline_survival = flexynesis.evaluate_baseline_survival_performance(train, test,
model.surv_time_var,
model.surv_event_var,
n_folds = 5,
Expand Down
8 changes: 8 additions & 0 deletions flexynesis/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,13 @@
Categorical(epochs, name="epochs"),
Integer(8, 32, name='supervisor_hidden_dim'),
Categorical(['relu'], name="activation")
],
'GNNEarly': [
Integer(16, 128, name='latent_dim'),
Real(0.2, 1, name='hidden_dim_factor'), # relative size of the hidden_dim w.r.t input_dim
Real(0.0001, 0.01, prior='log-uniform', name='lr'),
Integer(8, 32, name='supervisor_hidden_dim'),
Categorical(epochs, name='epochs'),
Categorical(['relu'], name="activation")
]
}
120 changes: 117 additions & 3 deletions flexynesis/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def import_data(self, force=False):
initial_edge_list = graph_df.to_numpy().tolist()
# Read STRING by default.
elif isinstance(self.graph, STRING):
graph_df = self.graph.df
graph_df = self.graph.graph_df
available_features = np.unique(graph_df[["protein1", "protein2"]].to_numpy()).tolist()
initial_edge_list = stringdb_links_to_list(graph_df)
else:
Expand Down Expand Up @@ -679,6 +679,7 @@ def get_dataset_stats(self):
stats = {': '.join(['feature_count in', x]): self.dat[x].shape[1] for x in self.dat.keys()}
stats['sample_count'] = len(self.samples)
return(stats)


# given a MultiOmicDataset object, convert to Triplets (anchor,positive,negative)
class TripletMultiOmicDataset(Dataset):
Expand Down Expand Up @@ -714,7 +715,100 @@ def get_label_indices(self, labels):
for label in labels_set}
return labels_set, label_to_indices

class MultiOmicDatasetNW(Dataset):
def __init__(self, multiomic_dataset, interaction_df):
self.multiomic_dataset = multiomic_dataset
self.interaction_df = interaction_df

# Precompute common features and edge index
self.common_features = self.find_common_features()
self.gene_to_index = {gene: idx for idx, gene in enumerate(self.common_features)}
self.edge_index = self.create_edge_index()
self.samples = self.multiomic_dataset.samples
self.variable_types = self.multiomic_dataset.variable_types
self.label_mappings = self.multiomic_dataset.label_mappings
self.ann = self.multiomic_dataset.ann

# Precompute all node features for all samples
self.node_features_tensor = self.precompute_node_features()

# Store labels for all samples
self.labels = {target_name: labels for target_name, labels in self.multiomic_dataset.ann.items()}

# Store sample identifiers
self.samples = self.multiomic_dataset.samples

def find_common_features(self):
common_features = set.intersection(*(set(features) for features in self.multiomic_dataset.features.values()))
interaction_genes = set(self.interaction_df['protein1']).union(set(self.interaction_df['protein2']))
return list(common_features.intersection(interaction_genes))

def create_edge_index(self):
filtered_df = self.interaction_df[
(self.interaction_df['protein1'].isin(self.common_features)) &
(self.interaction_df['protein2'].isin(self.common_features))
]
edge_list = [(self.gene_to_index[row['protein1']], self.gene_to_index[row['protein2']]) for index, row in filtered_df.iterrows()]
return torch.tensor(edge_list, dtype=torch.long).t()

def precompute_node_features(self):
# Find indices of common features in each data matrix
feature_indices = {data_type: [self.multiomic_dataset.features[data_type].get_loc(gene)
for gene in self.common_features]
for data_type in self.multiomic_dataset.dat}
# Create a tensor to store all features [num_samples, num_nodes, num_data_types]
num_samples = len(self.samples)
num_nodes = len(self.common_features)
num_data_types = len(self.multiomic_dataset.dat)
all_features = torch.empty((num_samples, num_nodes, num_data_types), dtype=torch.float)

# Extract features for each data type and place them in the tensor
for i, data_type in enumerate(self.multiomic_dataset.dat):
# Get the data matrix
data_matrix = self.multiomic_dataset.dat[data_type]
# Use advanced indexing to extract features for all samples at once
indices = feature_indices[data_type]
if indices: # Ensure there are common features in this data type
all_features[:, :, i] = data_matrix[:, indices]
return all_features

def __getitem__(self, idx):
node_features_tensor = self.node_features_tensor[idx]
y_dict = {target_name: self.labels[target_name][idx] for target_name in self.labels}
return node_features_tensor, y_dict, self.samples[idx]

def __len__(self):
return len(self.samples)

def print_stats(self):
"""
Prints various statistics about the graph.
"""
num_nodes = len(self.common_features)
num_edges = self.edge_index.size(1)
num_node_features = self.node_features_tensor.size(2)

# Calculate degree for each node
degrees = torch.zeros(num_nodes, dtype=torch.long)
degrees.index_add_(0, self.edge_index[0], torch.ones_like(self.edge_index[0]))
degrees.index_add_(0, self.edge_index[1], torch.ones_like(self.edge_index[1])) # For undirected graphs

num_singletons = torch.sum(degrees == 0).item()
non_singletons = degrees[degrees > 0]

mean_edges_per_node = non_singletons.float().mean().item() if len(non_singletons) > 0 else 0
median_edges_per_node = non_singletons.median().item() if len(non_singletons) > 0 else 0
max_edges = degrees.max().item()

print("Dataset Statistics:")
print(f"Number of nodes: {num_nodes}")
print(f"Total number of edges: {num_edges}")
print(f"Number of node features per node: {num_node_features}")
print(f"Number of singletons (nodes with no edges): {num_singletons}")
print(f"Mean number of edges per node (excluding singletons): {mean_edges_per_node:.2f}")
print(f"Median number of edges per node (excluding singletons): {median_edges_per_node}")
print(f"Max number of edges per node: {max_edges}")

class MultiOmicPYGDataset(PYGDataset):
required = ["variable_types", "features", "samples", "label_mappings", "feature_ann"]

Expand Down Expand Up @@ -811,7 +905,11 @@ def __init__(self, root: str, organism: int = 9606, node_name: str = "gene_name"
self.organism = organism
self.node_name = node_name
super().__init__(os.path.join(root, self.base_folder))
self.df = read_user_graph(self.processed_paths[0], sep=",", header=0, index_col=0)

if not os.path.exists(self.processed_paths[0]):
self.download()
self.process()
self.graph_df = pd.read_csv(self.processed_paths[0], sep=",", header=0, index_col=0)

def len(self) -> int:
return 0
Expand Down Expand Up @@ -839,8 +937,8 @@ def download(self) -> None:

def process(self) -> None:
graph_df = read_stringdb_graph(self.node_name, self.raw_paths[0], self.raw_paths[1])
# Drop nans and save to disk.
graph_df.dropna().to_csv(self.processed_paths[0])
self.graph_df = graph_df # Storing the DataFrame in the instance


def read_user_graph(fpath, sep=" ", header=None, **pd_read_csv_kw):
Expand All @@ -859,6 +957,22 @@ def read_stringdb_links(fname):
df[["protein1", "protein2"]] = df[["protein1", "protein2"]].map(lambda a: a.split(".")[-1])
return df

def read_stringdb_links_test(fname):
df = pd.read_csv(fname, header=0, sep=" ")
df = df[df.combined_score > 800]
df = df[df.combined_score > df.combined_score.quantile(0.9)]
df_expanded = pd.concat([
df.rename(columns={'protein1': 'protein', 'protein2': 'partner'}),
df.rename(columns={'protein2': 'protein', 'protein1': 'partner'})
])
# Sort the expanded DataFrame by 'combined_score' in descending order
df_expanded_sorted = df_expanded.sort_values(by='combined_score', ascending=False)
# Reduce to unique interactions to avoid counting duplicates
df_expanded_unique = df_expanded_sorted.drop_duplicates(subset=['protein', 'partner'])
top_interactions = df_expanded_unique.groupby('protein').head(5)
df = top_interactions.rename(columns={'protein': 'protein1', 'partner': 'protein2'})
df[["protein1", "protein2"]] = df[["protein1", "protein2"]].map(lambda a: a.split(".")[-1])
return df

def read_stringdb_aliases(fname: str, node_name: str) -> dict[str, str]:
if node_name == "gene_id":
Expand Down
18 changes: 12 additions & 6 deletions flexynesis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

import os, yaml
from skopt.space import Integer, Categorical, Real
from .data import STRING

torch.set_float32_matmul_precision("medium")

class HyperparameterTuning:
"""
Expand Down Expand Up @@ -104,7 +106,7 @@ def __init__(self, dataset, model_class, config_name, target_variables,
self.input_layers = input_layers
self.output_layers = output_layers

self.DataLoader = DataLoader # use torch data loader by default
self.DataLoader = torch.utils.data.DataLoader # use torch data loader by default

if self.model_class.__name__ == 'MultiTripletNetwork':
self.loader_dataset = TripletMultiOmicDataset(self.dataset, self.target_variables[0])
Expand All @@ -127,7 +129,7 @@ def __init__(self, dataset, model_class, config_name, target_variables,
else:
raise ValueError(f"'{self.config_name}' not found in the default config.")

def get_batch_space(self, min_size = 16, max_size = 256):
def get_batch_space(self, min_size = 32, max_size = 256):
m = int(np.log2(len(self.dataset) * 0.8))
st = int(np.log2(min_size))
end = int(np.log2(max_size))
Expand All @@ -147,8 +149,9 @@ def setup_trainer(self, params, current_step, total_steps, full_train = False):
if self.early_stop_patience > 0 and full_train == False:
early_stop_callback = self.init_early_stopping()
mycallbacks.append(early_stop_callback)

trainer = pl.Trainer(
precision = '16-mixed', # mixed precision training
max_epochs=int(params['epochs']),
log_every_n_steps=5,
callbacks=mycallbacks,
Expand All @@ -173,7 +176,7 @@ def objective(self, params, current_step, total_steps, full_train = False):
"device_type": self.device_type,
}

if self.model_class.__name__ == 'DirectPredGCNN':
if self.model_class.__name__ == 'DirectPredGCNN' or self.model_class.__name__ == 'GNNEarly':
model_args['gnn_conv_type'] = self.gnn_conv_type
if self.model_class.__name__ == 'CrossModalPred':
model_args['input_layers'] = self.input_layers
Expand Down Expand Up @@ -208,11 +211,14 @@ def objective(self, params, current_step, total_steps, full_train = False):
print(f"[INFO] {'training cross-validation fold' if self.use_cv else 'training validation split'} {i}")
train_subset = torch.utils.data.Subset(self.loader_dataset, train_index)
val_subset = torch.utils.data.Subset(self.loader_dataset, val_index)
train_loader = self.DataLoader(train_subset, batch_size=int(params['batch_size']), pin_memory=True, shuffle=True, drop_last=True)
val_loader = self.DataLoader(val_subset, batch_size=int(params['batch_size']), pin_memory=True, shuffle=False)
train_loader = self.DataLoader(train_subset, batch_size=int(params['batch_size']),
pin_memory=True, shuffle=True, drop_last=True, num_workers = 4, prefetch_factor = None, persistent_workers = True)
val_loader = self.DataLoader(val_subset, batch_size=int(params['batch_size']),
pin_memory=True, shuffle=False, num_workers = 4, prefetch_factor = None, persistent_workers = True)

model = self.model_class(**model_args)
trainer, early_stop_callback = self.setup_trainer(params, current_step, total_steps)
print(f"[INFO] hpo config:{params}")
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
if early_stop_callback.stopped_epoch:
epochs.append(early_stop_callback.stopped_epoch)
Expand Down
4 changes: 2 additions & 2 deletions flexynesis/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
from .supervised_vae import supervised_vae
from .triplet_encoder import MultiTripletNetwork
from .crossmodal_pred import CrossModalPred

__all__ = ["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork", "CrossModalPred"]
from .gnn_early import GNNEarly
__all__ = ["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork", "CrossModalPred", "GNNEarly"]
Loading

0 comments on commit 36e9dd4

Please sign in to comment.