diff --git a/docs/advanced_examples/ConvolutionalNeuralNetwork.ipynb b/docs/advanced_examples/ConvolutionalNeuralNetwork.ipynb index ee74318c8..b9774e3bb 100644 --- a/docs/advanced_examples/ConvolutionalNeuralNetwork.ipynb +++ b/docs/advanced_examples/ConvolutionalNeuralNetwork.ipynb @@ -42,11 +42,10 @@ "from sklearn.datasets import load_digits\n", "from sklearn.model_selection import train_test_split\n", "from torch import nn\n", - "from torch.nn.utils import prune\n", "from torch.utils.data import DataLoader, TensorDataset\n", "from tqdm import tqdm\n", "\n", - "from concrete.ml.torch.compile import compile_brevitas_qat_model\n", + "from concrete.ml.torch.compile import compile_torch_model\n", "\n", "# And some helpers for visualization.\n", "\n", @@ -71,7 +70,7 @@ "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -126,80 +125,27 @@ "metadata": {}, "outputs": [], "source": [ - "import brevitas.nn as qnn\n", - "\n", - "\n", "class TinyCNN(nn.Module):\n", - " \"\"\"A very small CNN to classify the sklearn digits data-set.\n", - "\n", - " This class also allows pruning to a maximum of 10 active neurons, which\n", - " should help keep the accumulator bit width low.\n", - " \"\"\"\n", + " \"\"\"A very small CNN to classify the sklearn digits data-set.\"\"\"\n", "\n", - " def __init__(self, n_classes, n_bits) -> None:\n", + " def __init__(self, n_classes) -> None:\n", " \"\"\"Construct the CNN with a configurable number of classes.\"\"\"\n", " super().__init__()\n", "\n", - " a_bits = n_bits\n", - " w_bits = n_bits\n", - "\n", " # This network has a total complexity of 1216 MAC\n", - " self.q1 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)\n", - " self.conv1 = qnn.QuantConv2d(1, 8, 3, stride=1, padding=0, weight_bit_width=w_bits)\n", - " self.q2 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)\n", - " self.conv2 = qnn.QuantConv2d(8, 16, 3, stride=2, padding=0, weight_bit_width=w_bits)\n", - " self.q3 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)\n", - " self.conv3 = qnn.QuantConv2d(16, 32, 2, stride=1, padding=0, weight_bit_width=w_bits)\n", - " self.q4 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)\n", - " self.fc1 = qnn.QuantLinear(\n", - " 32,\n", - " n_classes,\n", - " bias=True,\n", - " weight_bit_width=w_bits,\n", - " )\n", - "\n", - " # Enable pruning, prepared for training\n", - " self.toggle_pruning(True)\n", - "\n", - " def toggle_pruning(self, enable):\n", - " \"\"\"Enables or removes pruning.\"\"\"\n", - "\n", - " # Maximum number of active neurons (i.e., corresponding weight != 0)\n", - " n_active = 12\n", - "\n", - " # Go through all the convolution layers\n", - " for layer in (self.conv1, self.conv2, self.conv3):\n", - " s = layer.weight.shape\n", - "\n", - " # Compute fan-in (number of inputs to a neuron)\n", - " # and fan-out (number of neurons in the layer)\n", - " st = [s[0], np.prod(s[1:])]\n", - "\n", - " # The number of input neurons (fan-in) is the product of\n", - " # the kernel width x height x inChannels.\n", - " if st[1] > n_active:\n", - " if enable:\n", - " # This will create a forward hook to create a mask tensor that is multiplied\n", - " # with the weights during forward. The mask will contain 0s or 1s\n", - " prune.l1_unstructured(layer, \"weight\", (st[1] - n_active) * st[0])\n", - " else:\n", - " # When disabling pruning, the mask is multiplied with the weights\n", - " # and the result is stored in the weights member\n", - " prune.remove(layer, \"weight\")\n", + " self.conv1 = nn.Conv2d(1, 8, 3, stride=1, padding=0)\n", + " self.conv2 = nn.Conv2d(8, 16, 3, stride=2, padding=0)\n", + " self.conv3 = nn.Conv2d(16, 32, 2, stride=1, padding=0)\n", + " self.fc1 = nn.Linear(32, n_classes)\n", "\n", " def forward(self, x):\n", " \"\"\"Run inference on the tiny CNN, apply the decision layer on the reshaped conv output.\"\"\"\n", - "\n", - " x = self.q1(x)\n", " x = self.conv1(x)\n", " x = torch.relu(x)\n", - " x = self.q2(x)\n", " x = self.conv2(x)\n", " x = torch.relu(x)\n", - " x = self.q3(x)\n", " x = self.conv3(x)\n", " x = torch.relu(x)\n", - " x = self.q4(x)\n", " x = x.flatten(1)\n", " x = self.fc1(x)\n", " return x" @@ -237,16 +183,19 @@ "name": "stderr", "output_type": "stream", "text": [ - "Training with 2 bit weights and activations: 100%|██████████| 150/150 [01:05<00:00, 2.28it/s]\n", - "Training with 3 bit weights and activations: 100%|██████████| 150/150 [01:00<00:00, 2.49it/s]\n", - "Training with 4 bit weights and activations: 100%|██████████| 150/150 [01:02<00:00, 2.39it/s]\n", - "Training with 5 bit weights and activations: 100%|██████████| 150/150 [01:02<00:00, 2.38it/s]\n", - "Training with 6 bit weights and activations: 100%|██████████| 150/150 [01:05<00:00, 2.29it/s]\n" + "Training: 0%| | 0/150 [00:00" ] @@ -287,29 +236,17 @@ "test_dataset = TensorDataset(torch.Tensor(x_test), torch.Tensor(y_test))\n", "test_dataloader = DataLoader(test_dataset)\n", "\n", - "nets = []\n", - "bit_range = range(2, 7)\n", - "\n", "# Train the network with Adam, output the test set accuracy every epoch\n", - "losses = []\n", - "for n_bits in bit_range:\n", - " net = TinyCNN(10, n_bits)\n", - " losses_bits = []\n", - " optimizer = torch.optim.Adam(net.parameters())\n", - " for _ in tqdm(range(N_EPOCHS), desc=f\"Training with {n_bits} bit weights and activations\"):\n", - " losses_bits.append(train_one_epoch(net, optimizer, train_dataloader))\n", - " losses.append(losses_bits)\n", - "\n", - " # Finally, disable pruning (sets the pruned weights to 0)\n", - " net.toggle_pruning(False)\n", - " nets.append(net)\n", + "net = TinyCNN(10)\n", + "losses_bits = []\n", + "optimizer = torch.optim.Adam(net.parameters())\n", + "for _ in tqdm(range(N_EPOCHS), desc=\"Training\"):\n", + " losses_bits.append(train_one_epoch(net, optimizer, train_dataloader))\n", "\n", "fig = plt.figure(figsize=(8, 4))\n", - "for losses_bits in losses:\n", - " plt.plot(losses_bits)\n", + "plt.plot(losses_bits)\n", "plt.ylabel(\"Cross Entropy Loss\")\n", "plt.xlabel(\"Epoch\")\n", - "plt.legend(list(map(str, bit_range)))\n", "plt.title(\"Training set loss during training\")\n", "plt.grid(True)\n", "plt.show()" @@ -333,11 +270,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Test accuracy for 2-bit weights and activations: 72.89%\n", - "Test accuracy for 3-bit weights and activations: 89.56%\n", - "Test accuracy for 4-bit weights and activations: 96.00%\n", - "Test accuracy for 5-bit weights and activations: 96.89%\n", - "Test accuracy for 6-bit weights and activations: 96.44%\n" + "Test accuracy for 6-bit weights and activations: 98.22%\n" ] } ], @@ -372,8 +305,7 @@ " )\n", "\n", "\n", - "for idx, net in enumerate(nets):\n", - " test_torch(net, bit_range[idx], test_dataloader)" + "test_torch(net, 6, test_dataloader)" ] }, { @@ -441,14 +373,9 @@ "### Test the network using Simulation\n", "\n", "Note that this is not a test in FHE. The simulated FHE mode gives \n", - "insight into the number of accumulator bits that are needed and the \n", - "impact of FHE execution on the accuracy.\n", + "insight about the impact of FHE execution on the accuracy.\n", "\n", - "The torch/brevitas neural network is quantized during training and, for inference, it is converted \n", - "to FHE by Concrete ML using a dedicated function, `compile_brevitas_qat_model`.\n", - "\n", - "In this test we determine the accuracy and accumulator bit-widths for the various quantization settings\n", - "that are trained above." + "The torch neural network is converted to FHE by Concrete ML using a dedicated function, `compile_torch_model`." ] }, { @@ -461,22 +388,15 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 450/450 [00:01<00:00, 280.84it/s]\n", - "100%|██████████| 450/450 [00:01<00:00, 280.89it/s]\n", - "100%|██████████| 450/450 [00:01<00:00, 280.64it/s]\n", - "100%|██████████| 450/450 [00:01<00:00, 251.91it/s]\n", - "100%|██████████| 450/450 [00:01<00:00, 251.96it/s]" + "WARNING: high error rate, more details with --display-optimizer-choice\n", + "100%|██████████| 450/450 [00:01<00:00, 295.41it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Simulated FHE execution for 2 bit network: 1.61s, 280.09it/s\n", - "Simulated FHE execution for 3 bit network: 1.61s, 279.71it/s\n", - "Simulated FHE execution for 4 bit network: 1.61s, 280.00it/s\n", - "Simulated FHE execution for 5 bit network: 1.79s, 251.08it/s\n", - "Simulated FHE execution for 6 bit network: 1.79s, 251.20it/s\n" + "Simulated FHE execution for 6 bit network accuracy: 0.98%\n" ] }, { @@ -488,118 +408,19 @@ } ], "source": [ - "accs = []\n", - "accum_bits = []\n", - "sim_time = []\n", - "\n", - "\n", - "for idx in range(len(bit_range)):\n", - " q_module = compile_brevitas_qat_model(nets[idx], x_train)\n", - "\n", - " accum_bits.append(q_module.fhe_circuit.graph.maximum_integer_bit_width())\n", - "\n", - " start_time = time.time()\n", - " accs.append(\n", - " test_with_concrete(\n", - " q_module,\n", - " test_dataloader,\n", - " use_sim=True,\n", - " )\n", - " )\n", - " sim_time.append(time.time() - start_time)\n", - "\n", - "for idx, vl_time_bits in enumerate(sim_time):\n", - " print(\n", - " f\"Simulated FHE execution for {bit_range[idx]} bit network: {vl_time_bits:.2f}s, \"\n", - " f\"{len(test_dataloader) / vl_time_bits:.2f}it/s\"\n", - " )" - ] - }, - { - "cell_type": "markdown", - "id": "da4666bf", - "metadata": {}, - "source": [ - "### Analysis of quantized results\n", + "n_bits = 6\n", "\n", - "We plot the accuracies obtained for various levels of quantization of weights and activations. \n", - "In addition, we plot the maximum accumulator bit width required to run inference of the network for\n", - "each weight and activation bit width. This is shown as the numbers next to the graph markers. \n", + "q_module = compile_torch_model(net, x_train, rounding_threshold_bits=6, p_error=0.1)\n", "\n", - "This accumulator bit width is determined by the compiler and is an important quantity in designing FHE-compatible neural networks." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "5b31947f", - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "fig = plt.figure(figsize=(12, 8))\n", - "plt.rcParams[\"font.size\"] = 14\n", - "plt.plot(bit_range, accs, \"-x\")\n", - "for bits, acc, accum in zip(bit_range, accs, accum_bits):\n", - " plt.gca().annotate(str(accum), (bits - 0.1, acc + 0.01))\n", - "plt.ylabel(\"Accuracy on test set\")\n", - "plt.xlabel(\"Weight & activation quantization\")\n", - "plt.grid(True)\n", - "plt.title(\n", - " \"Accuracy for varying quantization bit width. Accumulator bit-width shown on graph markers\"\n", + "start_time = time.time()\n", + "accs = test_with_concrete(\n", + " q_module,\n", + " test_dataloader,\n", + " use_sim=True,\n", ")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "7acb5a3f", - "metadata": {}, - "source": [ - "### Test the CNN in FHE\n", - "\n", - "We identify 3 bit weights and activations as a good compromise for which the maximum accumulator size\n", - "is low but the accuracy is acceptable. We can now compile to FHE and execute on encrypted data." - ] - }, - { - "cell_type": "markdown", - "id": "a218aff7", - "metadata": {}, - "source": [ - "### 1. Compile to FHE" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "16fa5e3b", - "metadata": {}, - "outputs": [], - "source": [ - "bits_for_fhe = 3\n", - "idx_bits_fhe = bit_range.index(bits_for_fhe)\n", - "\n", - "accum_bits_required = accum_bits[idx_bits_fhe]\n", - "\n", - "q_module_fhe = None\n", - "\n", - "net = nets[idx_bits_fhe]\n", + "sim_time = time.time() - start_time\n", "\n", - "q_module_fhe = compile_brevitas_qat_model(\n", - " net,\n", - " x_train,\n", - ")" + "print(f\"Simulated FHE execution for {n_bits} bit network accuracy: {accs:.2f}%\")" ] }, { @@ -607,12 +428,12 @@ "id": "2875e825", "metadata": {}, "source": [ - "### 2. Generate Keys" + "### Generate Keys" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "id": "6e8b6471", "metadata": {}, "outputs": [ @@ -620,14 +441,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "Keygen time: 137.80s\n" + "Keygen time: 3.98s\n" ] } ], "source": [ - "# Generate keys first, this may take some time (up to 30min)\n", + "# Generate keys first\n", "t = time.time()\n", - "q_module_fhe.fhe_circuit.keygen()\n", + "q_module.fhe_circuit.keygen()\n", "print(f\"Keygen time: {time.time()-t:.2f}s\")" ] }, @@ -649,14 +470,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 1/1 [00:32<00:00, 32.33s/it]" + "100%|██████████| 100/100 [04:19<00:00, 2.59s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Time per inference in FHE: 32.34\n" + "Time per inference in FHE: 2.59 with 99.00% accuracy\n" ] }, { @@ -669,16 +490,23 @@ ], "source": [ "# Run inference in FHE on a single encrypted example\n", - "mini_test_dataset = TensorDataset(torch.Tensor(x_test[[0], :]), torch.Tensor(y_test[[0]]))\n", + "mini_test_dataset = TensorDataset(torch.Tensor(x_test[:100, :]), torch.Tensor(y_test[:100]))\n", "mini_test_dataloader = DataLoader(mini_test_dataset)\n", "\n", "t = time.time()\n", - "test_with_concrete(\n", - " q_module_fhe,\n", + "accuracy_test = test_with_concrete(\n", + " q_module,\n", " mini_test_dataloader,\n", " use_sim=False,\n", ")\n", - "print(f\"Time per inference in FHE: {(time.time() - t) / len(mini_test_dataset):.2f}\")" + "elapsed_time = time.time() - t\n", + "time_per_inference = elapsed_time / len(mini_test_dataset)\n", + "accuracy_percentage = 100 * accuracy_test\n", + "\n", + "print(\n", + " f\"Time per inference in FHE: {time_per_inference:.2f} \"\n", + " f\"with {accuracy_percentage:.2f}% accuracy\"\n", + ")" ] }, { @@ -688,13 +516,9 @@ "source": [ "### Conclusion\n", "\n", - "We see that quantization with **3** bit weight and activations is the best viable FHE configuration,\n", - "as the accumulator bit width for this configuration is between **7 and 8** bits (can vary due to the final \n", - "distribution of the weights). The accuracy in this setting, 92% is a few percentage points \n", - "under the maximum accuracy achievable with larger accumulator bit widths (97-98%). \n", + "In this example, a simple CNN model is trained with torch and reach 98% accuracy in clear. The model is then converted to FHE and evaluated over 100 samples in FHE.\n", "\n", - "Compiling the higher bit-width networks is also possible, but in this example, to ensure FHE execution is fast\n", - "we used the lower bit-width quantization setting.\n" + "The model in FHE achieves **the same accuracy** as the original torch model with a FHE execution time of **2.9 seconds** per image." ] } ],