Skip to content

Commit

Permalink
Update shell_ml to use new ShellTensors. Bump version.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Feb 2, 2024
1 parent 9eb5b5e commit 8d67b99
Show file tree
Hide file tree
Showing 9 changed files with 695 additions and 393 deletions.
310 changes: 225 additions & 85 deletions examples/label_dp_sgd.ipynb

Large diffs are not rendered by default.

16 changes: 10 additions & 6 deletions shell_ml/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,18 @@ def relu(x):


def relu_deriv(y, dy):
assert not isinstance(y, shell_tensor.ShellTensor64)
# Cannot operate on individual slots of a shell tensor.
# Formulate the problem as element-wise multiplication.
# t = np.dtype(dy.plaintext_dtype.as_numpy_dtype)
if isinstance(dy, shell_tensor.ShellTensor64):
# Cannot operate on individual slots of a shell tensor.
# Formulate the problem as element-wise multiplication.
t = np.dtype(dy.plaintext_dtype.as_numpy_dtype)
mask = tf.where(y <= 0, t.type(0), t.type(1))
return dy * mask
dy_dtype = dy.plaintext_dtype
else:
return dy * (y > 0)
dy_dtype = dy.dtype
mask = tf.where(
y <= 0, tf.constant(0, dtype=dy_dtype), tf.constant(1, dtype=dy_dtype)
)
return dy * mask


def sigmoid(x):
Expand Down
40 changes: 26 additions & 14 deletions shell_ml/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def __init__(
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
skip_normalization=True,
fxp_fractional_bits=1,
weight_dtype=tf.int64,
):
self.units = int(units)
self.activation = activation
Expand All @@ -37,17 +39,23 @@ def __init__(
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.skip_normalization = skip_normalization
self.fxp_fractional_bits = fxp_fractional_bits
self.weight_dtype = weight_dtype

self.built = False
self.weights = []

def build(self, input_shape):
self.units_in = int(input_shape[1])
self.kernel = self.kernel_initializer([self.units_in, self.units])

# Convert kernel to fixed point.
self.kernel_initializer = tf.cast(self.kernel, self.weight_dtype)
self.weights.append(self.kernel)

if self.use_bias:
self.bias = self.bias_initializer([self.units])
self.bias = tf.cast(self.bias, self.weight_dtype)
self.weights.append(self.bias)
else:
self.bias = None
Expand All @@ -59,9 +67,6 @@ def __call__(self, inputs):

self._layer_input = inputs

print("inputs dtype: ", inputs.dtype)
print("kernel dtype: ", self.weights[0].dtype)

if self.use_bias:
outputs = tf.matmul(inputs, self.weights[0]) + self.weights[1]
else:
Expand All @@ -76,7 +81,7 @@ def __call__(self, inputs):

return outputs

def backward(self, dy, is_first_layer=False, temp_key=None):
def backward(self, dy, rotation_key, is_first_layer=False):
"""dense backward"""
x = self._layer_input
z = self._layer_intermediate
Expand All @@ -88,23 +93,30 @@ def backward(self, dy, is_first_layer=False, temp_key=None):
if self.activation_deriv is not None:
dy = self.activation_deriv(z, dy)

if isinstance(dy, shell_tensor.ShellTensor64):
# It is a good idea to reduce the multiplication count before
# the multiplication with the kernel. Multiplying by the kernel
# requires a reduce_sum operation which makes it easy to exceed
# the plaintext modulus by the fixed point fractional bits, if
# the multiplication count is too high.
dy = dy.get_at_multiplication_count(0)

if is_first_layer:
d_x = None # no gradient needed for first layer
else:
d_x = shell_tensor.matmul(dy, tf.transpose(kernel))

# TODO(jchoncholas): this is stubbed in for now. Since we dont have a
# "reduce_sum" operation, e.g. compute the sum of all elements in a
# polynomial, we cheat and decrypt-compute-encrypt. This requires
# passing the key etc. to the op but once slot rotation is implemented
# this wont be necessary.
d_weights = shell_tensor.matmul(tf.transpose(x), dy, temp_key)
kernel_t = tf.transpose(kernel)
d_x = shell_tensor.matmul(dy, kernel_t)

# Perform the fixed point multiplication.
d_weights = shell_tensor.matmul(tf.transpose(x), dy, rotation_key)

if not self.skip_normalization:
d_weights = d_weights / batch_size
assert False, "Normalization not implemented yet."
# d_weights = d_weights / batch_size
grad_weights.append(d_weights)

if self.use_bias:
assert False, "Bias Not implemented yet"
assert False, "Bias not implemented yet"
# TODO(jchoncholas): reduce_sum is very expensive and requires slot rotation.
# Not implemented yet. A better way than the reduce sum is to set batch size to 1 less
# and use that last slot as the bias with input 1.
Expand Down
8 changes: 4 additions & 4 deletions shell_ml/test/BUILD
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
load("@pip//:requirements.bzl", "requirement")

py_test(
name = "mnist_plaintext_post_scale_test",
name = "mnist_post_scale_test",
size = "enormous",
srcs = ["mnist_plaintext_post_scale_test.py"],
srcs = ["mnist_post_scale_test.py"],
deps = [
"//shell_ml",
requirement("tensorflow-cpu"),
],
)

py_test(
name = "mnist_shell_backprop_test",
name = "mnist_backprop_test",
size = "enormous",
srcs = ["mnist_shell_backprop_test.py"],
srcs = ["mnist_backprop_test.py"],
deps = [
"//shell_ml",
requirement("tensorflow-cpu"),
Expand Down
196 changes: 196 additions & 0 deletions shell_ml/test/mnist_backprop_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
#!/usr/bin/python
#
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import time
from datetime import datetime
import tensorflow as tf
import keras
import numpy as np
import shell_tensor
import shell_ml

plaintext_dtype = tf.float32
fxp_num_bits = 5 # number of fractional bits.


# Shell setup.
log_slots = 11
slots = 2**log_slots

# Num plaintext bits: 27, noise bits: 65, num rns moduli: 2
context = shell_tensor.create_context64(
log_n=11,
main_moduli=[140737488486401, 140737488498689],
aux_moduli=[],
plaintext_modulus=134246401,
noise_variance=8,
seed="",
)
key = shell_tensor.create_key64(context)
rotation_key = shell_tensor.create_rotation_key64(context, key)

# Training setup.
epochs = 1
batch_size = slots
stop_after_n_batches = 1

# Prepare the dataset.
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train, x_test = np.reshape(x_train, (-1, 784)), np.reshape(x_test, (-1, 784))
x_train, x_test = x_train / np.float32(255.0), x_test / np.float32(255.0)
y_train, y_test = tf.one_hot(y_train, 10), tf.one_hot(y_test, 10)

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=2048).batch(batch_size)

val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(batch_size)


# Create the layers
hidden_layer = shell_ml.ShellDense(
64,
activation=shell_ml.relu,
activation_deriv=shell_ml.relu_deriv,
fxp_fractional_bits=fxp_num_bits,
weight_dtype=plaintext_dtype,
)
output_layer = shell_ml.ShellDense(
10,
activation=shell_ml.sigmoid,
# activation_deriv=shell_ml.sigmoid_deriv,
fxp_fractional_bits=fxp_num_bits,
weight_dtype=plaintext_dtype,
)

# Call the layers once to create the weights.
y1 = hidden_layer(tf.zeros((batch_size, 784)))
y2 = output_layer(y1)

loss_fn = shell_ml.CategoricalCrossentropy()
optimizer = shell_ml.Adam()
optimizer.compile([hidden_layer.weights, output_layer.weights])


def train_step(x, y):
# Forward pass always in plaintext.
y_1 = hidden_layer(x)
y_pred = output_layer(y_1)
# loss = loss_fn(y, y_pred) # this is expensive and not needed for training

# Backward pass.
dJ_dy_pred = loss_fn.grad(y, y_pred)

(dJ_dw1, dJ_dx1) = output_layer.backward(
dJ_dy_pred, rotation_key, is_first_layer=False
)

(dJ_dw0, dJ_dx0_unused) = hidden_layer.backward(
dJ_dx1, rotation_key, is_first_layer=True
)

# Only return the weight gradients at [0], not the bias gradients at [1].
return dJ_dw1[0], dJ_dw0[0]


class TestMNISTBackprop(tf.test.TestCase):
def test_mnist_plaintext_backprop(self):
(x_batch, y_batch) = next(iter(train_dataset))

start_time = time.time()

# Plaintext backprop splitting the batch in half vertically.
top_x_batch, bottom_x_batch = tf.split(x_batch, num_or_size_splits=2, axis=0)
top_y_batch, bottom_y_batch = tf.split(y_batch, num_or_size_splits=2, axis=0)
top_output_layer_grad, top_hidden_layer_grad = train_step(
top_x_batch, top_y_batch
)
bottom_output_layer_grad, bottom_hidden_layer_grad = train_step(
bottom_x_batch, bottom_y_batch
)

# Stack the top and bottom gradients back together along a new
# outer dimension.
output_layer_grad = tf.concat(
[
tf.expand_dims(top_output_layer_grad, axis=0),
tf.expand_dims(bottom_output_layer_grad, axis=0),
],
axis=0,
)
hidden_layer_grad = tf.concat(
[
tf.expand_dims(top_hidden_layer_grad, axis=0),
tf.expand_dims(bottom_hidden_layer_grad, axis=0),
],
axis=0,
)

# Encrypt y using fixed point representation.
enc_y_batch = shell_tensor.to_shell_tensor(
context, y_batch, fxp_fractional_bits=fxp_num_bits
).get_encrypted(key)

# Backprop.
enc_output_layer_grad, enc_hidden_layer_grad = train_step(x_batch, enc_y_batch)

# Decrypt the gradients.
repeated_output_layer_grad = enc_output_layer_grad.get_decrypted(key)
repeated_hidden_layer_grad = enc_hidden_layer_grad.get_decrypted(key)

print(f"\tFinished Stamp: {time.time() - start_time}")
print(f"\tOutput Layer Noise: {enc_output_layer_grad.noise_bits}")
print(f"\tHidden Layer Noise: {enc_hidden_layer_grad.noise_bits}")
print(
f"\tOutput Layer fxp bits: {enc_output_layer_grad.num_fxp_fractional_bits}"
)
print(
f"\tHidden Layer fxp bits: {enc_hidden_layer_grad.num_fxp_fractional_bits}"
)

shell_output_layer_grad = tf.concat(
[
tf.expand_dims(repeated_output_layer_grad[0, ...], 0),
tf.expand_dims(repeated_output_layer_grad[slots // 2, ...], 0),
],
axis=0,
)
shell_hidden_layer_grad = tf.concat(
[
tf.expand_dims(repeated_hidden_layer_grad[0, ...], 0),
tf.expand_dims(repeated_hidden_layer_grad[slots // 2, ...], 0),
],
axis=0,
)

# Compare the gradients.
self.assertAllClose(
output_layer_grad,
shell_output_layer_grad,
atol=slots * 2.0 ** (-fxp_num_bits),
)

self.assertAllClose(
hidden_layer_grad,
shell_hidden_layer_grad,
atol=slots * 2.0 ** (-fxp_num_bits - 2),
)

print(f"Total plaintext training time: {time.time() - start_time} seconds")


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 8d67b99

Please sign in to comment.