Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
rpreen committed Jul 4, 2024
1 parent e3deebc commit 5094bf0
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions aisdc/attacks/structural_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# pylint: disable=too-many-boolean-expressions,chained-comparison
# pylint: disable=chained-comparison


def get_unnecessary_risk(model: BaseEstimator) -> bool:
Expand Down Expand Up @@ -73,7 +73,7 @@ def _get_unnecessary_risk_dt(model: DecisionTreeClassifier) -> bool:
min_samples_leaf = model.min_samples_leaf
min_samples_split = model.min_samples_split
splitter = model.splitter
if (
return (
(max_depth > 7.5 and min_samples_leaf <= 7.5 and min_samples_split <= 15)
or (
splitter == "best"
Expand All @@ -99,9 +99,7 @@ def _get_unnecessary_risk_dt(model: DecisionTreeClassifier) -> bool:
and min_samples_leaf <= 7.5
and max_features is None
)
):
return True
return False
)


def _get_unnecessary_risk_rf(model: RandomForestClassifier) -> bool:
Expand All @@ -111,7 +109,7 @@ def _get_unnecessary_risk_rf(model: RandomForestClassifier) -> bool:
max_features = model.max_features
min_samples_leaf = model.min_samples_leaf
min_samples_split = model.min_samples_split
if (
return (
(max_depth > 3.5 and n_estimators > 35 and max_features is not None)
or (
max_depth > 3.5
Expand All @@ -126,9 +124,7 @@ def _get_unnecessary_risk_rf(model: RandomForestClassifier) -> bool:
and min_samples_leaf <= 15
and not model.bootstrap
)
):
return True
return False
)


def _get_unnecessary_risk_xgb(model: XGBClassifier) -> bool:
Expand All @@ -141,13 +137,11 @@ def _get_unnecessary_risk_xgb(model: XGBClassifier) -> bool:
n_estimators = int(model.n_estimators) if model.n_estimators else 100
max_depth = float(model.max_depth) if model.max_depth else 6
min_child_weight = float(model.min_child_weight) if model.min_child_weight else 1.0
if (
return (
(max_depth > 3.5 and 3.5 < n_estimators <= 12.5 and min_child_weight <= 1.5)
or (max_depth > 3.5 and n_estimators > 12.5 and min_child_weight <= 3)
or (max_depth > 3.5 and n_estimators > 62.5 and 3 < min_child_weight <= 6)
):
return True
return False
)


def get_tree_parameter_count(dtree: DecisionTreeClassifier) -> int:
Expand Down

0 comments on commit 5094bf0

Please sign in to comment.