Skip to content

Commit

Permalink
black linter run
Browse files Browse the repository at this point in the history
  • Loading branch information
anpolol committed Jan 8, 2024
1 parent 5fff44b commit 71000c9
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
8 changes: 7 additions & 1 deletion stable_gnn/embedding/embedding_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,13 @@ def _get_emb_settings(loss_name: str) -> Dict[str, Any]:
raise NameError

def build_embeddings(
self, loss_name: str, conv: str, data: List[Graph], device: device, number_of_trials: int, tune_out: bool = False
self,
loss_name: str,
conv: str,
data: List[Graph],
device: device,
number_of_trials: int,
tune_out: bool = False,
) -> NDArray:
"""Build embeddings based on passed dataset and settings
Expand Down
12 changes: 10 additions & 2 deletions stable_gnn/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ def _perturb_features_on_node(feature_matrix: NDArray, target: int, use_random:
return features_perturb

def _data_generation(
self, target: Optional[int] = None, index_to_perturb: Optional[list] = None, num_samples: int = 100, pred_threshold: float = 0.1
self,
target: Optional[int] = None,
index_to_perturb: Optional[list] = None,
num_samples: int = 100,
pred_threshold: float = 0.1,
) -> Tuple[pd.DataFrame, NDArray]:
if target is None:
neighbors = list(range(self.adj_matrix.shape[0]))
Expand Down Expand Up @@ -206,7 +210,11 @@ def _data_generation(
return data, neighbors

def _variable_selection(
self, target: Optional[int] = None, top_node: Optional[int] = None, num_samples: int = 100, pred_threshold: float = 0.1
self,
target: Optional[int] = None,
top_node: Optional[int] = None,
num_samples: int = 100,
pred_threshold: float = 0.1,
) -> Tuple[List[int], pd.DataFrame, Dict[int, float]]:
if target is None:
data, neighbors = self._data_generation(
Expand Down
4 changes: 3 additions & 1 deletion stable_gnn/model_link_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def __init__(
self.emb_conv_name = emb_conv_name
self.device = device

def train_test_edges(self, dataset: Graph) -> Tuple[List[List[int]], List[List[int]], List[List[int]], List[List[int]]]:
def train_test_edges(
self, dataset: Graph
) -> Tuple[List[List[int]], List[List[int]], List[List[int]], List[List[int]]]:
"""
Split dataset to train and test and calculate negative samples
Expand Down

0 comments on commit 71000c9

Please sign in to comment.