Skip to content

Commit

Permalink
Reduce n_epochs in CI for msr_banzhaf_digits.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
mdbenito committed Jan 12, 2025
1 parent 504a639 commit f4201a2
Showing 1 changed file with 26 additions and 78 deletions.
104 changes: 26 additions & 78 deletions notebooks/msr_banzhaf_digits.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,8 @@
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"tags": [
"hide"
]
},
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
Expand All @@ -70,7 +66,6 @@
"\n",
"is_CI = os.environ.get(\"CI\")\n",
"random_state = 24\n",
"n_jobs = 16\n",
"random.seed(random_state)"
]
},
Expand Down Expand Up @@ -123,53 +118,31 @@
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": [
"hide"
]
},
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# In CI we only use a subset of the training set\n",
"# Reduce computation time for CI\n",
"training_data = list(training_data)\n",
"if is_CI:\n",
" training_data[0] = training_data[0][:10]\n",
" training_data[1] = training_data[1][:10]\n",
" max_checks = 1\n",
" n_jobs = 2\n",
" n_epochs = 1\n",
"else:\n",
" training_data[0] = training_data[0][:200]\n",
" training_data[1] = training_data[1][:200]\n",
" max_checks = 1000"
" max_checks = 1000\n",
" n_jobs = 16\n",
" n_epochs = 40"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"tags": [
"hide-input",
"invertible-output"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWkAAAGJCAYAAABIP8LMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAoOUlEQVR4nO3de7xcZX3v8c+XcFHITe5CihsEjx6lIFAEixLFCwdsA1bFCy1R+zqCWI3YVyuvXtjgBauUi7XV1iqJp6hYL0GliFJJCgjkQEn0gBRtSBBIuEnCTQIkv/PH80wyTGZm71l7Ls/s/X2/XvPae69Zz3qeWfOb76xZa+01igjMzKxM2wx6AGZm1ppD2sysYA5pM7OCOaTNzArmkDYzK5hD2sysYA5pM7OCOaTNzArmkDYzK5hDugskrZK0qofLXygpJI30qo/JTNJ2ks6W9AtJG/K6PGHQ4+oVSSP5MS4c9FgaSZqfxzZ/0GMZFh2HdF7BY93mdn+oZpV9BPhr4F7gPOBs4PaBjmiC8utsyaDHMSiSRoclayY61m0n0PfZbe5bNYHl2tbOBD4F3DPogQypNwGPAa+PiKcGPRizTlQO6YgY7eI4rI2IWAOsGfQ4hthewEMOaBtKEdHRDYjUbFzz7gusA34NvKDhvp2AnwMbgbl1019E2mq8CXgA2ACsBv4JmNOkj7l5TKPAYcAPgPXAw8C3gN/K8+0HfD0v8zfA1cBBTZa3MC9vP+AM0sfiJ4G7gQuAmU3arAJWtVgH78h9rcvL+Tnwl8AOHazz2phG6qaN5GkLgRcC3wQeAh4Ffgi8LM+3W153a3L//xd4TZM+9iLtErgOWAs8Rdo98FXgf7YYl4APAbflZd8DfA6Y1a11ArwK+F5e/xvy2G4AzupgvTXeVjVZhy8CLgXuBzaRa5K0S/DUvN4eAx7Pv58GbNPi9bEE2AP4MnBfbvMT4FV1tf8ZUl1vAG4F3jrOWpjf4jEFMNrkcY2Q6v7BvK5vAt7UZvkTrte8nP2BfyW9DmuP//i68c9vmP81pDq9DXiE9Br9f8BZwHOavN6aroMJ5IiAU/I4H8iP/VfAlcBJTeafQ6r1lXnZDwHfBX6n07GOuS47WfF1RTj+DuAtuc11wLZ10xfl6Wc1zP/RXCDfAT5L2od4BemFswbYu2H+uXk5l+cn9ge5zZV5+n8BL85Fei3wt6RA20R6QU5v8cK+LBfYPwJ/AyzP029qUTRbBRLpRRr5yf5S7vu6PO3q+vUxzrBpFtJL8mO7Ji//W/mxPQgcAPw3cAtwIfAVUvg+CezT0MfbgSfyevz7/Ji/ned/jOZvaP+Qx3BP3XN1B7AsT5vQOgGOJb2JP5zr5ZPAF4ClwH3jWG8nkN681+XbaL4taFiH1+Q+biS9EX8BOCTPc0me5668Di9gywvvkhavj+XAL5us9yeAg0hvMnfk9fxPpDfWTcAR43hMB+fHEHkco3W3uQ2P62pSjd+Qx70oP/cbaf5G3a16PYBUfwH8W37evgE8TQqyZiH9g/x4vkp6A/s74D/r+p5WN+8CUt3X3og2r4MJ5Mgn8/JW5uflk8DFpDeKbzbMe0h+fJvyMs/L41hHCuzjOhlrz0K6oTjqbx9t82I+N/99Sv77xzRsjQB703yL6g25uD7fIqQDeFfDfV/K038N/EXDfX+V7/tQi0B8kLqtf9IW1bfyfX/V0GYVDYHEli2GbwPPbbhvtFnfbdZ5bUzNQjraPLZfkwJnm7r7/jDfd0FDm92BGU36PogU0lc0TH8VW94EZ9dN3x74D+q2WKuuk7r1fVCTce3aQc1u9fw0WYefbHL/O/J9/0ndmzlpS/imfN87W7w+Wq33X5M+GTyn7r7auvxOh6/DJS3uq39cZzXc98Y8/d96WK8/bDY/MK9uXPMb7tsPUJNlfSzPf1KLMc1tMYZOc+Qh0qe1HdvVGmkX8S9Jb3ZHN8y3F2njZE1932ONdcz12XGD1h+1ard1Tdo8h7R1sQn4AOlFfz/w/A77/imwsmHa3NzvNU3mf3W+707q3onzfS/I913cMH0hTYK4rpA2Anc2TF/F1oF0C2nLYXaT5UwjvQksG+fjro1ppG7aSJvHtk++73Eagjf3/TRwdQfr/bu5KLerm/bPuY8/ajL/79I8pDtaJ2wJ6RdVKe52z0/DOlxL8xf0j/L9b2hy3zH5vh83eX20W+8B7NdkeXc21tUYj2k8Ib2qsTby/auBB3tRr6TdAEHaIm3W9xKahHSb5e2c5/9yw/RRKgYfzXPkofwctN2tw5Y3ms+0uP9D+f7jujHWiJjQgUN1MO+Tkk4ibX38XR7wWyIdEHsWSQLeRXpnPwh4HqlIalod/LmpybR788/lEbGx4b7amRJzWixvaeOEiFgp6VfAiKTZEbGuWUNJO+axPwgsSA9pKxuAl7TouxPNHlvtcd8REY/W3xERGyXdR5PHLel40v7Xw4Bd2frA8q5sOYD58vzz2iZjugF4pmHZVdbJJcCbgRslXUr62HtdRNzdrPEErIiIDU2mH0LasFjS5L6lpDfslze5r9163ykiVjZpcw/wio5GPbZmtQFpd8aRtT+6XK+b66JF30uAoxsnStqJFHAnkvYnzyDtJ67Zexx91y+v0xy5BPgT4DZJ3yA9v9dHxPqG+Wrr7QWSRpt0fUD++RLSrp4Jm8gpeJ26g/QO9krSwYEftpjvfNJ+nDWk/cr3kPY1Q1rhL2jRrnFlwpag2Oq+iHgmF+N2LZZ3X4vpa/MYZpH2QTXzPFKB7UY68NFL7R5bs3UCab0863FL+hBp/+nDpC3Iu0j7UIO0b/cgYIe6JrPyz63WUw6khxomd7xOIuLbkt5EOs/5PcD78lhvBs6MiB+NZznjsLbF9FnAr6PJWSF5HT9I2k3UqN16b3dft1+P69r0Vf8/Et2s15Z1kW21riVtR9r1eThpH/ClpIN3T+dZzuLZtTcenebIh0lb/+8m7c/+KPCMpH8DPhIRv8zz7ZJ/vnWM/qd3ON6W+hnSHyUF9IPAS0nn/n6ifgZJuwMfJD1Rr2zcGpH0jv4MFUhH5/+ryfQ9889WL7b6+26JiEO6OqoekLQt6SPZWtIBszUN9x/ZpNkj+ecepOKun38aqZjrz+uutE4i4nLg8ryl9QrSOc+nAd+X9PKIuG28y2rXTYvp64GdJW0XEU/X35HX2a5sWQ/DrJv1WlvWHi3u37PJtHmkgF4YEe+uv0PS8+nwjaNKjuSt/guBC3P7o0gH098KvFTSS/OnrdrjmxcR3+1kXFX15d/CJb0SOIcUei/LP8+WdFTDrPvlMf2wyYqdk+/vl2YfyfYDfou0f3Ndq4YR8RjptKqXStq5ZyPsnl2B2cBPmgT0dNLH/ka35J+NzyHAETRsAEx0nUTE4xHx44g4g3TkfXvgf3W6nA7dQqrHVze579Wkj8//2eMxtLKJZ398r6zL9bq5LvKbdaO5Tabtn39+u8l9W70Os9qulGZ9TChHIuL+iPh2RLyNtIX/QlJuQdqVB+lg73i1G+uYeh7Skp4HfI000LdHxH3ASaSPXF9tKIpV+eeznuAcFF+kv1v+H5K0+SORpG1IpwZtQzo1Zyznk4Lky5JmN94p6XmSStnKvp+0a+PQvK6BzR9DLyKFeKOv5J9/IWlWXZvtSSHaTEfrRNKr8xZro9pW2hMtH1F3fDn/PDfvt62Na0fSObiQziAahIdIGwzd0pV6zccLfkT6H4kPNCxjHs1Dd1X+Obdh/v1Ip4I2U9udtk+b5Y0rRyTtIOl3GxeS67+WT7Vau4x0Wuvpko5rNjBJR9bXyxhjHVPl0Gux07xmcUQsz79/mTS4D9amRcQKSR8hnQy+EPj9PH2tpK+TPmYsl/RD0j6u15POLlhOOk+0H67LY7iU9BHnjaT9sjcDnx6rcUR8WdKhwPuB/5Z0JWk/786kAn41KexP7c3wxy8iNkn6LGmX1M8kXUZ6wb6GNN6r8+/1bZZK+ifgfwO3SvoWaR/i75HW172krb36Np2uk88Ce0u6jvTCewo4FHgt6QyFr3dxNWwlIr6ag+Vt+TEuZss++n2BSyPikl6OoY1/B94u6Xukrfmngf+IiP+osrAu1+vpwPWkXQdvAFaQtpZPJJ1++HsN83+PdFrbGZIOJG2N70PatXU5zcPtalJ9nSvpZaRjKUTExyvkyHOBayX9kvT6Xk06I+31pAOA342In+flPy3pzaT93JdL+kle3hOkN83fIW2pP58twd5yrGOvSnpyCt7m02tIR0sDuKzFsr6d7/9w3bQdSfuqa+ci/op0cvku5NN3GpYxNy9jtMnyR/J9C9s8liUN0xbm6fuRDljV/uPwHtI+q07/4/BNwPdJW6tPkfb7LgM+Drx4nOu8NqaRiTy2duMlvWGfQTqo+5s8zv9DOsCyVf+5zTakAy63k47+35ufq1mkf9BYPpF1QgrHrwG/IJ22+QhpP+MngN06qNmmz89Y67DuMb6fdPbQE/l2MymIWv7HYSfjyPdtVdtjPKbdSf/4cR/pU+rm18A4aqNlX92o17yc/Un/NLaOdEri9bT/j8PfIp1hUTvAdyvwZ7kum65T4GRSQP4mzxN19407R0gH0f+M9I8pd+X5HyDt2jgV2L7F+v9Urscncn3+Ij/mk2n4x592Yx3rprwAy5Qu73gKsG9ErBrsaIaTpANIZ/N8PSL6ebDXbNLx9aStMkl75n319dN2JH3igPQvuWY2Af08EGeTzwLgHUrXNV5DOr3qGNI/ylxBusCOmU2AQ9om4kekg6lvIB1geoa0m+OzwIXhfWlmE+Z90mZmBfM+aTOzgjmkzcwK5pA2MyuYQ9rMrGAOaTOzgjmkzcwK5pA2MyuYQ9rMrGAOaTOzgjmkzcwK5pA2MyuYQ9rMrGAOaTOzgjmkzcwK5pA2MyuYQ9rMrGAOaTOzgjmkzcwK5pA2MyuYQ9rMrGAOaTOzgjmkzcwK5pA2MyuYQ9rMrGAOaTOzgm076AGMRZKAvYBHBz0WG5cZwL0REYMeyLBz7Q+dntR+8SFNKtK7Bz0I68gc4J5BD2IScO0Pn67X/jCEdF+3Iq699tqO2xx44IE9GElz69evr9TuqKOO6rjNXXfdVakvvOXXLX1dj1Xq+P3vf3+lvt75zndWalfF8ccf33GbKjmQdf05G4aQ7qvp06d33GbmzJk9GElzVT9JbbONDz9Ye9OmTeu4zXOf+9xKffXzNbPttsMdc3155Uo6XdIqSU9KulHS4f3o12zQXPs2UT0PaUknAecDZwOHACuAKyXt3uu+zQbJtW/d0I8t6TOAL0bExRFxG3Aq8ATwnj70bTZIrn2bsJ6GtKTtgUOBq2rTImJT/vvIFm12kDSzdiOd1mI2VFz71i293pLeFZgG3Ncw/T5gzxZtzgTW1918CpINI9e+dUWJh/zPBWbV3eYMdjhmfePat630+tyUB4GNwB4N0/cA1jZrEBEbgA21v9M/XZkNHde+dUVPt6Qj4ingZuCY2jRJ2+S/r+9l32aD5Nq3bunHWd7nA4sk3QQsAxYAOwEX96Fvs0Fy7duE9TykI+JSSbsB55AOmCwHjo2IxgMqRZg9e3bHbVasWFGpr1WrVnXcZt68eZX6sv4bttpfvnx5x23mz59fqa8FCxZ03KbK+AAOPvjgjtssWbKkUl+90Jf/l4yIzwGf60dfZiVx7dtElXh2h5mZZQ5pM7OCOaTNzArmkDYzK5hD2sysYA5pM7OCOaTNzArmkDYzK5hD2sysYA5pM7OCOaTNzArmkDYzK1hfLrA0TKpcaWtkZKRSX3Pnzu24zbvf/e5KfVW54p5Zr1S52mSVNlD96nml8Ja0mVnBHNJmZgVzSJuZFcwhbWZWMIe0mVnBHNJmZgVzSJuZFcwhbWZWMIe0mVnBHNJmZgVzSJuZFcwhbWZWMF9gqcHFF1/ccZvFixdX6mvRokUdt1m4cGGlvszGUuWCX1WNjo72ra9Zs2b1ra9e8Ja0mVnBHNJmZgVzSJuZFcwhbWZWMIe0mVnBHNJmZgVzSJuZFcwhbWZWMIe0mVnBHNJmZgVzSJuZFcwhbWZWMIe0mVnBfBW8BieeeGLHbVavXl2prxNOOKFSO7N2Zs+eXaldlas59vMKcxdddFGlduvXr+/ySPrLW9JmZgVzSJuZFcwhbWZWMIe0mVnBHNJmZgVzSJuZFcwhbWZWMIe0mVnBHNJmZgVzSJuZFcwhbWZWMIe0mVnBFBGDHkNbkmYCHV8hpepFZpYvX95xm6oXSlqyZEnHbUZHRyv1deGFF1ZqV9GsiHiknx1ORlVrv6qRkZGO21Stqyp9HXzwwZX66rOu1763pM3MCuaQNjMrmEPazKxgPQ1pSaOSouF2ey/7NCuBa9+6pR/fzHIr8Lq6v5/pQ59mJXDt24T1I6SfiYi1fejHrDSufZuwfuyTPkDSvZJWSrpE0j7tZpa0g6SZtRswow9jNOsF175NWK9D+kZgPnAscBqwL3CNpHbFdybp3NDa7e4ej9GsF1z71hU9DemIuCIi/jUifhoRVwLHAbOBt7Vpdi4wq+42p5djNOsF1751Sz/2SW8WEesk3QHs32aeDcCG2t+S+jE0s55y7VtVfT1PWtJ04IXAmn72azZorn2rqtfnSZ8n6WhJI5JeCXwH2Ah8rZf9mg2aa9+6pde7O+aQinIX4AHgWuCIiHigx/2aDZpr37qipyEdEW/v5fLbqXpluipXwavSpqp169b1rS+rbpC1X1WV2po3b16lvqq+PqciX7vDzKxgDmkzs4I5pM3MCuaQNjMrmEPazKxgDmkzs4I5pM3MCuaQNjMrmEPazKxgDmkzs4I5pM3MCuaQNjMrWF8v+t9PixcvrtRuwYIFHbeJiEp9rV+/vuM2/byYk00tIyMjHbdZunRppb4uu+yySu2mIm9Jm5kVzCFtZlYwh7SZWcEc0mZmBXNIm5kVzCFtZlYwh7SZWcEc0mZmBXNIm5kVzCFtZlYwh7SZWcEc0mZmBZu0F1iqetGjxx57rOM2jzzySKW+qrTbuHFjpb7MxlKlth5//PEejMTqqWqY9YukvYG7Bz0O68iciLhn0IMYdq79odT12h+GkBawF/Bok7tnkIp4Tov7p5JS1sUM4N4ovbCGQJvaL+W5LkFJ66IntV/87o78gJu+M6UaBuDRiKi2z2GSKGhdTOnnoZta1X5Bz/XAFbYuetK/DxyamRXMIW1mVrBhD+kNwNn551TndTF1+LneYtKvi+IPHJqZTWXDviVtZjapOaTNzArmkDYzK5hD2sysYA5pM7OCDW1ISzpd0ipJT0q6UdLhgx5Tv0kalRQNt9sHPS7rLdf+1Kr9oQxpSScB55POjzwEWAFcKWn3gQ5sMG4Fnl93O2qww7Fecu0/y5So/aEMaeAM4IsRcXFE3AacCjwBvGewwxqIZyJibd3twUEPyHrKtb/FlKj9oQtpSdsDhwJX1aZFxKb895GDGtcAHSDpXkkrJV0iaZ9BD8h6w7W/lSlR+0MX0sCuwDTgvobp9wF79n84A3UjMB84FjgN2Be4RtKMQQ7Kesa1v8WUqf3iL1VqrUXEFXV//lTSjcBq4G3AlwYzKrPem0q1P4xb0g8CG4E9GqbvAazt/3DKERHrgDuA/Qc8FOsN134Lk7n2hy6kI+Ip4GbgmNo0Sdvkv68f1LhKIGk68EJgzaDHYt3n2m9tMtf+sO7uOB9YJOkmYBmwANgJuHiQg+o3SecB3yN9zNuLdFrWRuBrgxyX9ZRrn6lV+0MZ0hFxqaTdgHNIB0yWA8dGROMBlcluDqkodwEeAK4FjoiIBwY6KusZ1/5mU6b2fT1pM7OCDd0+aTOzqcQhbWZWMIe0mVnBHNJmZgVzSJuZFcwhbWZWMIe0mVnBHNJdIGkkfzPEn3ZxmXPzMud2a5lm3eba770pG9KS5udCOGzQY+kVSa+TdLWkByWtk7RM0h8Oelw2WJO99vNXizV+tVbt9otBj69TQ/lv4TY2Sb8PLCZdeGcUCNJlHL8iadeIuGBwozPrqQXA9IZpLwA+Dvyw76OZIIf05PUB0hXBXhsRGwAk/SNwO+li6Q5pm5QiYnHjNEl/mX+9pL+jmbgpu7tjPCRtL+kcSTdLWi/pcUnXSHpNmzYflrRa0m8kLZX0sibzvFjSNyX9On/j8015y3es8eyY2+46juHPBB6uBTRARDxDuibxb8bR3qawIa/9Zt4J3BkRP6nYfmAc0u3NBP4YWAL8OWm3wW6kb2c+uMn8fwR8EPh74FzgZcCPJW2+SLuklwI3AC8BPgV8BHgcWCzpxDHGczjwc9JW8liWAC+V9DFJ+0t6oaS/Ag4DPj2O9ja1DXPtP4ukl+c+v9pp2yJExJS8kT7yB3BYm3mmAds3TJtN+haML9VNG8nLegLYu2764Xn6+XXTrgJ+CuxQN03AdcAdddPm5rZzm0wbHcfj2wm4FNiU2wTpBTFv0Ovet8HeJnvtN3ks5+W2Lxn0uq9y85Z0GxGxMdK3YSBpG0k7k/bj3wQc0qTJ4oi4p679MtIXZh6Xl7Ez8FrgG8AMSbvmj2+7AFeSvv147zbjWRIRiojRcQx/A+nrhL4JvAM4OY/7XyQdMY72NoUNee1vlr+55u3ALRHx807alsIHDscg6RTSx7IXA9vV3XVnk9mbnd5zB+msCkjfvybgY/nWzO7APS3u68TngCOAQyJiE4CkbwC3AhcBr+hCHzaJDXHt1zsa2JshPlDukG5D0snAQtKpbJ8B7id9Rc+ZpO9T61Ttk8t5pK2HZn5ZYbnPIml74L3Ap2sBDRART0u6AviApO1rW0pmjYa19pt4F2mX39B+rZZDur23ACuBN0feuQUg6ewW8x/QZNqLgFX595X559MRcVW3BtnELqTndlqT+7YjvWCa3WdWM6y1v5mkHYA/AJZExL396LMXvE+6vY35p2oTJL0COLLF/CfU71eTdDhpt8IVABFxP+lo+fskPb+xsdJ317XUwWlI9wPrgBPzVnWt/XTg94DbI8Kn4Vk7w1r79Y4jHewcunOj63lLGt4j6dgm0y8Cvg+8GfiOpMuBfYFTgdvY+j+aIH1cu1bS54EdSP/59BDPPuXtdNKXZv5M0hdJWxh7kIp/DnBQm7EeDlxN+mbk0VYzRcRGpW9T/jhwg6SvkLac35v7OLlNHzZ1TLrab/Au0gH0b41z/iI5pOG0FtMX5tuewPuAN5IK9GTgraRTghp9hbT/awHpIMgy4AMRsaY2Q0TcpnTNhLNIp0LtQtryvYX0DdBdERGfkHQn8KHc1w6k05/eEhFDXbTWNZOy9gEkzQSOBy6PiPXdXHa/+dvCzcwK5n3SZmYFc0ibmRXMIW1mVjCHtJlZwRzSZmYFc0ibmRXMIW1mVjCHtJlZwRzSZmYFc0ibmRXMIW1mVjCHtJlZwRzSZmYFc0ibmRXMIW1mVjCHtJlZwRzSZmYFc0ibmRXMIW1mVjCHtJlZwYr/tnBJAvYCHh30WGxcZgD3hr/heMJc+0OnJ7VffEiTivTuQQ/COjIHuGfQg5gEXPvDp+u1Pwwh3detiOOOO67jNl/4whcq9TVr1qyO26xfv75SXwceeGDf+sJbft3S1/VYpUaOP/74Sn399m//dsdt9tlnn0p9VXmdVVkXWdefs2EI6b7abrvtOm4zc+bMSn1VaVf1k1T65GzW2rRp0zpu85znPKdSXzvuuGPHbaZPn16prxkzZlRqV4q+HDiUdLqkVZKelHSjpMP70a/ZoLn2baJ6HtKSTgLOB84GDgFWAFdK2r3XfZsNkmvfuqEfW9JnAF+MiIsj4jbgVOAJ4D3NZpa0g6SZtRvpiKnZMHLt24T1NKQlbQ8cClxVmxYRm/LfR7Zodiawvu7mo9s2dFz71i293pLeFZgG3Ncw/T5gzxZtzgVm1d3m9Gx0Zr3j2reuKO7sjojYAGyo/e2zEmyqcO1bM73ekn4Q2Ajs0TB9D2Btj/s2GyTXvnVFT0M6Ip4CbgaOqU2TtE3++/pe9m02SK5965Z+7O44H1gk6SZgGbAA2Am4uA99mw2Sa98mrOchHRGXStoNOId0wGQ5cGxENB5Q6aqFCxdWanfKKad03Oayyy6r1Nfy5cs7bnPWWWdV6uvoo4/uuE3Vx2XJoGp/ZGSkUrsqr5mDDjqoUl8rVqyo1K6KdevW9a2vXujLgcOI+BzwuX70ZVYS175NlK8nbWZWMIe0mVnBHNJmZgVzSJuZFcwhbWZWMIe0mVnBHNJmZgVzSJuZFcwhbWZWMIe0mVnBHNJmZgVzSJuZFay4b2bplqpXAlu0aFHHbRYsWFCprwsvvLDjNuvXr6/U19KlSyu1s+FT9apvVa5o52+P6T1vSZuZFcwhbWZWMIe0mVnBHNJmZgVzSJuZFcwhbWZWMIe0mVnBHNJmZgVzSJuZFcwhbWZWMIe0mVnBHNJmZgWbtBdYmjt3bt/6Gh0drdTulFNO6bjNZZddVqkvmzqqXlysisWLF1dqV+UiUFUvZFb1glOl8Ja0mVnBHNJmZgVzSJuZFcwhbWZWMIe0mVnBHNJmZgVzSJuZFcwhbWZWMIe0mVnBHNJmZgVzSJuZFcwhbWZWMIe0mVnBJu1V8PrpwgsvrNSuytXKqlw5D6qNcf78+ZX6sqnj4IMP7ltfVV9nw17H3pI2MyuYQ9rMrGAOaTOzgjmkzcwK5pA2MyuYQ9rMrGAOaTOzgjmkzcwK5pA2MyuYQ9rMrGAOaTOzgjmkzcwKpogY9BjakjQTWD/ocbRT5UJJAKtWreq4zYIFCyr1dcEFF3TcRlKlvoBZEfFI1caW9Lv2R0dHO26zePHiSn1VuehR1dfZCSecUKldRV2vfW9Jm5kVzCFtZlYwh7SZWcF6GtKSRiVFw+32XvZpVgLXvnVLP76Z5VbgdXV/P9OHPs1K4Nq3CetHSD8TEWvHO7OkHYAd6ibN6P6QzPrCtW8T1o990gdIulfSSkmXSNpnjPnPJJ12VLvd3fMRmvWGa98mrNchfSMwHzgWOA3YF7hGUrsthHOBWXW3OT0eo1kvuPatK3q6uyMirqj786eSbgRWA28DvtSizQZgQ+3vCfxDhdnAuPatW/p6Cl5ErAPuAPbvZ79mg+bat6r6GtKSpgMvBNb0s1+zQXPtW1W9Pk/6PElHSxqR9ErgO8BG4Gu97Nds0Fz71i29PgVvDqkodwEeAK4FjoiIB3rcr9mgufatK3p94PDtvVx+L1S5EljVq2ydddZZHbepMj6ApUuXVmpn1Qxj7a9bt67jNkuWLOlbX8uXL6/U17DztTvMzArmkDYzK5hD2sysYA5pM7OCOaTNzArmkDYzK5hD2sysYA5pM7OCOaTNzArmkDYzK5hD2sysYA5pM7OCKSIGPYa2JM0kfd9bX8ybN6/jNosWLarU16xZszpus3r16kp9VbkI1AQuaDMrIh6p2tiSftf+yMhIx20WLlxYqa/Zs2d33KbqhcxWrVpVqV1FXa99b0mbmRXMIW1mVjCHtJlZwRzSZmYFc0ibmRXMIW1mVjCHtJlZwRzSZmYFc0ibmRXMIW1mVjCHtJlZwRzSZmYF23bQAyjN008/3XGbRx6pdj0VSR23efTRRyv1tXHjxkrtbOrYtGlTx20ef/zxSn1tu23n0VNlfJPBMFwFb2/g7kGPwzoyJyLuGfQghp1rfyh1vfaHIaQF7AU024ScQSriOS3un0pKWRczgHuj9MIaAm1qv5TnugQlrYue1H7xuzvyA276zlS3u+DRqX794oLWxZR+HrqpVe0X9FwPXGHroif9+8ChmVnBHNJmZgUb9pDeAJydf051XhdTh5/rLSb9uij+wKGZ2VQ27FvSZmaTmkPazKxgDmkzs4I5pM3MCja0IS3pdEmrJD0p6UZJhw96TP0maVRSNNxuH/S4rLdc+1Or9ocypCWdBJxPOvXmEGAFcKWk3Qc6sMG4FXh+3e2owQ7Hesm1/yxTovaHMqSBM4AvRsTFEXEbcCrwBPCewQ5rIJ6JiLV1twcHPSDrKdf+FlOi9ocupCVtDxwKXFWbFhGb8t9HDmpcA3SApHslrZR0iaR9Bj0g6w3X/lamRO0PXUgDuwLTgPsapt8H7Nn/4QzUjcB84FjgNGBf4BpJMwY5KOsZ1/4WU6b2i78KnrUWEVfU/flTSTcCq4G3AV8azKjMem8q1f4wbkk/CGwE9miYvgewtv/DKUdErAPuAPYf8FCsN1z7LUzm2h+6kI6Ip4CbgWNq0yRtk/++flDjKoGk6cALgTWDHot1n2u/tclc+8O6u+N8YJGkm4BlwAJgJ+DiQQ6q3ySdB3yP9DFvL9JpWRuBrw1yXNZTrn2mVu0PZUhHxKWSdgPOIR0wWQ4cGxGNB1QmuzmkotwFeAC4FjgiIh4Y6KisZ1z7m02Z2velSs3MCjZ0+6TNzKYSh7SZWcEc0mZmBXNIm5kVzCFtZlYwh7SZWcEc0mZmBXNIm5kVzCHdBZJG8tf3/GkXlzk3L3Nut5Zp1m2u/d6bsiEtaX4uhMMGPZZekHSipCvzRdE3SLpb0jclvWzQY7PBmgK13+z7D0PSk4MeWxVDee0OG5cDgYeBi0iXuNyT9BVLyyQdGRErBjk4sz44DXis7u+NgxrIRDikJ6mIOKdxmqR/Bu4mFe+pfR+UWX99czJ87+GU3d0xHpK2l3SOpJslrZf0uKRrJL2mTZsPS1ot6TeSljbbvSDpxXnXw68lPSnpJkm/P47x7Jjb7lrxId1P+tLS2RXb2xQxSWpfkmZKUgdtiuOQbm8m8MfAEuDPgVFgN+BKSQc3mf+PgA8Cfw+cC7wM+LGkzd+kIemlwA3AS4BPAR8BHgcWSzpxjPEcDvwc+MB4H4Ck2ZJ2k3Qg8M/5Mf37eNvblDX0tQ+sBNYDj0r6l/qxDBPv7mjvYWAkfyMGAJK+CNwO/Anw3ob59wcOiIh78rw/IH1h5p8DZ+R5LgLuAn4nIjbk+f6BdD3cvwG+0+XHcAPwP/LvjwEfZ5J9B5z1xDDX/sPA50jfVrMBeBVwOnC4pMMi4pEu9dMXDuk2ImIj+WBD/pqi2aRPHzcBhzRpsrhWpLn9svwFmccBZ0jaGXgt8NfAjIZvNr4SOFvS3vXLaBjPEqDTj27vJm0V7Zd/fy7pG6c3dbgcm0KGufYj4qKGSd+StAy4BHg/aSt+aDikxyDpFNLHshcD29XddWeT2X/RZNodpG8whrS1IeBj+dbM7kDTQq0iIjZ/952kr5M+MgJ07bxWm5yGvfbrRcRXJf0t8Doc0pOHpJOBhcBi4DOkA28bgTNJX3rZqdoxgPNIWw/N/LLCcsclIh6W9GPgXTikrY3JVvvZr4Cde9xH1zmk23sL6eDDm6Pue8Yknd1i/gOaTHsRsCr/vjL/fDoirurWIDv0XGDWgPq24TGpaj+f4TEC3NLvvifKZ3e0Vzv5ffO+MEmvAI5sMf8Jkvaum/dw4BXAFQARcT/paPn7JD2/sbHSF4y21MlpSJJ2bzJtBDiGtF/RrJ1hrv1myzqNdHbKD8ZqXxpvScN7JB3bZPpFwPeBNwPfkXQ5sC/pn0BuA6Y3afNL4FpJnwd2ABYADwGfrpvndNLR7J/lo+UrgT1IxT8HOKjNWA8HriZ9ff3oGI/rZ5L+nfRt0g+TtnTeS9q3+NEx2trUMFlrf7WkS4GfAU8CRwFvJ70W/nGMtsVxSKd32GYW5tuewPuAN5IK9GTgrcDcJm2+QjprYgHpIMgy4AMRsaY2Q0TcpnTNhLOA+aSvpL+f9DFsq/8SnIDPA8cDxwIzch8/BD4ZET/rYj82vCZr7V8CvBL4A+A5wGrSm8UnIuKJLvbTF6rb3WRmZoXxPmkzs4I5pM3MCuaQNjMrmEPazKxgDmkzs4I5pM3MCuaQNjMrmEPazKxgDmkzs4I5pM3MCuaQNjMrmEPazKxg/x89BghkxmQJjQAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 400x400 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Visualize some of the data\n",
"fig, axes = plt.subplots(2, 2, figsize=(4, 4))\n",
Expand Down Expand Up @@ -219,28 +192,15 @@
"from support.banzhaf import TorchCNNModel\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model = TorchCNNModel(lr=0.001, epochs=40, batch_size=32, device=device)\n",
"model = TorchCNNModel(lr=0.001, epochs=n_epochs, batch_size=32, device=device)\n",
"model.fit(x=training_data[0], y=training_data[1])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"tags": [
"hide-input"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Accuracy: 0.705\n",
"Test Accuracy: 0.630\n"
]
}
],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(f\"Train Accuracy: {model.score(x=training_data[0], y=training_data[1]):.3f}\")\n",
"print(f\"Test Accuracy: {model.score(x=test_data[0], y=test_data[1]):.3f}\")"
Expand Down Expand Up @@ -550,21 +510,9 @@
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"tags": [
"hide-output"
]
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|█████████▉| 99.9/100 [00:59<00:00, 1.69%/s] \n"
]
}
],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"anomalous_dataset = Dataset(\n",
" x_train=x_train_anomalous,\n",
Expand All @@ -574,7 +522,7 @@
")\n",
"\n",
"anomalous_utility = Utility(\n",
" model=TorchCNNModel(),\n",
" model=TorchCNNModel(lr=0.001, epochs=n_epochs, batch_size=32, device=device),\n",
" data=anomalous_dataset,\n",
" scorer=Scorer(\"accuracy\", default=0.0, range=(0, 1)),\n",
" cache_backend=MemcachedCacheBackend(MemcachedClientConfig()),\n",
Expand Down Expand Up @@ -722,12 +670,12 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"utility = Utility(\n",
" model=TorchCNNModel(),\n",
" model=TorchCNNModel(lr=0.001, epochs=n_epochs, batch_size=32, device=device),\n",
" data=dataset,\n",
" scorer=Scorer(\"accuracy\", default=0.0, range=(0, 1)),\n",
" cache_backend=MemcachedCacheBackend(MemcachedClientConfig()),\n",
Expand Down Expand Up @@ -969,7 +917,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -983,15 +931,15 @@
" )\n",
"else:\n",
" utility = Utility(\n",
" model=TorchCNNModel(),\n",
" model=TorchCNNModel(lr=0.001, epochs=n_epochs, batch_size=32, device=device),\n",
" data=dataset,\n",
" scorer=Scorer(\"accuracy\", default=0.0, range=(0, 1)),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down

0 comments on commit f4201a2

Please sign in to comment.