Skip to content

Commit

Permalink
fix some dockstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
anpolol committed Dec 4, 2023
1 parent 3f0c277 commit 07d38d3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
2 changes: 1 addition & 1 deletion stable_gnn/embedding/sampling/abstract_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _sample_negative(self, batch: Tensor, num_negative_samples: int) -> Tensor:
Sample negative edges for batch of nodes
:param batch: (Batch): Nodes for negative sampling
:param num_negative_samples: (int): number of negative samples for each edge
:param num_negative_samples: (int): Number of negative samples for each edge
:return: (Tensor): Negative samples
"""
a, _ = subgraph(batch, self.data.edge_index)
Expand Down
35 changes: 18 additions & 17 deletions stable_gnn/fairness.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def run(
interior_classifier: str = "rf",
verbose: bool = False,
multiplier: int = 1,
random_state: int = None,
):
random_state: int = 32,
) -> Dict[str, Any]:
"""
Correct fairness and calculate accuracy and fairness values
Expand Down Expand Up @@ -87,7 +87,7 @@ def run(

return ans

def _ml_model(self, df, random_state, estimator, prefit):
def _ml_model(self, df: pd.DataFrame, random_state: int, estimator: BaseEstimator, prefit: bool) -> Dict[str, Any]:
y = df.drop("target", axis=1)
x = df["target"]

Expand Down Expand Up @@ -155,13 +155,13 @@ def _ml_model(self, df, random_state, estimator, prefit):

def _lp_solver(
self,
d,
number_iterations=10,
classifier=RandomForestClassifier(),
verbose=False,
multiplier=1,
random_state=None,
):
d: Dict,
number_iterations: int = 10,
classifier: BaseEstimator = RandomForestClassifier(),
verbose: bool = False,
multiplier: int = 1,
random_state: int = 32,
) -> Dict[str, Any]:
one_group = multiplier * d["one_group"]
zero_group = multiplier * d["zero_group"]

Expand Down Expand Up @@ -325,7 +325,7 @@ def _lp_solver(
print("Fitting is finished")
return ans

def _predictor(self, solved, d, verbose=False):
def _predictor(self, solved: Dict, d: Dict, verbose: bool = False) -> Dict[str, Any]:
if verbose:
print("Predicting in process")
one_predictor_array = solved["one_predictor_array"]
Expand Down Expand Up @@ -364,13 +364,14 @@ def _predictor(self, solved, d, verbose=False):

return ans

def _cuae(self, y_true, y_pred, sensitive_features) -> Dict[str, Any]:
def _cuae(self, y_true: List, y_pred: List, sensitive_features: List) -> Dict[str, Any]:
"""
Calculate metrics
y_true - stands for the true label
y_pred - a forecast
sensitive_features - sensitive attribute
:param y_true: (List) True label
:param y_pred: (List) Prediction
:param sensitive_features: (List) Sensitive attributes
:return: (Dict): Dictionary containing metrics for farness calculus
"""
true = np.array(y_true)
pred = np.array(y_pred)
Expand Down Expand Up @@ -407,15 +408,15 @@ def _cuae(self, y_true, y_pred, sensitive_features) -> Dict[str, Any]:
ans = {"df": df, "diff": total_diff, "ratio": total_ratio, "variation": variation}
return ans

def _zeros_ones_to_classes(self, x, length=3):
def _zeros_ones_to_classes(self, x: List, length: int = 3):
n = int(len(x) / length)
p = []
for i in range(n):
z = x[i * length : i * length + length]
p.append(z.argmax())
return np.array(p, dtype=int)

def _answer_creator(self, x, y, grouper):
def _answer_creator(self, x: List, y: List, grouper: List):
x = np.array(x) # array of 1
y = np.array(y) # array of 0
grouper = np.array(grouper)
Expand Down

0 comments on commit 07d38d3

Please sign in to comment.