diff --git a/development_files/structural_attack.ipynb b/development_files/structural_attack.ipynb index 14c988c8..f64020d0 100644 --- a/development_files/structural_attack.ipynb +++ b/development_files/structural_attack.ipynb @@ -1,957 +1,957 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "5b5bd89b-c0f9-476a-80a2-79ad044e11d2", - "metadata": {}, - "source": [ - "# Notebook for developing code to go into structural_attack class" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "dd7f7614-cbac-43a5-bf90-a59712eca953", - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "import os\n", - "\n", - "\n", - "# for development use local copy of aisdc in preference to installed version\n", - "sys.path.insert(0, os.path.abspath(\"..\"))" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "711cbd17-2e8e-452c-b9be-0b662579e333", - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "\n", - "import numpy as np\n", - "from sklearn.datasets import load_breast_cancer\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.tree import DecisionTreeClassifier\n", - "from sklearn.ensemble import RandomForestClassifier\n", - "from xgboost import XGBClassifier\n", - "from sklearn.svm import SVC\n", - "\n", - "\n", - "from aisdc.attacks.structural_attack import (\n", - " StructuralAttack,\n", - ") # pylint: disable = import-error\n", - "from aisdc.attacks.target import Target # pylint: disable = import-error" - ] - }, - { - "cell_type": "markdown", - "id": "536bf3bd-b5cc-4c8e-abed-bcd6dcfdf96e", - "metadata": {}, - "source": [ - "## helper function for test" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "b761fd2e-4a96-49f3-bbc4-888e26382c15", - "metadata": {}, - "outputs": [], - "source": [ - "def get_target(modeltype: str, **kwargs) -> Target:\n", - " \"\"\"loads dataset and creates target of the desired type\"\"\"\n", - "\n", - " X, y = load_breast_cancer(return_X_y=True, as_frame=False)\n", - " train_X, test_X, train_y, test_y = train_test_split(X, y, test_size=0.3)\n", - "\n", - " # these types should be handled\n", - " if modeltype == \"dt\":\n", - " target_model = DecisionTreeClassifier(**kwargs)\n", - " elif modeltype == \"rf\":\n", - " target_model = RandomForestClassifier(**kwargs)\n", - " elif modeltype == \"xgb\":\n", - " target_model = XGBClassifier(**kwargs)\n", - " # should get polite error but not DoF yet\n", - " elif modeltype == \"svc\":\n", - " target_model = SVC(**kwargs)\n", - " else:\n", - " raise NotImplementedError(\"model type passed to get_model unknown\")\n", - "\n", - " # Train the classifier\n", - " target_model.fit(train_X, train_y)\n", - "\n", - " # Wrap the model and data in a Target object\n", - " target = Target(model=target_model)\n", - " target.add_processed_data(train_X, train_y, test_X, test_y)\n", - "\n", - " return target" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "c4821ec0-f718-45f3-8912-3fcf69056e4e", - "metadata": {}, - "outputs": [], - "source": [ - "import importlib\n", - "import aisdc.attacks.structural_attack\n", - "\n", - "importlib.reload(aisdc.attacks.structural_attack)\n", - "from aisdc.attacks.structural_attack import StructuralAttack" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "8413ddf8-730f-4bf6-8299-7e01dc4806e3", - "metadata": {}, - "outputs": [], - "source": [ - "def test_dt():\n", - " \"\"\"test for decision tree classifier\"\"\"\n", - "\n", - " print(\"\\n\\n\\n====== Non Disclosive ====\\n\\n\")\n", - "\n", - " param_dict = {\"max_depth\": 1, \"min_samples_leaf\": 150}\n", - " target = get_target(\"dt\", **param_dict)\n", - " target_path = target.save(\"dt.sav\")\n", - " myattack = StructuralAttack(target_path=\"dt.sav\")\n", - " myattack.attack(target)\n", - " # assert myattack.DoF_risk ==0 ,\"should be no DoF risk with devision stump\"\n", - " # assert myattack.k_anonymity_risk ==0, 'should be no k-anonymity risk with min_samples_leaf 150'\n", - " # assert myattack.class_disclosure_risk ==0,'no class disclsoure risk for stump with min samles leaf 150'\n", - " # assert myattack.unnecessary_risk ==0, 'not unnecessary risk if max_depth < 3.5'\n", - " print(\n", - " f\"equiv_classes is {myattack.equiv_classes}\\n\"\n", - " f\"equiv_counts is {myattack.equiv_counts}\\n\"\n", - " f\"equiv_members is {myattack.equiv_members}\\n\"\n", - " )\n", - "\n", - " print(\"\\n\\n\\n====== Now Disclosive ====\\n\\n\")\n", - " # highly disclosive\n", - " param_dict = {\"max_depth\": None, \"min_samples_leaf\": 5, \"min_samples_split\": 2}\n", - " target2 = get_target(\"dt\", **param_dict)\n", - " myattack2 = StructuralAttack()\n", - " myattack2.attack(target2)\n", - " # assert myattack2.DoF_risk ==0 ,\"should be no DoF risk with decision stump\"\n", - " # assert myattack2.k_anonymity_risk ==1, 'should be k-anonymity risk with unlimited depth and min_samples_leaf 5'\n", - " # assert myattack2.class_disclosure_risk ==1,'should be class disclosure risk with unlimited depth and min_samples_leaf 5'\n", - " # assert myattack2.unnecessary_risk ==1, ' unnecessary risk with unlimited depth and min_samples_leaf 5'\n", - " # print(f' attack._get_param_names returns {myattack2._get_param_names()}')\n", - " # print(f' attack.get_params returns {myattack2.get_params()}')\n", - "\n", - " print(\n", - " f\"equiv_classes is {myattack2.equiv_classes}\\n\"\n", - " f\"equiv_counts is {myattack2.equiv_counts}\\n\"\n", - " f\"equiv_members is {myattack2.equiv_members}\\n\"\n", - " )\n", - "\n", - " # myattack.make_report()" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "215781af-a74d-4300-b572-1a9f696457b8", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:acro:version: 0.4.2\n", - "INFO:acro:config: {'safe_threshold': 10, 'safe_dof_threshold': 10, 'safe_nk_n': 2, 'safe_nk_k': 0.9, 'safe_pratio_p': 0.1, 'check_missing_values': False}\n", - "INFO:acro:automatic suppression: False\n", - "INFO:structural_attack:Thresholds for count 10 and DoF 10\n", - "INFO:acro:version: 0.4.2\n", - "INFO:acro:config: {'safe_threshold': 10, 'safe_dof_threshold': 10, 'safe_nk_n': 2, 'safe_nk_k': 0.9, 'safe_pratio_p': 0.1, 'check_missing_values': False}\n", - "INFO:acro:automatic suppression: False\n", - "INFO:structural_attack:Thresholds for count 10 and DoF 10\n" - ] + "cells": [ + { + "cell_type": "markdown", + "id": "5b5bd89b-c0f9-476a-80a2-79ad044e11d2", + "metadata": {}, + "source": [ + "# Notebook for developing code to go into structural_attack class" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "dd7f7614-cbac-43a5-bf90-a59712eca953", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "\n", + "\n", + "# for development use local copy of aisdc in preference to installed version\n", + "sys.path.insert(0, os.path.abspath(\"..\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "711cbd17-2e8e-452c-b9be-0b662579e333", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "import numpy as np\n", + "from sklearn.datasets import load_breast_cancer\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.tree import DecisionTreeClassifier\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from xgboost import XGBClassifier\n", + "from sklearn.svm import SVC\n", + "\n", + "\n", + "from aisdc.attacks.structural_attack import (\n", + " StructuralAttack,\n", + ") # pylint: disable = import-error\n", + "from aisdc.attacks.target import Target # pylint: disable = import-error" + ] + }, + { + "cell_type": "markdown", + "id": "536bf3bd-b5cc-4c8e-abed-bcd6dcfdf96e", + "metadata": {}, + "source": [ + "## helper function for test" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b761fd2e-4a96-49f3-bbc4-888e26382c15", + "metadata": {}, + "outputs": [], + "source": [ + "def get_target(modeltype: str, **kwargs) -> Target:\n", + " \"\"\"loads dataset and creates target of the desired type\"\"\"\n", + "\n", + " X, y = load_breast_cancer(return_X_y=True, as_frame=False)\n", + " train_X, test_X, train_y, test_y = train_test_split(X, y, test_size=0.3)\n", + "\n", + " # these types should be handled\n", + " if modeltype == \"dt\":\n", + " target_model = DecisionTreeClassifier(**kwargs)\n", + " elif modeltype == \"rf\":\n", + " target_model = RandomForestClassifier(**kwargs)\n", + " elif modeltype == \"xgb\":\n", + " target_model = XGBClassifier(**kwargs)\n", + " # should get polite error but not DoF yet\n", + " elif modeltype == \"svc\":\n", + " target_model = SVC(**kwargs)\n", + " else:\n", + " raise NotImplementedError(\"model type passed to get_model unknown\")\n", + "\n", + " # Train the classifier\n", + " target_model.fit(train_X, train_y)\n", + "\n", + " # Wrap the model and data in a Target object\n", + " target = Target(model=target_model)\n", + " target.add_processed_data(train_X, train_y, test_X, test_y)\n", + "\n", + " return target" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c4821ec0-f718-45f3-8912-3fcf69056e4e", + "metadata": {}, + "outputs": [], + "source": [ + "import importlib\n", + "import aisdc.attacks.structural_attack\n", + "\n", + "importlib.reload(aisdc.attacks.structural_attack)\n", + "from aisdc.attacks.structural_attack import StructuralAttack" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "8413ddf8-730f-4bf6-8299-7e01dc4806e3", + "metadata": {}, + "outputs": [], + "source": [ + "def test_dt():\n", + " \"\"\"test for decision tree classifier\"\"\"\n", + "\n", + " print(\"\\n\\n\\n====== Non Disclosive ====\\n\\n\")\n", + "\n", + " param_dict = {\"max_depth\": 1, \"min_samples_leaf\": 150}\n", + " target = get_target(\"dt\", **param_dict)\n", + " target_path = target.save(\"dt.sav\")\n", + " myattack = StructuralAttack(target_path=\"dt.sav\")\n", + " myattack.attack(target)\n", + " # assert myattack.DoF_risk ==0 ,\"should be no DoF risk with devision stump\"\n", + " # assert myattack.k_anonymity_risk ==0, 'should be no k-anonymity risk with min_samples_leaf 150'\n", + " # assert myattack.class_disclosure_risk ==0,'no class disclsoure risk for stump with min samles leaf 150'\n", + " # assert myattack.unnecessary_risk ==0, 'not unnecessary risk if max_depth < 3.5'\n", + " print(\n", + " f\"equiv_classes is {myattack.equiv_classes}\\n\"\n", + " f\"equiv_counts is {myattack.equiv_counts}\\n\"\n", + " f\"equiv_members is {myattack.equiv_members}\\n\"\n", + " )\n", + "\n", + " print(\"\\n\\n\\n====== Now Disclosive ====\\n\\n\")\n", + " # highly disclosive\n", + " param_dict = {\"max_depth\": None, \"min_samples_leaf\": 5, \"min_samples_split\": 2}\n", + " target2 = get_target(\"dt\", **param_dict)\n", + " myattack2 = StructuralAttack()\n", + " myattack2.attack(target2)\n", + " # assert myattack2.DoF_risk ==0 ,\"should be no DoF risk with decision stump\"\n", + " # assert myattack2.k_anonymity_risk ==1, 'should be k-anonymity risk with unlimited depth and min_samples_leaf 5'\n", + " # assert myattack2.class_disclosure_risk ==1,'should be class disclosure risk with unlimited depth and min_samples_leaf 5'\n", + " # assert myattack2.unnecessary_risk ==1, ' unnecessary risk with unlimited depth and min_samples_leaf 5'\n", + " # print(f' attack._get_param_names returns {myattack2._get_param_names()}')\n", + " # print(f' attack.get_params returns {myattack2.get_params()}')\n", + "\n", + " print(\n", + " f\"equiv_classes is {myattack2.equiv_classes}\\n\"\n", + " f\"equiv_counts is {myattack2.equiv_counts}\\n\"\n", + " f\"equiv_members is {myattack2.equiv_members}\\n\"\n", + " )\n", + "\n", + " # myattack.make_report()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "215781af-a74d-4300-b572-1a9f696457b8", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:acro:version: 0.4.2\n", + "INFO:acro:config: {'safe_threshold': 10, 'safe_dof_threshold': 10, 'safe_nk_n': 2, 'safe_nk_k': 0.9, 'safe_pratio_p': 0.1, 'check_missing_values': False}\n", + "INFO:acro:automatic suppression: False\n", + "INFO:structural_attack:Thresholds for count 10 and DoF 10\n", + "INFO:acro:version: 0.4.2\n", + "INFO:acro:config: {'safe_threshold': 10, 'safe_dof_threshold': 10, 'safe_nk_n': 2, 'safe_nk_k': 0.9, 'safe_pratio_p': 0.1, 'check_missing_values': False}\n", + "INFO:acro:automatic suppression: False\n", + "INFO:structural_attack:Thresholds for count 10 and DoF 10\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\n", + "====== Non Disclosive ====\n", + "\n", + "\n", + "ingroup [ 0 2 5 7 9 10 11 12 14 15 17 19 20 22 24 26 28 29\n", + " 30 32 34 35 36 37 38 39 40 41 43 44 46 47 48 50 51 52\n", + " 53 54 57 58 61 63 64 65 66 69 71 72 73 76 78 81 82 85\n", + " 86 87 88 89 92 93 94 97 98 102 103 105 106 108 109 110 113 115\n", + " 116 117 118 121 123 125 126 128 130 131 132 134 135 137 138 139 140 141\n", + " 142 143 145 146 147 148 149 153 154 156 158 162 163 167 169 170 172 173\n", + " 176 178 180 181 182 183 184 186 187 188 192 193 195 196 197 198 199 201\n", + " 202 203 206 207 209 211 213 215 218 219 220 222 223 224 225 226 228 229\n", + " 231 233 235 237 238 240 241 242 245 246 247 248 250 252 254 256 258 259\n", + " 261 262 264 265 269 272 273 274 275 276 277 278 279 282 284 285 286 288\n", + " 289 294 299 300 303 304 307 308 309 311 312 314 315 316 317 318 319 321\n", + " 323 324 328 329 330 332 334 335 336 340 342 344 346 347 351 352 354 355\n", + " 357 358 360 362 365 366 367 371 373 374 375 377 379 380 381 384 386 388\n", + " 389 392 394 395 396 397],count 240\n", + "ingroup [ 1 3 4 6 8 13 16 18 21 23 25 27 31 33 42 45 49 55\n", + " 56 59 60 62 67 68 70 74 75 77 79 80 83 84 90 91 95 96\n", + " 99 100 101 104 107 111 112 114 119 120 122 124 127 129 133 136 144 150\n", + " 151 152 155 157 159 160 161 164 165 166 168 171 174 175 177 179 185 189\n", + " 190 191 194 200 204 205 208 210 212 214 216 217 221 227 230 232 234 236\n", + " 239 243 244 249 251 253 255 257 260 263 266 267 268 270 271 280 281 283\n", + " 287 290 291 292 293 295 296 297 298 301 302 305 306 310 313 320 322 325\n", + " 326 327 331 333 337 338 339 341 343 345 348 349 350 353 356 359 361 363\n", + " 364 368 369 370 372 376 378 382 383 385 387 390 391 393],count 158\n", + "equiv_classes is [1 2]\n", + "equiv_counts is [240 158]\n", + "equiv_members is [array([ 0, 2, 5, 7, 9, 10, 11, 12, 14, 15, 17, 19, 20,\n", + " 22, 24, 26, 28, 29, 30, 32, 34, 35, 36, 37, 38, 39,\n", + " 40, 41, 43, 44, 46, 47, 48, 50, 51, 52, 53, 54, 57,\n", + " 58, 61, 63, 64, 65, 66, 69, 71, 72, 73, 76, 78, 81,\n", + " 82, 85, 86, 87, 88, 89, 92, 93, 94, 97, 98, 102, 103,\n", + " 105, 106, 108, 109, 110, 113, 115, 116, 117, 118, 121, 123, 125,\n", + " 126, 128, 130, 131, 132, 134, 135, 137, 138, 139, 140, 141, 142,\n", + " 143, 145, 146, 147, 148, 149, 153, 154, 156, 158, 162, 163, 167,\n", + " 169, 170, 172, 173, 176, 178, 180, 181, 182, 183, 184, 186, 187,\n", + " 188, 192, 193, 195, 196, 197, 198, 199, 201, 202, 203, 206, 207,\n", + " 209, 211, 213, 215, 218, 219, 220, 222, 223, 224, 225, 226, 228,\n", + " 229, 231, 233, 235, 237, 238, 240, 241, 242, 245, 246, 247, 248,\n", + " 250, 252, 254, 256, 258, 259, 261, 262, 264, 265, 269, 272, 273,\n", + " 274, 275, 276, 277, 278, 279, 282, 284, 285, 286, 288, 289, 294,\n", + " 299, 300, 303, 304, 307, 308, 309, 311, 312, 314, 315, 316, 317,\n", + " 318, 319, 321, 323, 324, 328, 329, 330, 332, 334, 335, 336, 340,\n", + " 342, 344, 346, 347, 351, 352, 354, 355, 357, 358, 360, 362, 365,\n", + " 366, 367, 371, 373, 374, 375, 377, 379, 380, 381, 384, 386, 388,\n", + " 389, 392, 394, 395, 396, 397]), array([ 1, 3, 4, 6, 8, 13, 16, 18, 21, 23, 25, 27, 31,\n", + " 33, 42, 45, 49, 55, 56, 59, 60, 62, 67, 68, 70, 74,\n", + " 75, 77, 79, 80, 83, 84, 90, 91, 95, 96, 99, 100, 101,\n", + " 104, 107, 111, 112, 114, 119, 120, 122, 124, 127, 129, 133, 136,\n", + " 144, 150, 151, 152, 155, 157, 159, 160, 161, 164, 165, 166, 168,\n", + " 171, 174, 175, 177, 179, 185, 189, 190, 191, 194, 200, 204, 205,\n", + " 208, 210, 212, 214, 216, 217, 221, 227, 230, 232, 234, 236, 239,\n", + " 243, 244, 249, 251, 253, 255, 257, 260, 263, 266, 267, 268, 270,\n", + " 271, 280, 281, 283, 287, 290, 291, 292, 293, 295, 296, 297, 298,\n", + " 301, 302, 305, 306, 310, 313, 320, 322, 325, 326, 327, 331, 333,\n", + " 337, 338, 339, 341, 343, 345, 348, 349, 350, 353, 356, 359, 361,\n", + " 363, 364, 368, 369, 370, 372, 376, 378, 382, 383, 385, 387, 390,\n", + " 391, 393])]\n", + "\n", + "\n", + "\n", + "\n", + "====== Now Disclosive ====\n", + "\n", + "\n", + "ingroup [ 34 89 95 100 157 252 282 296],count 8\n", + "ingroup [ 29 42 183 337 393],count 5\n", + "ingroup [ 1 2 3 4 6 7 9 10 12 13 16 20 21 22 27 28 35 36\n", + " 37 39 40 46 48 49 50 51 52 55 57 58 63 64 66 67 69 70\n", + " 72 76 77 79 80 81 85 93 97 102 104 105 106 110 111 112 114 115\n", + " 116 118 119 121 123 125 128 130 131 135 137 138 139 141 142 146 150 152\n", + " 155 156 158 159 160 162 164 165 167 168 169 170 171 172 175 176 178 180\n", + " 182 184 186 189 192 195 198 199 200 201 203 204 205 208 210 212 214 215\n", + " 216 219 220 221 222 225 226 228 230 235 238 240 241 244 246 247 250 251\n", + " 253 254 256 259 261 263 264 265 267 268 269 270 272 273 274 276 278 279\n", + " 281 284 287 288 290 291 293 295 297 298 299 300 301 303 304 311 312 317\n", + " 318 319 320 321 322 323 324 325 326 327 331 332 334 342 343 345 346 347\n", + " 348 349 353 356 359 360 362 364 365 369 370 373 375 376 377 378 380 382\n", + " 384 385 387 390 392 396 397],count 205\n", + "ingroup [133 193 285 330 352],count 5\n", + "ingroup [ 31 53 68 122 394],count 5\n", + "ingroup [ 60 103 237 262 266 391],count 6\n", + "ingroup [ 88 124 147 174 207 379],count 6\n", + "ingroup [ 75 113 140 179 181 232 307 361],count 8\n", + "ingroup [ 18 47 191 234 351],count 5\n", + "ingroup [ 11 117 134 149 242 271 277 329 338 371 395],count 11\n", + "ingroup [145 248 302 339 374],count 5\n", + "ingroup [ 19 255 341 366 368],count 5\n", + "ingroup [ 0 5 8 14 15 17 23 24 25 26 30 32 33 38 41 43 44 45\n", + " 54 56 59 61 62 65 71 73 74 78 82 83 84 86 87 90 91 92\n", + " 94 96 98 99 101 107 108 109 120 126 127 129 132 136 143 144 148 151\n", + " 153 154 161 163 166 173 177 185 187 188 190 194 196 197 202 206 209 211\n", + " 213 217 218 223 224 227 229 231 233 236 239 243 245 249 257 258 260 275\n", + " 280 283 286 289 292 294 305 306 308 309 310 313 314 315 316 328 333 335\n", + " 336 340 344 350 354 355 357 358 363 367 372 381 383 386 388 389],count 124\n", + "equiv_classes is [ 5 6 7 10 11 12 15 16 18 19 22 23 24]\n", + "equiv_counts is [ 8 5 205 5 5 6 6 8 5 11 5 5 124]\n", + "equiv_members is [array([ 34, 89, 95, 100, 157, 252, 282, 296]), array([ 29, 42, 183, 337, 393]), array([ 1, 2, 3, 4, 6, 7, 9, 10, 12, 13, 16, 20, 21,\n", + " 22, 27, 28, 35, 36, 37, 39, 40, 46, 48, 49, 50, 51,\n", + " 52, 55, 57, 58, 63, 64, 66, 67, 69, 70, 72, 76, 77,\n", + " 79, 80, 81, 85, 93, 97, 102, 104, 105, 106, 110, 111, 112,\n", + " 114, 115, 116, 118, 119, 121, 123, 125, 128, 130, 131, 135, 137,\n", + " 138, 139, 141, 142, 146, 150, 152, 155, 156, 158, 159, 160, 162,\n", + " 164, 165, 167, 168, 169, 170, 171, 172, 175, 176, 178, 180, 182,\n", + " 184, 186, 189, 192, 195, 198, 199, 200, 201, 203, 204, 205, 208,\n", + " 210, 212, 214, 215, 216, 219, 220, 221, 222, 225, 226, 228, 230,\n", + " 235, 238, 240, 241, 244, 246, 247, 250, 251, 253, 254, 256, 259,\n", + " 261, 263, 264, 265, 267, 268, 269, 270, 272, 273, 274, 276, 278,\n", + " 279, 281, 284, 287, 288, 290, 291, 293, 295, 297, 298, 299, 300,\n", + " 301, 303, 304, 311, 312, 317, 318, 319, 320, 321, 322, 323, 324,\n", + " 325, 326, 327, 331, 332, 334, 342, 343, 345, 346, 347, 348, 349,\n", + " 353, 356, 359, 360, 362, 364, 365, 369, 370, 373, 375, 376, 377,\n", + " 378, 380, 382, 384, 385, 387, 390, 392, 396, 397]), array([133, 193, 285, 330, 352]), array([ 31, 53, 68, 122, 394]), array([ 60, 103, 237, 262, 266, 391]), array([ 88, 124, 147, 174, 207, 379]), array([ 75, 113, 140, 179, 181, 232, 307, 361]), array([ 18, 47, 191, 234, 351]), array([ 11, 117, 134, 149, 242, 271, 277, 329, 338, 371, 395]), array([145, 248, 302, 339, 374]), array([ 19, 255, 341, 366, 368]), array([ 0, 5, 8, 14, 15, 17, 23, 24, 25, 26, 30, 32, 33,\n", + " 38, 41, 43, 44, 45, 54, 56, 59, 61, 62, 65, 71, 73,\n", + " 74, 78, 82, 83, 84, 86, 87, 90, 91, 92, 94, 96, 98,\n", + " 99, 101, 107, 108, 109, 120, 126, 127, 129, 132, 136, 143, 144,\n", + " 148, 151, 153, 154, 161, 163, 166, 173, 177, 185, 187, 188, 190,\n", + " 194, 196, 197, 202, 206, 209, 211, 213, 217, 218, 223, 224, 227,\n", + " 229, 231, 233, 236, 239, 243, 245, 249, 257, 258, 260, 275, 280,\n", + " 283, 286, 289, 292, 294, 305, 306, 308, 309, 310, 313, 314, 315,\n", + " 316, 328, 333, 335, 336, 340, 344, 350, 354, 355, 357, 358, 363,\n", + " 367, 372, 381, 383, 386, 388, 389])]\n", + "\n" + ] + } + ], + "source": [ + "test_dt()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "deb9d05a-24aa-464c-9a7b-0647ae1a56e1", + "metadata": {}, + "outputs": [], + "source": [ + "def test_rf():\n", + " \"\"\"test for decision tree classifier\"\"\"\n", + "\n", + " print(\"\\n\\n\\n====== Non Disclosive ====\\n\\n\")\n", + "\n", + " param_dict = {\"max_depth\": 1, \"min_samples_leaf\": 150, \"n_estimators\": 5}\n", + " target = get_target(\"rf\", **param_dict)\n", + " target_path = target.save(\"dt.sav\")\n", + " myattack = StructuralAttack(target_path=\"dt.sav\")\n", + " myattack.attack(target)\n", + " # assert myattack.DoF_risk ==0 ,\"should be no DoF risk with devision stump\"\n", + " # assert myattack.k_anonymity_risk ==0, 'should be no k-anonymity risk with min_samples_leaf 150'\n", + " # assert myattack.class_disclosure_risk ==0,'no class disclsoure risk for stump with min samles leaf 150'\n", + " # assert myattack.unnecessary_risk ==0, 'not unnecessary risk if max_depth < 3.5'\n", + " print(\n", + " f\" {len(myattack.equiv_classes)} equiv_classes:\\n{myattack.equiv_classes}\\n\"\n", + " f\"equiv_counts is {myattack.equiv_counts}\\n\"\n", + " # f'equiv_members is {myattack.equiv_members}\\n'\n", + " )\n", + " for i in range(len(myattack.equiv_members)):\n", + " print(\n", + " f\" {len(myattack.equiv_members[i])} members for group {i}\\n\"\n", + " f\"{myattack.equiv_members[i]}\"\n", + " )\n", + "\n", + " print(\"\\n\\n\\n====== Now Disclosive ====\\n\\n\")\n", + " # highly disclosive\n", + " param_dict = {\n", + " \"max_depth\": None,\n", + " \"min_samples_leaf\": 5,\n", + " \"min_samples_split\": 2,\n", + " \"n_estimators\": 5,\n", + " }\n", + " target2 = get_target(\"rf\", **param_dict)\n", + " myattack2 = StructuralAttack()\n", + " myattack2.attack(target2)\n", + " # assert myattack2.DoF_risk ==0 ,\"should be no DoF risk with decision stump\"\n", + " # assert myattack2.k_anonymity_risk ==1, 'should be k-anonymity risk with unlimited depth and min_samples_leaf 5'\n", + " # assert myattack2.class_disclosure_risk ==1,'should be class disclosure risk with unlimited depth and min_samples_leaf 5'\n", + " # assert myattack2.unnecessary_risk ==1, ' unnecessary risk with unlimited depth and min_samples_leaf 5'\n", + " print(f\" attack._get_param_names returns {myattack2._get_param_names()}\")\n", + " print(f\" attack.get_params returns {myattack2.get_params()}\")\n", + "\n", + " print(\n", + " f\" {len(myattack2.equiv_classes)} equiv_classes:\\n{myattack2.equiv_classes}\\n\"\n", + " f\"equiv_counts is {myattack2.equiv_counts}\\n\"\n", + " # f'equiv_members is {myattack2.equiv_members}\\n'\n", + " )\n", + " for i in range(len(myattack2.equiv_members)):\n", + " print(\n", + " f\" {len(myattack2.equiv_members[i])} members for group {i}\\n\"\n", + " f\"{myattack2.equiv_members[i]}\"\n", + " )\n", + "\n", + " # myattack.make_report()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "b962b63e-4b6b-47e0-b718-7d0e649dabc5", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:acro:version: 0.4.2\n", + "INFO:acro:config: {'safe_threshold': 10, 'safe_dof_threshold': 10, 'safe_nk_n': 2, 'safe_nk_k': 0.9, 'safe_pratio_p': 0.1, 'check_missing_values': False}\n", + "INFO:acro:automatic suppression: False\n", + "INFO:structural_attack:Thresholds for count 10 and DoF 10\n", + "INFO:acro:version: 0.4.2\n", + "INFO:acro:config: {'safe_threshold': 10, 'safe_dof_threshold': 10, 'safe_nk_n': 2, 'safe_nk_k': 0.9, 'safe_pratio_p': 0.1, 'check_missing_values': False}\n", + "INFO:acro:automatic suppression: False\n", + "INFO:structural_attack:Thresholds for count 10 and DoF 10\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\n", + "====== Non Disclosive ====\n", + "\n", + "\n", + " 1 equiv_classes:\n", + "[[0.33919598 0.66080402]]\n", + "equiv_counts is [398]\n", + "\n", + " 398 members for group 0\n", + "[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17\n", + " 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35\n", + " 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53\n", + " 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71\n", + " 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89\n", + " 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107\n", + " 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125\n", + " 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143\n", + " 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161\n", + " 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179\n", + " 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197\n", + " 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215\n", + " 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233\n", + " 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251\n", + " 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269\n", + " 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287\n", + " 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305\n", + " 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323\n", + " 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341\n", + " 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359\n", + " 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377\n", + " 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395\n", + " 396 397]\n", + "\n", + "\n", + "\n", + "====== Now Disclosive ====\n", + "\n", + "\n", + " attack._get_param_names returns ['risk_appetite_config', 'target_path', 'output_dir', 'report_name']\n", + " attack.get_params returns {'risk_appetite_config': 'default', 'target_path': None, 'output_dir': 'outputs_structural', 'report_name': 'report_structural'}\n", + " 88 equiv_classes:\n", + "[[0. 1. ]\n", + " [0.01818182 0.98181818]\n", + " [0.02 0.98 ]\n", + " [0.025 0.975 ]\n", + " [0.03333333 0.96666667]\n", + " [0.04040404 0.95959596]\n", + " [0.04444444 0.95555556]\n", + " [0.05333333 0.94666667]\n", + " [0.05714286 0.94285714]\n", + " [0.06 0.94 ]\n", + " [0.06444444 0.93555556]\n", + " [0.06666667 0.93333333]\n", + " [0.08 0.92 ]\n", + " [0.08222222 0.91777778]\n", + " [0.08666667 0.91333333]\n", + " [0.1 0.9 ]\n", + " [0.11428571 0.88571429]\n", + " [0.12 0.88 ]\n", + " [0.13333333 0.86666667]\n", + " [0.13611111 0.86388889]\n", + " [0.15555556 0.84444444]\n", + " [0.16444444 0.83555556]\n", + " [0.16666667 0.83333333]\n", + " [0.16666667 0.83333333]\n", + " [0.18222222 0.81777778]\n", + " [0.19428571 0.80571429]\n", + " [0.24040404 0.75959596]\n", + " [0.25555556 0.74444444]\n", + " [0.25714286 0.74285714]\n", + " [0.26984127 0.73015873]\n", + " [0.27111111 0.72888889]\n", + " [0.29047619 0.70952381]\n", + " [0.30277778 0.69722222]\n", + " [0.30707071 0.69292929]\n", + " [0.36103896 0.63896104]\n", + " [0.38608059 0.61391941]\n", + " [0.38698413 0.61301587]\n", + " [0.44761905 0.55238095]\n", + " [0.44888889 0.55111111]\n", + " [0.47777778 0.52222222]\n", + " [0.48770563 0.51229437]\n", + " [0.52031746 0.47968254]\n", + " [0.54761905 0.45238095]\n", + " [0.57818182 0.42181818]\n", + " [0.57936508 0.42063492]\n", + " [0.5847619 0.4152381 ]\n", + " [0.58888889 0.41111111]\n", + " [0.59261905 0.40738095]\n", + " [0.5956044 0.4043956 ]\n", + " [0.62 0.38 ]\n", + " [0.62666667 0.37333333]\n", + " [0.63142857 0.36857143]\n", + " [0.64322344 0.35677656]\n", + " [0.64484848 0.35515152]\n", + " [0.65555556 0.34444444]\n", + " [0.65655678 0.34344322]\n", + " [0.70989011 0.29010989]\n", + " [0.72380952 0.27619048]\n", + " [0.72666667 0.27333333]\n", + " [0.7556044 0.2443956 ]\n", + " [0.76103896 0.23896104]\n", + " [0.78888889 0.21111111]\n", + " [0.8 0.2 ]\n", + " [0.80285714 0.19714286]\n", + " [0.82222222 0.17777778]\n", + " [0.82666667 0.17333333]\n", + " [0.82698413 0.17301587]\n", + " [0.84131868 0.15868132]\n", + " [0.85333333 0.14666667]\n", + " [0.85555556 0.14444444]\n", + " [0.85555556 0.14444444]\n", + " [0.85714286 0.14285714]\n", + " [0.86 0.14 ]\n", + " [0.86666667 0.13333333]\n", + " [0.86989011 0.13010989]\n", + " [0.87179487 0.12820513]\n", + " [0.88888889 0.11111111]\n", + " [0.89333333 0.10666667]\n", + " [0.89846154 0.10153846]\n", + " [0.9047619 0.0952381 ]\n", + " [0.93142857 0.06857143]\n", + " [0.93333333 0.06666667]\n", + " [0.93846154 0.06153846]\n", + " [0.94285714 0.05714286]\n", + " [0.96 0.04 ]\n", + " [0.96666667 0.03333333]\n", + " [0.97142857 0.02857143]\n", + " [1. 0. ]]\n", + "equiv_counts is [173 8 3 3 8 1 5 3 1 5 1 1 3 1 5 2 1 1\n", + " 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 4\n", + " 1 2 1 1 2 1 1 1 3 5 1 1 1 7 1 86]\n", + "\n", + " 173 members for group 0\n", + "[ 1 2 3 5 6 7 10 11 15 18 20 22 26 28 29 30 32 36\n", + " 37 39 48 50 53 60 62 63 64 65 66 68 69 70 71 73 75 78\n", + " 81 82 84 88 92 93 96 102 103 109 110 115 117 119 122 123 124 126\n", + " 130 131 134 135 141 143 148 149 150 151 156 158 159 161 162 163 165 166\n", + " 174 180 181 182 186 189 191 192 195 197 198 199 200 202 203 206 210 212\n", + " 213 215 216 218 220 221 223 226 231 232 233 235 237 240 241 245 247 248\n", + " 253 254 256 259 260 261 263 264 267 269 270 271 272 281 284 285 290 291\n", + " 295 304 305 306 307 309 311 314 315 317 318 321 324 327 328 331 333 334\n", + " 335 337 338 339 341 342 345 349 350 355 360 361 365 369 370 373 375 377\n", + " 378 382 385 386 387 388 389 391 392 394 395]\n", + " 8 members for group 1\n", + "[ 16 43 104 121 152 177 229 384]\n", + " 3 members for group 2\n", + "[ 13 336 380]\n", + " 3 members for group 3\n", + "[ 98 287 301]\n", + " 8 members for group 4\n", + "[ 51 170 214 222 262 273 276 396]\n", + " 1 members for group 5\n", + "[279]\n", + " 5 members for group 6\n", + "[ 27 35 42 106 371]\n", + " 3 members for group 7\n", + "[ 55 129 138]\n", + " 1 members for group 8\n", + "[80]\n", + " 5 members for group 9\n", + "[ 79 116 164 252 286]\n", + " 1 members for group 10\n", + "[160]\n", + " 1 members for group 11\n", + "[77]\n", + " 3 members for group 12\n", + "[ 49 120 289]\n", + " 1 members for group 13\n", + "[176]\n", + " 5 members for group 14\n", + "[ 87 173 208 322 359]\n", + " 2 members for group 15\n", + "[91 99]\n", + " 1 members for group 16\n", + "[146]\n", + " 1 members for group 17\n", + "[376]\n", + " 2 members for group 18\n", + "[154 280]\n", + " 1 members for group 19\n", + "[196]\n", + " 1 members for group 20\n", + "[108]\n", + " 1 members for group 21\n", + "[46]\n", + " 1 members for group 22\n", + "[89]\n", + " 1 members for group 23\n", + "[234]\n", + " 1 members for group 24\n", + "[275]\n", + " 1 members for group 25\n", + "[72]\n", + " 1 members for group 26\n", + "[308]\n", + " 1 members for group 27\n", + "[136]\n", + " 1 members for group 28\n", + "[243]\n", + " 1 members for group 29\n", + "[368]\n", + " 1 members for group 30\n", + "[381]\n", + " 1 members for group 31\n", + "[67]\n", + " 1 members for group 32\n", + "[172]\n", + " 1 members for group 33\n", + "[111]\n", + " 1 members for group 34\n", + "[294]\n", + " 1 members for group 35\n", + "[348]\n", + " 1 members for group 36\n", + "[219]\n", + " 1 members for group 37\n", + "[178]\n", + " 1 members for group 38\n", + "[204]\n", + " 1 members for group 39\n", + "[312]\n", + " 1 members for group 40\n", + "[320]\n", + " 1 members for group 41\n", + "[169]\n", + " 1 members for group 42\n", + "[56]\n", + " 1 members for group 43\n", + "[364]\n", + " 1 members for group 44\n", + "[155]\n", + " 1 members for group 45\n", + "[113]\n", + " 1 members for group 46\n", + "[288]\n", + " 1 members for group 47\n", + "[54]\n", + " 1 members for group 48\n", + "[282]\n", + " 1 members for group 49\n", + "[351]\n", + " 1 members for group 50\n", + "[372]\n", + " 1 members for group 51\n", + "[105]\n", + " 1 members for group 52\n", + "[86]\n", + " 1 members for group 53\n", + "[8]\n", + " 1 members for group 54\n", + "[224]\n", + " 1 members for group 55\n", + "[225]\n", + " 1 members for group 56\n", + "[40]\n", + " 1 members for group 57\n", + "[313]\n", + " 1 members for group 58\n", + "[145]\n", + " 1 members for group 59\n", + "[19]\n", + " 1 members for group 60\n", + "[83]\n", + " 1 members for group 61\n", + "[193]\n", + " 1 members for group 62\n", + "[362]\n", + " 1 members for group 63\n", + "[14]\n", + " 1 members for group 64\n", + "[292]\n", + " 1 members for group 65\n", + "[356]\n", + " 1 members for group 66\n", + "[397]\n", + " 1 members for group 67\n", + "[367]\n", + " 1 members for group 68\n", + "[296]\n", + " 2 members for group 69\n", + "[ 0 207]\n", + " 2 members for group 70\n", + "[ 0 207]\n", + " 4 members for group 71\n", + "[ 52 242 302 344]\n", + " 1 members for group 72\n", + "[293]\n", + " 2 members for group 73\n", + "[323 330]\n", + " 1 members for group 74\n", + "[326]\n", + " 1 members for group 75\n", + "[379]\n", + " 2 members for group 76\n", + "[ 34 363]\n", + " 1 members for group 77\n", + "[266]\n", + " 1 members for group 78\n", + "[157]\n", + " 1 members for group 79\n", + "[297]\n", + " 3 members for group 80\n", + "[ 47 94 217]\n", + " 5 members for group 81\n", + "[ 24 168 205 319 329]\n", + " 1 members for group 82\n", + "[171]\n", + " 1 members for group 83\n", + "[268]\n", + " 1 members for group 84\n", + "[12]\n", + " 7 members for group 85\n", + "[114 127 132 209 250 283 298]\n", + " 1 members for group 86\n", + "[257]\n", + " 86 members for group 87\n", + "[ 4 9 17 21 23 25 31 33 38 41 44 45 57 58 59 61 74 76\n", + " 85 90 95 97 100 101 107 112 118 125 128 133 137 139 140 142 144 147\n", + " 153 167 175 179 183 184 185 187 188 190 194 201 211 227 228 230 236 238\n", + " 239 244 246 249 251 255 258 265 274 277 278 299 300 303 310 316 325 332\n", + " 340 343 346 347 352 353 354 357 358 366 374 383 390 393]\n" + ] + } + ], + "source": [ + "test_rf()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d9fc566-9c2f-468f-b664-41ff71fa366f", + "metadata": {}, + "outputs": [], + "source": [ + "from acro import ACRO\n", + "\n", + "acro = ACRO()\n", + "\n", + "from scipy.io.arff import loadarff\n", + "\n", + "path = os.path.join(\"../data\", \"nursery.arff\")\n", + "data = loadarff(path)\n", + "df = pd.DataFrame(data[0])\n", + "df = df.select_dtypes([object])\n", + "df = df.stack().str.decode(\"utf-8\").unstack()\n", + "df.rename(columns={\"class\": \"recommend\"}, inplace=True)\n", + "df.head()\n", + "df[\"children\"].replace(to_replace={\"more\": \"4\"}, inplace=True)\n", + "df[\"children\"] = pd.to_numeric(df[\"children\"])\n", + "\n", + "df[\"children\"] = df.apply(\n", + " lambda row: (\n", + " row[\"children\"] if row[\"children\"] in (1, 2, 3) else np.random.randint(4, 10)\n", + " ),\n", + " axis=1,\n", + ")\n", + "\n", + "mytable = acro.crosstab(\n", + " [data.survivor, data.year], data.grant_type, values=data.inc_grants, aggfunc=\"mean\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d5cbff4-08cb-4e4d-933b-5cfd7928615f", + "metadata": {}, + "outputs": [], + "source": [ + "def get_whitebox_class_disclosure(yprobs:np.ndarray, \n", + " true_labels:np.array,\n", + " threshold:int,\n", + " ignore_zeros:Bool)->tuple[int,int]:\n", + " \"\"\" \n", + " function that ingests the proba values created\n", + " when a classifier is applied to a set of records\n", + " and returns details of whitebox group membership \n", + "\n", + " Parameters\n", + " ----------\n", + " yprobs: int\n", + " numpy 2d array, one row per record, one column per output class\n", + " true_labels: numpy 1Darray\n", + " one element for each row in yprobs, giving the actual class label\n", + " threshold :int\n", + " minimum number of (non-zero) records of each class in each equivalence group\n", + " ignore_zeros:Bool\n", + " should the threshold checking ignore 'evidential zeros' i.e. unrepresented classes\n", + " \n", + " Returns\n", + " --------\n", + " tuple [int,int]: \n", + " model is whitebox class disclosive (1) or not (0)\n", + " according to probability*membership tuple[0]\n", + " or actual group member labels tuple[1]\n", + " \n", + " \"\"\"\n", + " n_classes = yprobs.shape[1]\n", + " n_rows=yprobs.shape[0]\n", + " assert len(true_labels)==n_rows, f\"shape mismatch:lengths of yprobs {n_rows} and true_classes{len(true_classes)}\"\n", + " \n", + " uniques = np.unique(yprobs,axis=0,return_counts=True)\n", + " #groups are equivalance classes in predicted class probability space \n", + " uniq_probs=uniques[0]\n", + " uniq_freqs=uniques[1]\n", + " class_freqs= np.zeros( uniq_probs.shape,dtype=float)\n", + " membership=[]\n", + "\n", + " #check disclosure according to proba values\n", + " disclosive_by_freqs=1\n", + " for group in range( len(uniq_probs)):\n", + " class_freqs[group]= uniq_probs[group,:]*uniq_freqs[group]\n", + " for label in range(n_classes):\n", + " if class_freqs[group][label]== 0 and not ignore_zeros:\n", + " disclosive_by_freqs = 1\n", + " elif 0< class_freqs[group][label]< threshold :\n", + " disclosive_by_freqs = 1\n", + " else:\n", + " pass\n", + " \n", + " #now according to the labels of records falling in to each group\n", + " disclosive_by_labels=0\n", + " for prob_vals in uniq_probs:\n", + " ingroup = np.all(yprobs==prb_vals,axis=1)\n", + " \n", + " \n", + "def test_whitebox_class_disclosure(): \n", + "uprobs=uniques[0]\n", + "ufreqs=uniques[1]\n", + "class_freqs= np.zeros( uprobs.shape,dtype=float)\n", + "for group in range( len(uprobs)):\n", + " class_freqs[group]= uprobs[group,:]*ufreqs[group]\n", + " print(f'group {group} class_membership {class_freqs[group]}')\n", + " errmsg=f'class sum {class_freqs[group].sum()} should equal group count {ufreqs[group]}'\n", + " np.testing.assert_almost_equal( class_freqs[group].sum(), ufreqs[group],0.001),errmsg\n", + "print(f'class freqs are:\\n{class_freqs}')\n", + " \n", + " \n", + "uniqvals= [ [0.1,0.2,0.7],\n", + " [0.6,0.4,0.0],\n", + " [0.2,0.4,0.4]]\n", + "\n", + "\n", + "yprobs = np.zeros((20,3),dtype=float)\n", + "for i in range (20):\n", + " randval = np.random.randint(0,3)\n", + " yprobs[i] = np.array(uniqvals[randval])\n", + "#print( f'yprobs is \\n{yprobs}')\n", + "sorted_probs = yprobs[np.lexsort(([yprobs[:, i] for i in range(yprobs.shape[1]-1, -1, -1)]))]\n", + "#print( f'sorted_probs is \\n{sorted_probs}')\n", + "uniques = np.unique(sorted_probs,axis=0,return_counts=True)\n", + "print(f'np.uniq gives {len(uniques[0])}') \n", + "\n", + "uprobs=uniques[0]\n", + "ufreqs=uniques[1]\n", + "class_freqs= np.zeros( uprobs.shape,dtype=float)\n", + "for group in range( len(uprobs)):\n", + " class_freqs[group]= uprobs[group,:]*ufreqs[group]\n", + " print(f'group {group} class_membership {class_freqs[group]}')\n", + " errmsg=f'class sum {class_freqs[group].sum()} should equal group count {ufreqs[group]}'\n", + " np.testing.assert_almost_equal( class_freqs[group].sum(), ufreqs[group],0.001),errmsg\n", + "print(f'class freqs are:\\n{class_freqs}')\n", + "\n", + "#class disclosure step 3:loop through all similarity groups\n", + "r_ends = []\n", + "group_first = 0\n", + "group_last= 0\n", + "possible_next=group_last+1\n", + "while possible_nexttuple[int,int]:\n", - " \"\"\" \n", - " function that ingests the proba values created\n", - " when a classifier is applied to a set of records\n", - " and returns details of whitebox group membership \n", - "\n", - " Parameters\n", - " ----------\n", - " yprobs: int\n", - " numpy 2d array, one row per record, one column per output class\n", - " true_labels: numpy 1Darray\n", - " one element for each row in yprobs, giving the actual class label\n", - " threshold :int\n", - " minimum number of (non-zero) records of each class in each equivalence group\n", - " ignore_zeros:Bool\n", - " should the threshold checking ignore 'evidential zeros' i.e. unrepresented classes\n", - " \n", - " Returns\n", - " --------\n", - " tuple [int,int]: \n", - " model is whitebox class disclosive (1) or not (0)\n", - " according to probability*membership tuple[0]\n", - " or actual group member labels tuple[1]\n", - " \n", - " \"\"\"\n", - " n_classes = yprobs.shape[1]\n", - " n_rows=yprobs.shape[0]\n", - " assert len(true_labels)==n_rows, f\"shape mismatch:lengths of yprobs {n_rows} and true_classes{len(true_classes)}\"\n", - " \n", - " uniques = np.unique(yprobs,axis=0,return_counts=True)\n", - " #groups are equivalance classes in predicted class probability space \n", - " uniq_probs=uniques[0]\n", - " uniq_freqs=uniques[1]\n", - " class_freqs= np.zeros( uniq_probs.shape,dtype=float)\n", - " membership=[]\n", - "\n", - " #check disclosure according to proba values\n", - " disclosive_by_freqs=1\n", - " for group in range( len(uniq_probs)):\n", - " class_freqs[group]= uniq_probs[group,:]*uniq_freqs[group]\n", - " for label in range(n_classes):\n", - " if class_freqs[group][label]== 0 and not ignore_zeros:\n", - " disclosive_by_freqs = 1\n", - " elif 0< class_freqs[group][label]< threshold :\n", - " disclosive_by_freqs = 1\n", - " else:\n", - " pass\n", - " \n", - " #now according to the labels of records falling in to each group\n", - " disclosive_by_labels=0\n", - " for prob_vals in uniq_probs:\n", - " ingroup = np.all(yprobs==prb_vals,axis=1)\n", - " \n", - " \n", - "def test_whitebox_class_disclosure(): \n", - "uprobs=uniques[0]\n", - "ufreqs=uniques[1]\n", - "class_freqs= np.zeros( uprobs.shape,dtype=float)\n", - "for group in range( len(uprobs)):\n", - " class_freqs[group]= uprobs[group,:]*ufreqs[group]\n", - " print(f'group {group} class_membership {class_freqs[group]}')\n", - " errmsg=f'class sum {class_freqs[group].sum()} should equal group count {ufreqs[group]}'\n", - " np.testing.assert_almost_equal( class_freqs[group].sum(), ufreqs[group],0.001),errmsg\n", - "print(f'class freqs are:\\n{class_freqs}')\n", - " \n", - " \n", - "uniqvals= [ [0.1,0.2,0.7],\n", - " [0.6,0.4,0.0],\n", - " [0.2,0.4,0.4]]\n", - "\n", - "\n", - "yprobs = np.zeros((20,3),dtype=float)\n", - "for i in range (20):\n", - " randval = np.random.randint(0,3)\n", - " yprobs[i] = np.array(uniqvals[randval])\n", - "#print( f'yprobs is \\n{yprobs}')\n", - "sorted_probs = yprobs[np.lexsort(([yprobs[:, i] for i in range(yprobs.shape[1]-1, -1, -1)]))]\n", - "#print( f'sorted_probs is \\n{sorted_probs}')\n", - "uniques = np.unique(sorted_probs,axis=0,return_counts=True)\n", - "print(f'np.uniq gives {len(uniques[0])}') \n", - "\n", - "uprobs=uniques[0]\n", - "ufreqs=uniques[1]\n", - "class_freqs= np.zeros( uprobs.shape,dtype=float)\n", - "for group in range( len(uprobs)):\n", - " class_freqs[group]= uprobs[group,:]*ufreqs[group]\n", - " print(f'group {group} class_membership {class_freqs[group]}')\n", - " errmsg=f'class sum {class_freqs[group].sum()} should equal group count {ufreqs[group]}'\n", - " np.testing.assert_almost_equal( class_freqs[group].sum(), ufreqs[group],0.001),errmsg\n", - "print(f'class freqs are:\\n{class_freqs}')\n", - "\n", - "#class disclosure step 3:loop through all similarity groups\n", - "r_ends = []\n", - "group_first = 0\n", - "group_last= 0\n", - "possible_next=group_last+1\n", - "while possible_next