Skip to content

Commit

Permalink
use default xgboost params if not defined in structural attacks (#277)
Browse files Browse the repository at this point in the history
* set xgb params to defaults if not defined for stability in get_unnecessary_risk()

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jim-smith and pre-commit-ci[bot] authored Jun 3, 2024
1 parent a0d7876 commit beebe14
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions aisdc/attacks/structural_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def get_unnecessary_risk(model: BaseEstimator) -> bool:
model, (DecisionTreeClassifier, RandomForestClassifier, XGBClassifier)
):
return 0 # no experimental evidence to support rejection

unnecessary_risk = 0
max_depth = float(model.max_depth) if model.max_depth else 500

Expand Down Expand Up @@ -119,21 +120,19 @@ def get_unnecessary_risk(model: BaseEstimator) -> bool:
unnecessary_risk = 1

elif isinstance(model, XGBClassifier):
n_estimators = model.n_estimators
if n_estimators is None:
n_estimators = 1000
# checking whether params exist and using xgboost defaults if not using defaults
# from https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/sklearn.py
# and here: https://xgboost.readthedocs.io/en/stable/parameter.html
n_estimators = int(model.n_estimators) if model.n_estimators else 100
max_depth = float(model.max_depth) if model.max_depth else 6
min_child_weight = (
float(model.min_child_weight) if model.min_child_weight else 1.0
)

if (
(
max_depth > 3.5
and 3.5 < n_estimators <= 12.5
and model.min_child_weight <= 1.5
)
or (max_depth > 3.5 and n_estimators > 12.5 and model.min_child_weight <= 3)
or (
max_depth > 3.5
and n_estimators > 62.5
and 3 < model.min_child_weight <= 6
)
(max_depth > 3.5 and 3.5 < n_estimators <= 12.5 and min_child_weight <= 1.5)
or (max_depth > 3.5 and n_estimators > 12.5 and min_child_weight <= 3)
or (max_depth > 3.5 and n_estimators > 62.5 and 3 < min_child_weight <= 6)
):
unnecessary_risk = 1
return unnecessary_risk
Expand Down

0 comments on commit beebe14

Please sign in to comment.