Skip to content

Commit

Permalink
Fix graph optimizers.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Sep 29, 2024
1 parent c60c6b6 commit 0cf8ad0
Show file tree
Hide file tree
Showing 12 changed files with 263 additions and 213 deletions.
2 changes: 2 additions & 0 deletions examples/label_dp_sgd.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@
" metrics=[tf.keras.metrics.CategoricalAccuracy()],\n",
")\n",
"\n",
"train_datset = m.set_dataset_batching(train_dataset)\n",
"\n",
"# m.build([batch_size, 784]) # do not build if using autoparams\n",
"# m(train_dataset)\n",
"# m.summary()\n"
Expand Down
1 change: 0 additions & 1 deletion tf_shell/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

from tf_shell.python.shell_context import ShellContext64
from tf_shell.python.shell_context import create_context64
from tf_shell.python.shell_context import mod_reduce_context64
from tf_shell.python.shell_context import create_autocontext64

from tf_shell.python.shell_key import ShellKey64
Expand Down
2 changes: 1 addition & 1 deletion tf_shell/cc/optimizers/pt_pt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace grappler {

namespace {

constexpr bool const debug = true;
constexpr bool const debug = false;

bool IsReplaceableOp(NodeDef const& node) {
return IsAddPtPt(node) || IsSubPtPt(node) || IsMulPtPt(node) || IsNegPt(node);
Expand Down
2 changes: 1 addition & 1 deletion tf_shell/cc/optimizers/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ constexpr char kReduceSum[] = "ReduceSum64";

constexpr char kUnsortedCtSegmentSum[] = "UnsortedCtSegmentSum";

// TensorFlow names
constexpr char kExpandDimsVariant[] = "ExpandDimsVariant";
constexpr char kBroadcastToShape[] = "BroadcastToShape"; // TODO check name
constexpr char kReshape[] = "Reshape"; // TODO check name

constexpr char kConstOpName[] = "Const";

bool IsShellContext(NodeDef const& node);
Expand Down
87 changes: 57 additions & 30 deletions tf_shell/python/shell_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


class ShellContext64(tf.experimental.ExtensionType):
_raw_context: tf.Tensor
_raw_contexts: tf.Tensor
is_autocontext: bool
log_n: tf.Tensor
num_slots: tf.Tensor
Expand All @@ -34,7 +34,7 @@ class ShellContext64(tf.experimental.ExtensionType):

def __init__(
self,
_raw_context,
_raw_contexts,
is_autocontext,
log_n,
main_moduli,
Expand All @@ -44,7 +44,7 @@ def __init__(
scaling_factor,
seed,
):
self._raw_context = _raw_context
self._raw_contexts = _raw_contexts
self.is_autocontext = is_autocontext
self.log_n = tf.convert_to_tensor(log_n, dtype=tf.uint64)
self.num_slots = 2 ** tf.cast(log_n, dtype=tf.int64)
Expand All @@ -63,9 +63,18 @@ def __init__(
self.scaling_factor = scaling_factor
self.seed = seed

def _get_context_at_level(self, level):
level -= 1 # Context tensor start at level 1.
tf.Assert(level >= 0, [f"level must be >= 0. Got {level}"])
tf.Assert(
level < tf.shape(self._raw_contexts)[0],
[f"level must be < {tf.shape(self._raw_contexts)[0]}. Got {level}"],
)
return self._raw_contexts[level]

def _get_generic_context_spec(self):
return ShellContext64.Spec(
_raw_context=tf.TensorSpec([], dtype=tf.variant),
_raw_contexts=tf.TensorSpec([], dtype=tf.variant),
is_autocontext=self.is_autocontext,
log_n=tf.TensorSpec([], dtype=tf.uint64),
num_slots=tf.TensorSpec([], dtype=tf.int64),
Expand All @@ -80,27 +89,6 @@ def _get_generic_context_spec(self):
)


def mod_reduce_context64(context):
if not isinstance(context, ShellContext64):
raise ValueError("context must be a ShellContext64.")

smaller_context = shell_ops.modulus_reduce_context64(context._raw_context)

mod_reduced = ShellContext64(
_raw_context=smaller_context,
is_autocontext=context.is_autocontext,
log_n=context.log_n,
main_moduli=context.main_moduli[:-1],
aux_moduli=context.aux_moduli,
plaintext_modulus=context.plaintext_modulus,
noise_variance=context.noise_variance,
scaling_factor=context.scaling_factor,
seed=context.seed,
)

return mod_reduced


def create_context64(
log_n,
main_moduli,
Expand All @@ -115,17 +103,37 @@ def create_context64(
elif len(seed) < 64 and seed != "":
seed = seed.ljust(64)

shell_context, _, _, _, _ = shell_ops.context_import64(
context_sz = tf.shape(main_moduli)[0]
raw_contexts = tf.TensorArray(tf.variant, size=context_sz, clear_after_read=False)

# Generate and store the first context in the last index.
first_context, _, _, _, _ = shell_ops.context_import64(
log_n=log_n,
main_moduli=main_moduli,
aux_moduli=aux_moduli,
plaintext_modulus=plaintext_modulus,
noise_variance=noise_variance,
seed=seed,
)
raw_contexts = raw_contexts.write(context_sz - 1, first_context)

# Mod reduce to compute the remaining contexts.
raw_contexts, _ = tf.while_loop(
lambda cs, l: l > 0,
lambda cs, l: (
cs.write(l - 1, shell_ops.modulus_reduce_context64(cs.read(l))),
l - 1,
),
loop_vars=[raw_contexts, context_sz - 1],
shape_invariants=[
tf.TensorSpec(None, dtype=tf.variant),
tf.TensorSpec([], dtype=tf.int32),
],
parallel_iterations=1,
)

return ShellContext64(
_raw_context=shell_context,
_raw_contexts=raw_contexts.gather(tf.range(0, context_sz)),
is_autocontext=False,
log_n=log_n,
main_moduli=main_moduli,
Expand All @@ -146,17 +154,36 @@ def create_autocontext64(
):
if len(seed) > 64:
raise ValueError("Seed must be at most 64 characters long.")
seed = seed.ljust(64)
elif len(seed) < 64 and seed != "":
seed = seed.ljust(64)

shell_context, new_log_n, new_qs, new_ps, new_t = shell_ops.auto_shell_context64(
first_context, new_log_n, new_qs, new_ps, new_t = shell_ops.auto_shell_context64(
log2_cleartext_sz=log2_cleartext_sz,
scaling_factor=scaling_factor,
log2_noise_offset=noise_offset_log2,
noise_variance=noise_variance,
)
context_sz = tf.shape(new_qs)[0]
raw_contexts = tf.TensorArray(tf.variant, size=context_sz, clear_after_read=False)
raw_contexts = raw_contexts.write(context_sz - 1, first_context)

# Mod reduce to compute the remaining contexts.
raw_contexts, _ = tf.while_loop(
lambda cs, l: l > 0,
lambda cs, l: (
cs.write(l - 1, shell_ops.modulus_reduce_context64(cs.read(l))),
l - 1,
),
loop_vars=[raw_contexts, context_sz - 1],
shape_invariants=[
tf.TensorSpec(None, dtype=tf.variant),
tf.TensorSpec([], dtype=tf.int32),
],
parallel_iterations=1,
)

return ShellContext64(
_raw_context=shell_context,
_raw_contexts=raw_contexts.gather(tf.range(0, context_sz)),
is_autocontext=True,
log_n=new_log_n,
main_moduli=new_qs,
Expand Down
98 changes: 34 additions & 64 deletions tf_shell/python/shell_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
import tf_shell.python.shell_ops as shell_ops
from tf_shell.python.shell_context import ShellContext64
from tf_shell.python.shell_context import mod_reduce_context64
import tensorflow as tf
import typing

Expand All @@ -40,35 +39,28 @@ def create_key64(context):
keys = tf.TensorArray(tf.variant, size=context.level, clear_after_read=False)

# Generate and store the first key in the last index.
keys = keys.write(context.level - 1, shell_ops.key_gen64(context._raw_context))
keys = keys.write(
num_keys - 1, shell_ops.key_gen64(context._get_context_at_level(num_keys))
)

# Mod reduce to compute the remaining keys.
keys, context = tf.while_loop(
lambda ks, c: c.level > 2,
lambda ks, c: (
keys, _ = tf.while_loop(
lambda ks, l: l > 1,
lambda ks, l: (
ks.write(
c.level - 2,
shell_ops.modulus_reduce_key64(c._raw_context, ks.read(c.level - 1)),
l - 2,
shell_ops.modulus_reduce_key64(
context._get_context_at_level(l), ks.read(l - 1)
),
),
mod_reduce_context64(c),
l - 1,
),
loop_vars=[keys, context],
loop_vars=[keys, num_keys],
shape_invariants=[
tf.TensorSpec(None, dtype=tf.variant),
context._get_generic_context_spec(),
tf.TensorSpec([], dtype=tf.int32),
],
)

# Store the first key for level 1.
keys = tf.cond(
context.level == 2,
lambda: keys.write(
context.level - 2,
shell_ops.modulus_reduce_key64(
context._raw_context, keys.read(context.level - 1)
),
),
lambda: keys,
parallel_iterations=1,
)

return ShellKey64(_raw_keys_at_level=keys.gather(tf.range(0, num_keys)))
Expand Down Expand Up @@ -102,41 +94,30 @@ def create_rotation_key64(context, key):
num_keys = context.level
rot_keys = tf.TensorArray(
tf.variant,
size=context.level,
size=num_keys,
clear_after_read=False,
infer_shape=False,
element_shape=(),
)

# Generate rotation keys starting from the highest level.
rot_keys, context = tf.while_loop(
lambda ks, c: c.level > 1,
lambda ks, c: (
rot_keys, _ = tf.while_loop(
lambda ks, l: l > 0,
lambda ks, l: (
ks.write(
c.level - 1,
l - 1,
shell_ops.rotation_key_gen64(
c._raw_context, key._get_key_at_level(c.level)
context._get_context_at_level(l), key._get_key_at_level(l)
),
),
mod_reduce_context64(c),
l - 1,
),
loop_vars=[rot_keys, context],
loop_vars=[rot_keys, num_keys],
shape_invariants=[
tf.TensorSpec(None, dtype=tf.variant),
context._get_generic_context_spec(),
tf.TensorSpec([], dtype=tf.int32),
],
)

# Store the first key for level 1.
rot_keys = tf.cond(
context.level == 1,
lambda: rot_keys.write(
context.level - 1,
shell_ops.rotation_key_gen64(
context._raw_context, key._get_key_at_level(context.level)
),
),
lambda: rot_keys,
parallel_iterations=1,
)

return ShellRotationKey64(_raw_keys_at_level=rot_keys.gather(tf.range(0, num_keys)))
Expand Down Expand Up @@ -171,41 +152,30 @@ def create_fast_rotation_key64(context, key):
num_keys = context.level
rot_keys = tf.TensorArray(
tf.variant,
size=context.level,
size=num_keys,
clear_after_read=False,
infer_shape=False,
element_shape=(),
)

# Generate rotation keys starting from the highest level.
rot_keys, context = tf.while_loop(
lambda ks, c: c.level > 1,
lambda ks, c: (
rot_keys, _ = tf.while_loop(
lambda ks, l: l > 0,
lambda ks, l: (
ks.write(
c.level - 1,
l - 1,
shell_ops.fast_rotation_key_gen64(
c._raw_context, key._get_key_at_level(c.level)
context._get_context_at_level(l), key._get_key_at_level(l)
),
),
mod_reduce_context64(c),
l - 1,
),
loop_vars=[rot_keys, context],
loop_vars=[rot_keys, num_keys],
shape_invariants=[
tf.TensorSpec(None, dtype=tf.variant),
context._get_generic_context_spec(),
tf.TensorSpec([], dtype=tf.int32),
],
)

# Store the first key for level 1.
rot_keys = tf.cond(
context.level == 1,
lambda: rot_keys.write(
context.level - 1,
shell_ops.fast_rotation_key_gen64(
context._raw_context, key._get_key_at_level(context.level)
),
),
lambda: rot_keys,
parallel_iterations=1,
)

return ShellFastRotationKey64(
Expand Down
Loading

0 comments on commit 0cf8ad0

Please sign in to comment.