Skip to content

Commit

Permalink
scope: modify select classes and labels save operation
Browse files Browse the repository at this point in the history
  • Loading branch information
aditya0by0 committed Jan 24, 2025
1 parent c3ba8da commit 764b812
Showing 1 changed file with 41 additions and 34 deletions.
75 changes: 41 additions & 34 deletions chebai/preprocessing/datasets/scope/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import gzip
import os
import shutil
from abc import ABC
from abc import ABC, abstractmethod
from tempfile import NamedTemporaryFile
from typing import Any, Dict, Generator, Optional, Tuple
from typing import Any, Dict, Generator, List, Optional, Tuple

import networkx as nx
import pandas as pd
Expand Down Expand Up @@ -350,21 +350,7 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
"""
print(f"Process graph")

sids = nx.get_node_attributes(graph, "sid")
levels = nx.get_node_attributes(graph, "level")

sun_ids = {}
sids_list = []

selected_sids_dict = self.select_classes(graph)

for sun_id, level in levels.items():
if sun_id in selected_sids_dict:
sun_ids.setdefault(level, []).append(sun_id)
sids_list.append(sids.get(sun_id))

# Remove root node, as it will True for all instances
sun_ids.pop("root", None)
sun_ids = self.select_classes(graph)

if not sun_ids:
raise RuntimeError("No sunid selected.")
Expand Down Expand Up @@ -440,6 +426,10 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
sequence_hierarchy_df = sequence_hierarchy_df[
["id", "sids", "sequence"] + encoded_target_columns
]

with open(os.path.join(self.processed_dir_main, "classes.txt"), "wt") as fout:
fout.writelines(str(sun_id) + "\n" for sun_id in encoded_target_columns)

return sequence_hierarchy_df

def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]:
Expand Down Expand Up @@ -498,6 +488,11 @@ def _update_or_add_sequence(
new_row["sids"] = [row["sid"]]
sequence_hierarchy_df.loc[sequence] = new_row

@abstractmethod
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict[str, List[int]]:
# Override the return type of the method from superclass
pass

# ------------------------------ Phase: Setup data -----------------------------------
def setup_processed(self) -> None:
"""
Expand Down Expand Up @@ -755,24 +750,36 @@ def _name(self) -> str:
"""
return f"SCOPe{self.THRESHOLD}"

def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict:
# Filter nodes and create a dictionary of node and out-degree
sun_ids_dict = {
node: g.out_degree(node) # Store node and its out-degree
for node in g.nodes
if g.out_degree(node) >= self.THRESHOLD
}
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict[str, List[int]]:
"""
Selects classes from the SCOPe dataset based on the number of successors meeting a specified threshold.
# Return a sorted dictionary (by out-degree or node id)
sorted_dict = dict(
sorted(sun_ids_dict.items(), key=lambda item: item[0], reverse=False)
)
This method iterates over the nodes in the graph, counting the number of successors for each node.
Nodes with a number of successors greater than or equal to the defined threshold are selected.
Note:
The input graph must be transitive closure of a directed acyclic graph.
filename = "classes.txt"
with open(os.path.join(self.processed_dir_main, filename), "wt") as fout:
fout.writelines(str(sun_id) + "\n" for sun_id in sorted_dict.keys())
Args:
g (nx.Graph): The graph representing the dataset.
*args: Additional positional arguments (not used).
**kwargs: Additional keyword arguments (not used).
Returns:
Dict: A dict containing selected nodes at each hierarchy level.
return sorted_dict
Notes:
- The `THRESHOLD` attribute should be defined in the subclass of this class.
"""
selected_sunids_for_level = {}
for node, attr_dict in g.nodes(data=True):
if g.out_degree(node) >= self.THRESHOLD:
selected_sunids_for_level.setdefault(attr_dict["level"], []).append(
node
)
# Remove root node, as it will True for all instances
selected_sunids_for_level.pop("root", None)
return selected_sunids_for_level


class _SCOPeOverXPartial(_SCOPeOverX, ABC):
Expand Down Expand Up @@ -860,6 +867,6 @@ class SCOPeOverPartial2000(_SCOPeOverXPartial):


if __name__ == "__main__":
scope = SCOPE(scope_version=2.08)
g = scope._extract_class_hierarchy("d")
scope = SCOPeOver2000(scope_version=2.08)
g = scope._extract_class_hierarchy("dummy/path")
scope._graph_to_raw_dataset(g)

0 comments on commit 764b812

Please sign in to comment.