Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated code and examples for compatibility with l_p_norm kwarg #11

Merged
merged 1 commit into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion error_parity/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def eval_accuracy_and_equalized_odds(
y_true: np.ndarray,
y_pred_binary: np.ndarray,
sensitive_attr: np.ndarray,
l_p_norm: int = np.inf,
display: bool = False,
) -> tuple[float, float]:
"""Evaluate accuracy and equalized odds of the given predictions.
Expand All @@ -48,6 +49,8 @@ def eval_accuracy_and_equalized_odds(
The predicted class labels.
sensitive_attr : np.ndarray
The sensitive attribute data.
l_p_norm : int, optional
The norm to use for the constraint violation, by default np.inf.
display : bool, optional
Whether to print results or not, by default False.

Expand All @@ -68,7 +71,7 @@ def eval_accuracy_and_equalized_odds(
roc_points = np.vstack(roc_points)

linf_constraint_violation = [
np.linalg.norm(roc_points[i] - roc_points[j], ord=np.inf)
np.linalg.norm(roc_points[i] - roc_points[j], ord=l_p_norm)
for i, j in product(range(n_groups), range(n_groups))
if i < j
]
Expand Down
67 changes: 34 additions & 33 deletions error_parity/pareto_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import os
import copy
import logging
import traceback
from functools import partial
Expand All @@ -25,42 +26,31 @@


def fit_and_evaluate_postprocessing(
predictor: callable,
postproc_template: RelaxedThresholdOptimizer,
tolerance: float,
fit_data: tuple,
eval_data: tuple | dict[tuple],
fairness_constraint: str = "equalized_odds",
false_pos_cost: float = 1.,
false_neg_cost: float = 1.,
max_roc_ticks: int = 200,
seed: int = 42,
y_fit_pred_scores: np.ndarray = None, # pre-computed predictions on the fit data
bootstrap: bool = True,
bootstrap_kwargs: dict = None,
**bootstrap_kwargs: dict,
) -> dict[str, dict]:
"""Fit and evaluate a postprocessing intervention on the given predictor.

Parameters
----------
predictor : callable
The callable predictor to fit postprocessing on.
postproc_template: RelaxedThresholdOptimizer
An object that serves as the template to copy when creating the
postprocessing optimizer.
tolerance : float
The tolerance (or slack) for fairness constraint fulfillment.
The tolerance (or slack) for fairness constraint fulfillment. This value
will override the `tolerance` attribute of the `postproc_template` object.
fit_data : tuple
The data used to fit postprocessing.
eval_data : tuple or dict[tuple]
The data or sequence of data to evaluate postprocessing on.
If a tuple is provided, will call it "eval" data in the returned results
dictionary; if a dict is provided, will assume {<key_1>: <data_1>, ...}.
fairness_constraint : str, optional
The name of the fairness constraint to use, by default "equalized_odds".
false_pos_cost : float, optional
The cost of a false positive error, by default 1.
false_neg_cost : float, optional
The cost of a false negative error, by default 1.
max_roc_ticks : int, optional
The maximum number of ticks (precision) to use when computing
group-specific ROC curves, by default 200.
seed : int, optional
The random seed, by default 42
y_fit_pred_scores : np.ndarray, optional
Expand All @@ -85,15 +75,8 @@ def fit_and_evaluate_postprocessing(
>>> "test": {"accuracy": 0.65, "...": "..."},
>>> }
"""
clf = RelaxedThresholdOptimizer(
predictor=predictor,
constraint=fairness_constraint,
tolerance=tolerance,
false_pos_cost=false_pos_cost,
false_neg_cost=false_neg_cost,
max_roc_ticks=max_roc_ticks,
seed=seed,
)
clf = copy.copy(postproc_template)
clf.tolerance = tolerance

# Unpack data
X_fit, y_fit, s_fit = fit_data
Expand All @@ -105,7 +88,7 @@ def fit_and_evaluate_postprocessing(
# (Theoretical) fit results
results["fit-theoretical"] = {
"accuracy": 1 - clf.cost(1.0, 1.0),
fairness_constraint: clf.constraint_violation(),
clf.constraint: clf.constraint_violation(),
}

ALLOWED_ABS_ERROR = 1e-5
Expand All @@ -123,12 +106,16 @@ def _evaluate_on_data(data: tuple):
X, Y, S = data

if bootstrap:
kwargs = bootstrap_kwargs or dict(
# Default kwargs for bootstrapping
kwargs = dict(
confidence_pct=95,
seed=seed,
threshold=0.50,
)

# Update kwargs with any extra bootstrap kwargs
kwargs.update(bootstrap_kwargs)

eval_func = partial(
evaluate_predictions_bootstrap,
**kwargs,
Expand Down Expand Up @@ -156,8 +143,9 @@ def _evaluate_on_data(data: tuple):
def compute_postprocessing_curve(
model: object,
fit_data: tuple,
eval_data: tuple or dict[tuple],
eval_data: tuple | dict[tuple],
fairness_constraint: str = "equalized_odds",
l_p_norm: int = np.inf,
bootstrap: bool = True,
tolerance_ticks: list = DEFAULT_TOLERANCE_TICKS,
tolerance_tick_step: float = None,
Expand All @@ -180,7 +168,10 @@ def compute_postprocessing_curve(
format as `fit_data`), or a dictionary of <data_name>-><data_triplet>
containing multiple datasets to evaluate on.
fairness_constraint : str, optional
_description_, by default "equalized_odds"
The fairness constraint to use , by default "equalized_odds".
l_p_norm : int, optional
The norm to use when computing the fairness constraint, by default np.inf.
Note: only compatible with the "equalized odds" constraint.
bootstrap : bool, optional
Whether to compute uncertainty estimates via bootstrapping, by default
False.
Expand Down Expand Up @@ -210,15 +201,25 @@ def callable_predictor(X) -> np.ndarray:
assert 1 <= len(preds.shape) <= 2, f"Model outputs predictions in shape {preds.shape}"
return preds if len(preds.shape) == 1 else preds[:, -1]

# Pre-compute predictions on the fit data
X_fit, _, _ = fit_data
y_fit_pred_scores = callable_predictor(X_fit)

postproc_template = RelaxedThresholdOptimizer(
predictor=callable_predictor,
constraint=fairness_constraint,
l_p_norm=l_p_norm,
)

def _func_call(tol: float):
try:
return fit_and_evaluate_postprocessing(
predictor=callable_predictor,
postproc_template=postproc_template,
tolerance=tol,
fit_data=fit_data,
eval_data=eval_data,
fairness_constraint=fairness_constraint,
bootstrap=bootstrap,
y_fit_pred_scores=y_fit_pred_scores,
**kwargs)

except Exception as exc:
Expand Down
7 changes: 6 additions & 1 deletion error_parity/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,14 @@ def plot_postprocessing_solution(
)

# Set axis settings
fairness_constr_str = postprocessed_clf.constraint.replace("_", " ")
if postprocessed_clf.constraint == "equalized_odds":
l_p_norm = postprocessed_clf.l_p_norm if postprocessed_clf.l_p_norm != np.inf else r"\infty"
fairness_constr_str += f" $\\ell_{l_p_norm}$"

plt.suptitle(f"Solution to {postprocessed_clf.tolerance}-relaxed optimum", y=0.96)
plt.title(
f"(fairness constraint: {postprocessed_clf.constraint.replace('_', ' ')})",
f"(fairness constraint: {fairness_constr_str})",
fontsize="small",
)

Expand Down
18 changes: 17 additions & 1 deletion error_parity/threshold_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def predict(self, X: np.ndarray, *, group: np.ndarray) -> np.ndarray:
return self(X, group=group)

def _check_fit_status(self, raise_error: bool = True) -> bool:
"""Checks whether this classifier has been fit on some data.
"""Check whether this classifier has been fit on some data.

Parameters
----------
Expand All @@ -546,3 +546,19 @@ def _check_fit_status(self, raise_error: bool = True) -> bool:
"This classifier has not yet been fitted to any data.")

return True

def __copy__(self):
"""Create a shallow copy of this object.
The returned copy is in a blank state, i.e., it has not been fit to any
data.
"""
return self.__class__(
predictor=self.predictor,
constraint=self.constraint,
tolerance=self.tolerance,
false_pos_cost=self.false_pos_cost,
false_neg_cost=self.false_neg_cost,
l_p_norm=self.l_p_norm,
max_roc_ticks=self.max_roc_ticks,
seed=self.seed,
)
Loading
Loading