Skip to content

Commit

Permalink
chore: cleanup notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-stoian-zama committed May 21, 2024
1 parent 11f6c37 commit 82ec666
Show file tree
Hide file tree
Showing 2 changed files with 556 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import random\n",
"import time\n",
"import warnings\n",
"import os\n",
"\n",
"import numpy as np\n",
"import torch\n",
Expand Down Expand Up @@ -68,9 +68,9 @@
"\n",
"# The timing and the accuracy recorded in the article\n",
"if os.cpu_count() > 48:\n",
" PAPER_NOTES = { 20: [21.17, 0.97], 50: [43.91, 0.94]}\n",
" PAPER_NOTES = {20: [21.17, 0.971], 50: [43.91, 0.947]}\n",
"else:\n",
" PAPER_NOTES = { 20: [115.52, 0.97], 50: [233.55, 0.94]}"
" PAPER_NOTES = {20: [115.52, 0.971], 50: [233.55, 0.947]}"
]
},
{
Expand All @@ -91,7 +91,7 @@
" # in_channel=1, out_channels=1, kernel_size=3, stride=1, padding_mode='replicate'\n",
" (\"C\", 1, 1, 3, 1, \"replicate\"),\n",
" (\"R\",),\n",
" (\"B\", 1, 30), # 2d batch-norm for 1 channel\n",
" (\"B\", 1, 30), # 2d batch-norm for 1 channel\n",
"]\n",
"\n",
"\n",
Expand All @@ -102,19 +102,18 @@
" [\n",
" (\"L\", INPUT_IMG_SIZE * INPUT_IMG_SIZE, 92),\n",
" (\"R\",),\n",
" (\"B\", 92), # 1d batch norm\n",
" (\"B\", 92), # 1d batch norm\n",
" ] # noqa: W503\n",
" + [ # noqa: W503\n",
" (\"L\", 92, 92),\n",
" (\"R\",),\n",
" (\"B\", 92), # 1d batch norm\n",
" (\"B\", 92), # 1d batch norm\n",
" ]\n",
" * (nb_layers - 3) # noqa: W503\n",
" + [\n",
" (\"L\", 92, output_size)\n",
" ] # noqa: W503\n",
" + [(\"L\", 92, output_size)] # noqa: W503\n",
" )\n",
" \n",
"\n",
"\n",
"class Fp32MNIST(torch.nn.Module):\n",
" \"\"\"MNIST Torch model.\"\"\"\n",
"\n",
Expand Down Expand Up @@ -143,7 +142,7 @@
" return torch.nn.Linear(in_features=t[1], out_features=t[2])\n",
" if t[0] == \"R\":\n",
" return torch.nn.ReLU()\n",
" if t[0] == 'B':\n",
" if t[0] == \"B\":\n",
" if len(t) == 2:\n",
" return torch.nn.BatchNorm1d(t[1])\n",
" elif len(t) == 3:\n",
Expand All @@ -167,9 +166,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load MNIST data-set\n",
"\n",
"At the time of writing this notebook, `padding=1` is not yet supported by Concrete ML ; as a workaround, padding is added during the data loading transformation process."
"## Load and pre-process the MNIST data-set\n"
]
},
{
Expand Down Expand Up @@ -418,7 +415,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In the compilation step, the compiler requires an exhaustive set of data, here noted `data_calibration` to evaluate the maximum integer bit-width within the graph."
"In the compilation step, the compiler requires an exhaustive set of data, named `data_calibration` below, to evaluate the maximum integer bit-width within the graph."
]
},
{
Expand Down Expand Up @@ -487,15 +484,14 @@
" q_module.forward(data[0, None], fhe=\"execute\")\n",
" fhe_timing.append((time.time() - start_time))\n",
"\n",
" results_cml[nb_layers] = [ acc_test, np.mean(y_predictions), np.min(fhe_timing)]\n",
" \n",
" results_cml[nb_layers] = [acc_test, np.mean(y_predictions), np.min(fhe_timing)]\n",
"\n",
" print(\n",
" f\"Running NN-{nb_layers} on a {MACHINE} machine:\"\n",
" f\"Accuracy in fp32 : {results_cml[nb_layers][0]:.3%} for the test set\\n\"\n",
" f\"Accuracy in FHE simulation mode : {results_cml[nb_layers][1]:.3%} for the test set\\n\"\n",
" f\"Timing in FHE: {results_cml[nb_layers][2]:.3f}s per sample.\"\n",
" )\n",
"\n"
" )"
]
},
{
Expand Down Expand Up @@ -566,18 +562,46 @@
"source": [
"import pandas as pd\n",
"\n",
"pd.DataFrame([\n",
" [20, PAPER_NOTES[20][1], PAPER_NOTES[20][0], results_cml[20][0], results_cml[20][1], results_cml[20][2], PAPER_NOTES[20][0]/results_cml[20][2]], \n",
" [50, PAPER_NOTES[50][1], PAPER_NOTES[50][0], results_cml[50][0], results_cml[50][1], results_cml[50][2], PAPER_NOTES[50][0]/results_cml[50][2]]\n",
" ], columns=[\"Num Layers\", \"Accuracy [1]\", \"FHE Latency [1]\", \"Our Accuracy fp32\", \"Our Accuracy FHE\", \"Our FHE Latency\", \"Speedup\"]\n",
").style.format({\n",
" 'Accuracy [1]': '{:,.2%}'.format,\n",
" 'FHE Latency [1]':'{:,.2f}s'.format,\n",
" 'Our Accuracy fp32': '{:,.2%}'.format,\n",
" 'Our Accuracy FHE': '{:,.2%}'.format,\n",
" 'Our FHE Latency': '{:,.2f}s'.format,\n",
" 'Speedup': '{:,.1f}x'.format\n",
"})"
"pd.DataFrame(\n",
" [\n",
" [\n",
" 20,\n",
" PAPER_NOTES[20][1],\n",
" PAPER_NOTES[20][0],\n",
" results_cml[20][0],\n",
" results_cml[20][1],\n",
" results_cml[20][2],\n",
" PAPER_NOTES[20][0] / results_cml[20][2],\n",
" ],\n",
" [\n",
" 50,\n",
" PAPER_NOTES[50][1],\n",
" PAPER_NOTES[50][0],\n",
" results_cml[50][0],\n",
" results_cml[50][1],\n",
" results_cml[50][2],\n",
" PAPER_NOTES[50][0] / results_cml[50][2],\n",
" ],\n",
" ],\n",
" columns=[\n",
" \"Num Layers\",\n",
" \"Accuracy [1]\",\n",
" \"FHE Latency [1]\",\n",
" \"Our Accuracy fp32\",\n",
" \"Our Accuracy FHE\",\n",
" \"Our FHE Latency\",\n",
" \"Speedup\",\n",
" ],\n",
").style.format(\n",
" {\n",
" \"Accuracy [1]\": \"{:,.1%}\".format,\n",
" \"FHE Latency [1]\": \"{:,.2f}s\".format,\n",
" \"Our Accuracy fp32\": \"{:,.1%}\".format,\n",
" \"Our Accuracy FHE\": \"{:,.1%}\".format,\n",
" \"Our FHE Latency\": \"{:,.2f}s\".format,\n",
" \"Speedup\": \"{:,.1f}x\".format,\n",
" }\n",
")"
]
}
],
Expand Down
Loading

0 comments on commit 82ec666

Please sign in to comment.