Skip to content

Commit

Permalink
update lira test
Browse files Browse the repository at this point in the history
  • Loading branch information
rpreen committed Jul 8, 2024
1 parent cd72d64 commit 37be9c1
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 6 deletions.
2 changes: 1 addition & 1 deletion aisdc/attacks/likelihood_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__( # pylint: disable=too-many-arguments
self.result["out_prob"] = []
self.result["out_mean"] = []
self.result["out_std"] = []
if self.mode == "online_carlini":
if self.mode == "online-carlini":
self.result["in_prob"] = []
self.result["in_mean"] = []
self.result["in_std"] = []
Expand Down
63 changes: 58 additions & 5 deletions tests/attacks/test_lira_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@pytest.fixture(name="dummy_classifier_setup")
def fixture_dummy_classifier_setup():
"""Set up common things for DummyClassifier."""
dummy = LIRAAttack._DummyClassifier()
dummy = LIRAAttack._DummyClassifier() # pylint: disable=protected-access
X = np.array([[0.2, 0.8], [0.7, 0.3]])
return dummy, X

Expand Down Expand Up @@ -50,8 +50,61 @@ def fixture_lira_classifier_setup():
return target


def test_lira_attack(lira_classifier_setup):
"""Test LiRA attack."""
@pytest.mark.parametrize(
("mode", "expect_error"),
[
("offline", False),
("offline-carlini", False),
("online-carlini", False),
("blah", True),
],
)
@pytest.mark.parametrize("fix_variance", [True, False])
def test_lira_attack(lira_classifier_setup, mode, expect_error, fix_variance):
"""Test LiRA attack with different modes."""
# create target
target = lira_classifier_setup
lira = LIRAAttack(n_shadow_models=20, output_dir="test_output_lira")
lira.attack(target)
# create attack
lira = LIRAAttack(
output_dir="test_output_lira",
write_report=True,
n_shadow_models=20,
p_thresh=0.05,
mode=mode,
fix_variance=fix_variance,
report_individual=True,
)

# check unsupported modes raise an error
if expect_error:
with pytest.raises(ValueError, match="Unsupported LiRA mode.*"):
output = lira.attack(target)
return

# run LiRA
output = lira.attack(target)

# check metadata
metadata = output["metadata"]
assert metadata["attack_name"] == "LiRA Attack"
assert metadata["attack_params"]["n_shadow_models"] == 20
assert metadata["attack_params"]["p_thresh"] == 0.05
assert metadata["attack_params"]["mode"] == mode
assert metadata["attack_params"]["fix_variance"] == fix_variance
assert metadata["attack_params"]["report_individual"]

# check global metrics
global_metrics = metadata["global_metrics"]
sig = "Significant at p=0.05"
not_sig = "Not significant at p=0.05"

if mode == "offline-carlini" and fix_variance:
assert global_metrics["PDIF_sig"] == not_sig
else:
assert global_metrics["PDIF_sig"] == sig
assert global_metrics["AUC_sig"] == sig

# check instance metrics
metrics = output["attack_experiment_logger"]["attack_instance_logger"]["instance_0"]
assert 0 <= metrics["TPR"] <= 1
assert 0 <= metrics["FPR"] <= 1

0 comments on commit 37be9c1

Please sign in to comment.