diff --git a/aisdc/attacks/structural_attack.py b/aisdc/attacks/structural_attack.py index 76f32453..62bc08a5 100644 --- a/aisdc/attacks/structural_attack.py +++ b/aisdc/attacks/structural_attack.py @@ -25,7 +25,7 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# pylint: disable=too-many-boolean-expressions,chained-comparison +# pylint: disable=chained-comparison def get_unnecessary_risk(model: BaseEstimator) -> bool: @@ -73,7 +73,7 @@ def _get_unnecessary_risk_dt(model: DecisionTreeClassifier) -> bool: min_samples_leaf = model.min_samples_leaf min_samples_split = model.min_samples_split splitter = model.splitter - if ( + return ( (max_depth > 7.5 and min_samples_leaf <= 7.5 and min_samples_split <= 15) or ( splitter == "best" @@ -99,9 +99,7 @@ def _get_unnecessary_risk_dt(model: DecisionTreeClassifier) -> bool: and min_samples_leaf <= 7.5 and max_features is None ) - ): - return True - return False + ) def _get_unnecessary_risk_rf(model: RandomForestClassifier) -> bool: @@ -111,7 +109,7 @@ def _get_unnecessary_risk_rf(model: RandomForestClassifier) -> bool: max_features = model.max_features min_samples_leaf = model.min_samples_leaf min_samples_split = model.min_samples_split - if ( + return ( (max_depth > 3.5 and n_estimators > 35 and max_features is not None) or ( max_depth > 3.5 @@ -126,9 +124,7 @@ def _get_unnecessary_risk_rf(model: RandomForestClassifier) -> bool: and min_samples_leaf <= 15 and not model.bootstrap ) - ): - return True - return False + ) def _get_unnecessary_risk_xgb(model: XGBClassifier) -> bool: @@ -141,13 +137,11 @@ def _get_unnecessary_risk_xgb(model: XGBClassifier) -> bool: n_estimators = int(model.n_estimators) if model.n_estimators else 100 max_depth = float(model.max_depth) if model.max_depth else 6 min_child_weight = float(model.min_child_weight) if model.min_child_weight else 1.0 - if ( + return ( (max_depth > 3.5 and 3.5 < n_estimators <= 12.5 and min_child_weight <= 1.5) or (max_depth > 3.5 and n_estimators > 12.5 and min_child_weight <= 3) or (max_depth > 3.5 and n_estimators > 62.5 and 3 < min_child_weight <= 6) - ): - return True - return False + ) def get_tree_parameter_count(dtree: DecisionTreeClassifier) -> int: