Skip to content

Commit

Permalink
Support key caching.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Oct 2, 2024
1 parent aeb9f9b commit 321a163
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 145 deletions.
253 changes: 108 additions & 145 deletions tf_shell/python/shell_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,57 +32,75 @@ def _get_key_at_level(self, level):
return self._raw_keys_at_level[level]


def create_key64(context):
def create_key64(context, cache_path=""):
if not isinstance(context, ShellContext64):
raise ValueError("context must be a ShellContext64")

try:
num_keys = context.level.numpy()
except:
num_keys = context.level

if isinstance(num_keys, tf.Tensor):
keys = tf.TensorArray(tf.variant, size=context.level, clear_after_read=False)

# Generate and store the first key in the last index.
keys = keys.write(
num_keys - 1, shell_ops.key_gen64(context._get_context_at_level(num_keys))
)

# Mod reduce to compute the remaining keys.
keys, _ = tf.while_loop(
lambda ks, l: l > 1,
lambda ks, l: (
ks.write(
l - 2,
@tf.py_function(Tout=tf.variant)
def read_or_generate_keys(cache_path, context):
exists = tf.io.gfile.exists(cache_path.numpy())
if exists:
cached_keys = tf.io.read_file(cache_path)
cached_keys = tf.io.parse_tensor(cached_keys, out_type=tf.variant)
return cached_keys

else:
# num_keys = context.level
# keys = tf.TensorArray(tf.variant, size=context.level, clear_after_read=False)

# # Generate and store the first key in the last index.
# keys = keys.write(
# num_keys - 1, shell_ops.key_gen64(context._get_context_at_level(num_keys))
# )

# # Mod reduce to compute the remaining keys.
# keys, _ = tf.while_loop(
# lambda ks, l: l > 1,
# lambda ks, l: (
# ks.write(
# l - 2,
# shell_ops.modulus_reduce_key64(
# context._get_context_at_level(l), ks.read(l - 1)
# ),
# ),
# l - 1,
# ),
# loop_vars=[keys, num_keys],
# shape_invariants=[
# tf.TensorSpec(None, dtype=tf.variant),
# tf.TensorSpec([], dtype=tf.int32),
# ],
# parallel_iterations=1,
# )

# gathered_keys = keys.gather(tf.range(0, num_keys))

# if cache_path is not None:
# print(f"Writing keys to {cache_path}")
# tf.io.write_file(cache_path, tf.io.serialize_tensor(gathered_keys))

# return gathered_keys

# The approach above is graph-friendly but this is unnecessary when
# wrapped with a tf.py_function. The code below is more concise and
# slightly faster.
num_keys = context.level.numpy()
keys = [shell_ops.key_gen64(context._get_context_at_level(num_keys))]
for i in range(num_keys, 1, -1):
keys.insert(
0,
shell_ops.modulus_reduce_key64(
context._get_context_at_level(l), ks.read(l - 1)
context._raw_contexts[i - 1], keys[0]
),
),
l - 1,
),
loop_vars=[keys, num_keys],
shape_invariants=[
tf.TensorSpec(None, dtype=tf.variant),
tf.TensorSpec([], dtype=tf.int32),
],
parallel_iterations=1,
)
)

return ShellKey64(_raw_keys_at_level=keys.gather(tf.range(0, num_keys)))
if cache_path.numpy() != b"":
tf.io.write_file(cache_path, tf.io.serialize_tensor(keys))

else:
# Compared to the approach above, this code embeds the looping logic for
# key generation in the tf graph. This is slightly faster, but should
# also allow TensorFlow to optimize the graph better, e.g. pruning
# unused keys from the graph. Note, the fact that these are stored as
# a tensor vs. python list may inhibit this.
keys = [shell_ops.key_gen64(context._get_context_at_level(num_keys))]
for i in range(num_keys, 1, -1):
keys.insert(
0, shell_ops.modulus_reduce_key64(context._raw_contexts[i - 1], keys[0])
)
return ShellKey64(_raw_keys_at_level=keys)
return tf.convert_to_tensor(keys, dtype=tf.variant)

raw_keys = read_or_generate_keys(cache_path, context)
return ShellKey64(_raw_keys_at_level=raw_keys)


class ShellRotationKey64(tf.experimental.ExtensionType):
Expand All @@ -98,7 +116,7 @@ def _get_key_at_level(self, level):
return self._raw_keys_at_level[level]


def create_rotation_key64(context, key):
def create_rotation_key64(context, key, cache_path=""):
"""Create rotation keys for any multiplicative depth of the given context.
Rotation key contains keys to perform an arbitrary number of slot rotations.
Since rotation key generation is expensive, the caller can choose to skip
Expand All @@ -110,58 +128,31 @@ def create_rotation_key64(context, key):
if not isinstance(key, ShellKey64):
raise ValueError("key must be a ShellKey64.")

try:
num_keys = context.level.numpy()
except:
num_keys = context.level

if isinstance(num_keys, tf.Tensor):
rot_keys = tf.TensorArray(
tf.variant,
size=num_keys,
clear_after_read=False,
infer_shape=False,
element_shape=(),
)

# Generate rotation keys starting from the highest level.
rot_keys, _ = tf.while_loop(
lambda ks, l: l > 0,
lambda ks, l: (
ks.write(
l - 1,
@tf.py_function(Tout=tf.variant)
def read_or_generate_keys(context, key, cache_path):
exists = tf.io.gfile.exists(cache_path.numpy())
if exists:
cached_keys = tf.io.read_file(cache_path)
cached_keys = tf.io.parse_tensor(cached_keys, out_type=tf.variant)
return cached_keys

else:
rot_keys = []
for i in range(context.level.numpy(), 0, -1):
rot_keys.insert(
0,
shell_ops.rotation_key_gen64(
context._get_context_at_level(l), key._get_key_at_level(l)
context._raw_contexts[i - 1], key._raw_keys_at_level[i - 1]
),
),
l - 1,
),
loop_vars=[rot_keys, num_keys],
shape_invariants=[
tf.TensorSpec(None, dtype=tf.variant),
tf.TensorSpec([], dtype=tf.int32),
],
parallel_iterations=1,
)
)

return ShellRotationKey64(
_raw_keys_at_level=rot_keys.gather(tf.range(0, num_keys))
)
else:
# Compared to the approach above, this code embeds the looping logic for
# key generation in the tf graph. This is slightly faster, but should
# also allow TensorFlow to optimize the graph better, e.g. pruning
# unused keys from the graph. Note, the fact that these are stored as
# a tensor vs. python list may inhibit this.
rot_keys = []
for i in range(num_keys, 0, -1):
rot_keys.insert(
0,
shell_ops.rotation_key_gen64(
context._raw_contexts[i - 1], key._raw_keys_at_level[i - 1]
),
)
return ShellRotationKey64(_raw_keys_at_level=rot_keys)
if cache_path.numpy() != b"":
tf.io.write_file(cache_path, tf.io.serialize_tensor(rot_keys))

return tf.convert_to_tensor(rot_keys, dtype=tf.variant)

raw_keys = read_or_generate_keys(context, key, cache_path)
return ShellRotationKey64(_raw_keys_at_level=raw_keys)


class ShellFastRotationKey64(tf.experimental.ExtensionType):
Expand All @@ -177,7 +168,7 @@ def _get_key_at_level(self, level):
return self._raw_keys_at_level[level]


def create_fast_rotation_key64(context, key):
def create_fast_rotation_key64(context, key, cache_path=""):
"""Create fast rotation keys for any multiplicative depth of the given context.
Rotation key contains keys *decrypt* a previously "fast" rotated ciphertext.
These keys are much faster to generated than regular rotation keys, and
Expand All @@ -190,56 +181,28 @@ def create_fast_rotation_key64(context, key):
if not isinstance(key, ShellKey64):
raise ValueError("key must be a ShellKey64.")

try:
num_keys = context.level.numpy()
except:
num_keys = context.level

if isinstance(num_keys, tf.Tensor):
num_keys = context.level
rot_keys = tf.TensorArray(
tf.variant,
size=num_keys,
clear_after_read=False,
infer_shape=False,
element_shape=(),
)

# Generate rotation keys starting from the highest level.
rot_keys, _ = tf.while_loop(
lambda ks, l: l > 0,
lambda ks, l: (
ks.write(
l - 1,
@tf.py_function(Tout=tf.variant)
def read_or_generate_keys(context, key, cache_path):
exists = tf.io.gfile.exists(cache_path.numpy())
if exists:
cached_keys = tf.io.read_file(cache_path)
cached_keys = tf.io.parse_tensor(cached_keys, out_type=tf.variant)
return cached_keys

else:
rot_keys = []
for i in range(context.level, 0, -1):
rot_keys.insert(
0,
shell_ops.fast_rotation_key_gen64(
context._get_context_at_level(l), key._get_key_at_level(l)
context._raw_contexts[i - 1], key._raw_keys_at_level[i - 1]
),
),
l - 1,
),
loop_vars=[rot_keys, num_keys],
shape_invariants=[
tf.TensorSpec(None, dtype=tf.variant),
tf.TensorSpec([], dtype=tf.int32),
],
parallel_iterations=1,
)
)

return ShellFastRotationKey64(
_raw_keys_at_level=rot_keys.gather(tf.range(0, num_keys))
)
else:
# Compared to the approach above, this code embeds the looping logic for
# key generation in the tf graph. This is slightly faster, but should
# also allow TensorFlow to optimize the graph better, e.g. pruning
# unused keys from the graph. Note, the fact that these are stored as
# a tensor vs. python list may inhibit this.
rot_keys = []
for i in range(num_keys, 0, -1):
rot_keys.insert(
0,
shell_ops.fast_rotation_key_gen64(
context._raw_contexts[i - 1], key._raw_keys_at_level[i - 1]
),
)
return ShellFastRotationKey64(_raw_keys_at_level=rot_keys)
if cache_path.numpy() != b"":
tf.io.write_file(cache_path, tf.io.serialize_tensor(rot_keys))

return tf.convert_to_tensor(rot_keys, dtype=tf.variant)

raw_keys = read_or_generate_keys(context, key, cache_path)
return ShellFastRotationKey64(_raw_keys_at_level=raw_keys)
14 changes: 14 additions & 0 deletions tf_shell/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,20 @@ py_test(
],
)

py_test(
name = "key_test",
size = "small",
srcs = [
"key_test.py",
"test_utils.py",
],
imports = ["./"],
deps = [
"//tf_shell:tf_shell_lib",
requirement("tensorflow-cpu"),
],
)

py_test(
name = "segment_test",
size = "enormous",
Expand Down
43 changes: 43 additions & 0 deletions tf_shell/test/key_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/usr/bin/python
#
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
import tf_shell
import tempfile


class TestShellContext(tf.test.TestCase):
def test_key_save(self):
# Num plaintext bits: 48, noise bits: 65
# Max plaintext value: 127, est error: 3.840%
context = tf_shell.create_context64(
log_n=11,
main_moduli=[288230376151748609, 18014398509506561],
plaintext_modulus=281474976768001,
scaling_factor=1052673,
)
key_path = tempfile.mkdtemp() # Every trace gets a new key.
key = tf_shell.create_key64(context, key_path + "cached_test_key")

a = tf.ones([2**11, 2, 3], dtype=tf.float32) * 10
ea = tf_shell.to_encrypted(a, key, context)

# Try decrypting with the cached key
cached_key = tf_shell.create_key64(context, key_path + "cached_test_key")
self.assertAllClose(a, tf_shell.to_tensorflow(ea, cached_key))


if __name__ == "__main__":
tf.test.main()

0 comments on commit 321a163

Please sign in to comment.