Skip to content

Commit

Permalink
scope: data filtering update
Browse files Browse the repository at this point in the history
- consider proteins domain in the dataset which maps to any selected node irrespective of the hierarchy level
  • Loading branch information
aditya0by0 committed Feb 15, 2025
1 parent 45c1015 commit d3fd0f2
Showing 1 changed file with 65 additions and 48 deletions.
113 changes: 65 additions & 48 deletions chebai/preprocessing/datasets/scope/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,9 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
"""
print(f"Process graph")

sun_ids = self.select_classes(graph)
selected_sun_ids_per_lvl = self.select_classes(graph)

if not sun_ids:
if not selected_sun_ids_per_lvl:
raise RuntimeError("No sunid selected.")

df_cla = self._get_classification_data()
Expand All @@ -362,38 +362,35 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:

df_cla = df_cla[["sid", "sunid"] + hierarchy_levels]

# This filtering make sures to consider only domains that belongs to each `selected` hierarchy level
# So, that our data has domains that maps to all levels of the taxonomy
for level, selected_sun_ids in sun_ids.items():
if selected_sun_ids:
df_cla = df_cla[
df_cla[self.SCOPE_HIERARCHY[level]].isin(selected_sun_ids)
]

assert (
len(df_cla) > 1
), "dataframe should have more than one instance for `pd.get_dummies` to work as expected"

df_encoded = pd.get_dummies(
df_cla,
columns=hierarchy_levels,
drop_first=False,
sparse=True,
)
# Initialize selected target columns
df_encoded = df_cla[["sid", "sunid"]].copy()

lvl_to_target_cols_mapping = {}
# Iterate over only the selected sun_ids (nodes) to one-hot encode them
for level, selected_sun_ids in selected_sun_ids_per_lvl.items():
level_column = self.SCOPE_HIERARCHY[
level
] # Get the actual column name in df_cla
if level_column in df_cla.columns:
# Create binary encoding for only relevant sun_ids
for sun_id in selected_sun_ids:
col_name = f"{level_column}_{sun_id}"
df_encoded[col_name] = (df_cla[level_column] == sun_id).astype(bool)
lvl_to_target_cols_mapping.setdefault(level_column, []).append(
col_name
)

pdb_chain_seq_mapping = self._parse_pdb_sequence_file()
# Filter to select only domains that atleast map to any one selected sunid in any level
df_encoded = df_encoded[df_encoded.iloc[:, 2:].any(axis=1)]

encoded_target_cols = {}
for col in hierarchy_levels:
encoded_target_cols[col] = [
t_col for t_col in df_encoded.columns if t_col.startswith(col)
]
pdb_chain_seq_mapping = self._parse_pdb_sequence_file()

encoded_target_columns = []
for level in hierarchy_levels:
encoded_target_columns.extend(encoded_target_cols[level])
encoded_target_columns.extend(lvl_to_target_cols_mapping[level])

sequence_hierarchy_df = pd.DataFrame(columns=["sids"] + encoded_target_columns)
df_encoded = df_encoded[["sid", "sunid"] + encoded_target_columns]

for _, row in df_encoded.iterrows():
sid = row["sid"]
Expand All @@ -410,14 +407,19 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
chain_sequence = pdb_to_chain_mapping.get(chain_id, None)
if chain_sequence:
self._update_or_add_sequence(
chain_sequence, row, sequence_hierarchy_df, encoded_target_cols
chain_sequence,
row,
sequence_hierarchy_df,
encoded_target_columns,
)

else:
# Add nodes and edges for chains in the mapping
for chain, chain_sequence in pdb_to_chain_mapping.items():
self._update_or_add_sequence(
chain_sequence, row, sequence_hierarchy_df, encoded_target_cols
chain_sequence,
row,
sequence_hierarchy_df,
encoded_target_columns,
)

sequence_hierarchy_df.reset_index(inplace=True)
Expand All @@ -427,6 +429,10 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
sequence_hierarchy_df = sequence_hierarchy_df[
["id", "sids", "sequence"] + encoded_target_columns
]
# Ensure atleast one label is true for each protein sequence
sequence_hierarchy_df = sequence_hierarchy_df[
sequence_hierarchy_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)
]

with open(os.path.join(self.processed_dir_main, "classes.txt"), "wt") as fout:
fout.writelines(str(sun_id) + "\n" for sun_id in encoded_target_columns)
Expand Down Expand Up @@ -458,7 +464,10 @@ def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]:

@staticmethod
def _update_or_add_sequence(
sequence, row, sequence_hierarchy_df, encoded_col_names
sequence: str,
row: pd.Series,
sequence_hierarchy_df: pd.DataFrame,
encoded_target_columns: List[str],
):
"""
Updates an existing sequence entry or adds a new one to the DataFrame.
Expand All @@ -467,29 +476,25 @@ def _update_or_add_sequence(
sequence (str): Amino acid sequence of the chain.
row (pd.Series): Row data containing SCOPe hierarchy levels and associated values.
sequence_hierarchy_df (pd.DataFrame): DataFrame storing sequences and their hierarchy labels.
encoded_col_names (Dict[str, List[str]]): Mapping of hierarchy levels to encoded column names.
encoded_target_columns (List): List of column names which must be in same order in row and sequence_hierarchy_df.
Raises:
AssertionError: If a sequence instance belongs to more than one hierarchy level.
"""
if sequence in sequence_hierarchy_df.index:
# Update encoded columns only if they are True
for col in encoded_col_names:
assert (
sum(row[encoded_col_names[col]].tolist()) == 1
), "A instance can belong to only one hierarchy level"
sliced_data = row[
encoded_col_names[col]
] # Slice starting from the second column (index 1)
# Get the column name with the True value
true_column = sliced_data.idxmax() if sliced_data.any() else None
sequence_hierarchy_df.loc[sequence, true_column] = True

sequence_hierarchy_df.loc[sequence, "sids"].append(row["sid"])
# Update encoded columns using bitwise OR (ensures values remain True if they were previously True)
sequence_hierarchy_df.loc[sequence, encoded_target_columns] = (
row[encoded_target_columns]
| sequence_hierarchy_df.loc[sequence, encoded_target_columns]
)

sequence_hierarchy_df.at[sequence, "sids"] = sequence_hierarchy_df.at[
sequence, "sids"
] + [row["sid"]]

else:
# Add new row with sequence as the index and hierarchy data
new_row = row
new_row = row.to_dict()
new_row["sids"] = [row["sid"]]
sequence_hierarchy_df.loc[sequence] = new_row

Expand Down Expand Up @@ -859,7 +864,7 @@ class SCOPeOver2000(_SCOPeOverX):


class SCOPeOver50(_SCOPeOverX):

THRESHOLD = 50


Expand All @@ -878,5 +883,17 @@ class SCOPeOverPartial2000(_SCOPeOverXPartial):

if __name__ == "__main__":
scope = SCOPeOver2000(scope_version="2.08")
g = scope._extract_class_hierarchy("dummy/path")
# g = scope._extract_class_hierarchy("dummy/path")
# # Save graph
# import pickle
# with open("graph.gpickle", "wb") as f:
# pickle.dump(g, f)

# Load graph
import pickle

with open("graph.gpickle", "rb") as f:
g = pickle.load(f)

# print(len([node for node in g.nodes() if g.out_degree(node) > 10000]))
scope._graph_to_raw_dataset(g)

0 comments on commit d3fd0f2

Please sign in to comment.