Skip to content

Commit

Permalink
move modality-based graph filtering to another function; improve log …
Browse files Browse the repository at this point in the history
…messages
  • Loading branch information
borauyar committed Feb 21, 2024
1 parent 81b9d5f commit 36e736d
Showing 1 changed file with 49 additions and 46 deletions.
95 changes: 49 additions & 46 deletions flexynesis/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def __init__(self, path, data_types, log_transform = False, concatenate = False,
self.transformers = None

self.use_graph = use_graph
self.node_name = node_name # "gene_name" | "gene_id"
self.node_name = node_name
self.transform = transform

def import_data(self):
Expand All @@ -341,47 +341,36 @@ def import_data(self):

self.validate_data_folders(training_path, testing_path)

training_data = self.read_data(training_path)
testing_data = self.read_data(testing_path)
# raw data matrices as exists in the data path
train_dat = self.read_data(training_path)
test_dat = self.read_data(testing_path)

if self.use_graph:
edge_list = self.read_graph(training_data, testing_data)
edge_list = self.read_graph(train_dat, test_dat)

# cleanup uninformative features/samples, subset annotation data, do feature selection on training data
train_dat, train_ann, train_samples, train_features = self.process_data(training_data, split = 'train')
test_dat, test_ann, test_samples, test_features = self.process_data(testing_data, split = 'test')
train_dat, train_ann, train_samples, train_features = self.process_data(train_dat, split = 'train')
test_dat, test_ann, test_samples, test_features = self.process_data(test_dat, split = 'test')

# harmonize feature sets in train/test
train_dat, test_dat = self.harmonize(train_dat, test_dat)

train_feature_ann = {}
test_feature_ann = {}
if self.use_graph:
# Now filter the graph edges based on provided_genes
# But this time separately for each modality
for k, v in train_dat.items():
mod_gene_list = v.index.to_list()
node_to_idx = {node: i for i, node in enumerate(mod_gene_list)}
mod_edge_list = []
for edge in edge_list:
src, dst = edge
if (src in mod_gene_list) and (dst in mod_gene_list):
mod_edge_list.append([node_to_idx[src], node_to_idx[dst]])
train_feature_ann[k] = {"edge_index": torch.tensor(mod_edge_list).T}
# Repeat the same for test data
for k, v in train_dat.items():
mod_gene_list = v.index.to_list()
node_to_idx = {node: i for i, node in enumerate(mod_gene_list)}
mod_edge_list = []
for edge in edge_list:
src, dst = edge
if (src in mod_gene_list) and (dst in mod_gene_list):
mod_edge_list.append([node_to_idx[src], node_to_idx[dst]])
test_feature_ann[k] = {"edge_index": torch.tensor(mod_edge_list).T}

# apply a second filter to the graph,
# this time by each data modality separately
print("\n[INFO] ----------------- Filtering graph by modality -----------------")
train_feature_ann = self.filter_graph_by_modality(train_dat, edge_list)
print("[INFO] Number of edges by modality in training data",
{x: train_feature_ann[x]['edge_index'].shape[1] for x in train_feature_ann.keys()})
test_feature_ann = self.filter_graph_by_modality(test_dat, edge_list)
print("[INFO] Number of edges by modality in test data",
{x: test_feature_ann[x]['edge_index'].shape[1] for x in test_feature_ann.keys()})

# log_transform
if self.log_transform:
print("transforming data to log scale")
print("[INFO] transforming data to log scale")
train_dat = self.transform_data(train_dat)
test_dat = self.transform_data(test_dat)

Expand All @@ -402,9 +391,8 @@ def import_data(self):
testing_dataset.dat = {'all': torch.cat([testing_dataset.dat[x] for x in testing_dataset.dat.keys()], dim = 1)}
testing_dataset.features = {'all': list(chain(*testing_dataset.features.values()))}

print("[INFO] Training Data Stats:\n", training_dataset.get_dataset_stats())
print("[INFO] Test Data Stats:\n", testing_dataset.get_dataset_stats())

print("[INFO] Training Data Stats: ", training_dataset.get_dataset_stats())
print("[INFO] Test Data Stats: ", testing_dataset.get_dataset_stats())
print("[INFO] Data import successful.")

return training_dataset, testing_dataset
Expand All @@ -427,16 +415,16 @@ def validate_data_folders(self, training_path, testing_path):
def read_data(self, folder_path):
data = {}
required_files = {'clin.csv'} | {f"{dt}.csv" for dt in self.data_types}
print("\n[INFO] ----------------- Reading Data -----------------")
print("\n[INFO] ----------------- Reading Data ----------------- ")
for file in required_files:
file_path = os.path.join(folder_path, file)
file_name = os.path.splitext(file)[0]
print(f"[INFO] Importing {file_path}...")
data[file_name] = pd.read_csv(file_path, index_col=0)
return data

def read_graph(self, training_data, testing_data):
print("\n[INFO] ----------------- Importing Graph Edges -----------------")
def read_graph(self, train_dat, test_dat):
print("\n[INFO] ----------------- Importing Graph Edges ----------------- ")
# NOTE: stringdb file hardcoded for now
edges_data_path = os.path.join(self.path, self.protein_links)
node_data_path = os.path.join(self.path, self.protein_aliases)
Expand All @@ -460,26 +448,39 @@ def fn(a):

available_features: list[str] = np.unique(graph_df[["protein1", "protein2"]].to_numpy()).tolist()

print("\n[INFO] Removing nodes/edges features which don't exist in omics data matrices")
print("[INFO] Removing nodes/edges features which don't exist in omics data matrices")
# Collect genes from both training and testing data matrices
provided_features = list({x for df in {**training_data, **testing_data}.values() for x in df.index})
provided_features = list({x for df in {**train_dat, **test_dat}.values() for x in df.index})

# Intersect with available genes to filter out non-existing ones
provided_features = set(available_features).intersection(provided_features)

initial_edge_list = stringdb_links_to_list(graph_df)
print("\n[INFO] Number of edges in initial edgelist",len(initial_edge_list))
print("[INFO] Number of edges in initial edgelist",len(initial_edge_list))
# Now filter the graph edges based on provided_genes
edge_list = []
for edge in initial_edge_list:
src, dst = edge
if (src in provided_features) and (dst in provided_features):
edge_list.append(edge)
print("\n[INFO] Number of edges in pruned edgelist",len(edge_list))
print("[INFO] Number of edges in pruned edgelist",len(edge_list))
return edge_list

def filter_graph_by_modality(self, dat, edge_list):
feature_ann = {}
for k, v in dat.items():
mod_gene_list = v.index.to_list()
node_to_idx = {node: i for i, node in enumerate(mod_gene_list)}
mod_edge_list = []
for edge in edge_list:
src, dst = edge
if (src in mod_gene_list) and (dst in mod_gene_list):
mod_edge_list.append([node_to_idx[src], node_to_idx[dst]])
feature_ann[k] = {"edge_index": torch.tensor(mod_edge_list).T}
return feature_ann

def process_data(self, data, split = 'train'):
print(f"\n[INFO] ---------- Processing Data ({split}) ----------")
print(f"\n[INFO] ----------------- Processing Data ({split}) ----------------- ")
# remove uninformative features and samples with no information (from data matrices)
dat = self.cleanup_data({x: data[x] for x in self.data_types})
ann = data['clin']
Expand All @@ -492,7 +493,7 @@ def process_data(self, data, split = 'train'):
return dat, ann, samples, features

def cleanup_data(self, df_dict):
print("\n[INFO] --------------- Cleaning Up Data ---------------")
print("\n[INFO] ----------------- Cleaning Up Data ----------------- ")
cleaned_dfs = {}
sample_masks = []

Expand All @@ -518,7 +519,7 @@ def cleanup_data(self, df_dict):
if np.sum(df.isna().sum()) > 0:
# Identify rows that contain missing values
missing_rows = df.isna().any(axis=1)
print("Imputing NA values to median of features, affected # of features ", np.sum(df.isna().sum()), " # of rows:",sum(missing_rows))
print("[INFO] Imputing NA values to median of features, affected # of features ", np.sum(df.isna().sum()), " # of rows:",sum(missing_rows))

# Only calculate the median for rows with missing values
medians = df[missing_rows].median(axis=1)
Expand All @@ -528,7 +529,7 @@ def cleanup_data(self, df_dict):
# Replace missing values in the row with the corresponding median
df.loc[i] = df.loc[i].fillna(medians[i])

print("Number of NA values: ",np.sum(df.isna().sum()))
print("[INFO] Number of NA values: ",np.sum(df.isna().sum()))

removed_features_count = original_features_count - df.shape[0]
print(f"[INFO] DataFrame {key} - Removed {removed_features_count} features.")
Expand All @@ -550,7 +551,7 @@ def cleanup_data(self, df_dict):
original_samples_count = cleaned_dfs[key].shape[1]
cleaned_dfs[key] = cleaned_dfs[key].loc[:, common_mask]
removed_samples_count = original_samples_count - cleaned_dfs[key].shape[1]
print(f"DataFrame {key} - Removed {removed_samples_count} samples ({removed_samples_count / original_samples_count * 100:.2f}%).")
print(f"[INFO] DataFrame {key} - Removed {removed_samples_count} samples ({removed_samples_count / original_samples_count * 100:.2f}%).")

return cleaned_dfs

Expand All @@ -568,20 +569,22 @@ def filter(self, dat, min_features, top_percentile):
return dat

def harmonize(self, dat1, dat2):
print("\n[INFO] ------------ Harmonizing Data Sets ------------")
print("\n[INFO] ----------------- Harmonizing Data Sets ----------------- ")
# Get common features
common_features = {x: dat1[x].index.intersection(dat2[x].index) for x in self.data_types}
# Subset both datasets to only include common features
dat1 = {x: dat1[x].loc[common_features[x]] for x in dat1.keys()}
dat2 = {x: dat2[x].loc[common_features[x]] for x in dat2.keys()}
print("\n[INFO] ----------------- Finished Harmonizing ----------------- ")

return dat1, dat2

def transform_data(self, data):
transformed_data = {x: np.log1p(data[x].T).T for x in data.keys()}
return transformed_data

def normalize_data(self, data, scaler_type="standard", fit=True):
print("\n[INFO] --------------- Normalizing Data ---------------")
print("\n[INFO] ----------------- Normalizing Data ----------------- ")
# notice matrix transpositions during fit and finally after transformation
# because data matrices have features on rows,
# while scaling methods assume features to be on the columns.
Expand Down

0 comments on commit 36e736d

Please sign in to comment.