Skip to content

Commit

Permalink
Key caching supports distributed execution.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Oct 11, 2024
1 parent 5b39de4 commit 6105513
Show file tree
Hide file tree
Showing 4 changed files with 322 additions and 229 deletions.
47 changes: 25 additions & 22 deletions tf_shell/cc/optimizers/moduli_autotune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,28 @@ Status AddScalarConstNode(T value, utils::Mutation* mutation,
return status;
}

Status GetAutoShellContextParams(utils::MutableGraphView& graph_view,
ShellAutoParams& params) {
// Get the plaintext modulus t.
auto const* autocontext_node_view = graph_view.GetNode(kShellAutoContext);
// auto const* autocontext_node_def = autocontext_node_view->node();
StatusOr<utils::MutableNodeView*> GetAutoShellContextNode(
utils::MutableGraphView& graph_view) {
utils::MutableNodeView* found = nullptr;
for (int i = 0; i < graph_view.NumNodes(); ++i) {
auto const* node_view = graph_view.GetNode(i);
auto const* node_def = node_view->node();
if (node_def->op() == kShellAutoContext) {
if (found != nullptr) {
return errors::InvalidArgument(
"Multiple AutoShellContext nodes found in the graph.");
}
found = graph_view.GetNode(i);
}
}
if (found == nullptr) {
return errors::NotFound("AutoShellContext node not found.");
}
return found;
}

Status GetAutoShellContextParams(utils::MutableNodeView* autocontext_node_view,
ShellAutoParams& params) {
auto const* cleartext_bits_node =
autocontext_node_view->GetRegularFanin(0).node_view()->node();
TF_RETURN_IF_ERROR(GetScalarConstValue<uint64_t, DT_UINT64>(
Expand Down Expand Up @@ -656,8 +672,8 @@ Status EstimateNoiseGrowth(utils::MutableGraphView& graph_view,
Status ReplaceAutoparamWithContext(utils::MutableGraphView& graph_view,
ShellParams const& params,
ShellAutoParams const& auto_params) {
utils::MutableNodeView* autocontext_node_view =
graph_view.GetNode(kShellAutoContext);
TF_ASSIGN_OR_RETURN(auto* autocontext_node_view,
GetAutoShellContextNode(graph_view));
int autocontext_node_index = autocontext_node_view->node_index();

if constexpr (debug_graph) {
Expand Down Expand Up @@ -753,25 +769,12 @@ Status ModuliAutotuneOptimizer::Optimize(Cluster* cluster,

// See if an autocontext node exists in the graph. If not, there is nothing
// to do.
auto const* autocontext_view = graph_view.GetNode(kShellAutoContext);
if (autocontext_view == nullptr) {
*optimized_graph = std::move(mutable_item.graph);
return OkStatus();
}

// Make sure there is only one autocontext node.
std::string duplicate_autocontext = kShellAutoContext;
duplicate_autocontext += "_1";
auto const* duplicate_autocontext_view =
graph_view.GetNode(duplicate_autocontext);
if (duplicate_autocontext_view != nullptr) {
return errors::FailedPrecondition("Multiple autocontext nodes found.");
}
TF_ASSIGN_OR_RETURN(auto* autocontext, GetAutoShellContextNode(graph_view));

// Use GetScalarConstValue to get value of plaintext modulus,
// etc.
ShellAutoParams auto_params;
TF_RETURN_IF_ERROR(GetAutoShellContextParams(graph_view, auto_params));
TF_RETURN_IF_ERROR(GetAutoShellContextParams(autocontext, auto_params));

// Topological sort so all subsequent traversals are in order.
TF_RETURN_IF_ERROR(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
Expand Down
247 changes: 168 additions & 79 deletions tf_shell/python/shell_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ShellContext64(tf.experimental.ExtensionType):
noise_variance: int
scaling_factor: int
seed: str
id_str: str

def __init__(
self,
Expand All @@ -43,6 +44,7 @@ def __init__(
noise_variance,
scaling_factor,
seed,
id_str,
):
self._raw_contexts = _raw_contexts
self.is_autocontext = is_autocontext
Expand All @@ -65,15 +67,10 @@ def __init__(
self.noise_variance = noise_variance
self.scaling_factor = scaling_factor
self.seed = seed
self.id_str = id_str

def _get_context_at_level(self, level):
level -= 1 # Context tensor start at level 1.
tf.Assert(level >= 0, [f"level must be >= 0. Got {level}"])
tf.Assert(
level < tf.shape(self._raw_contexts)[0],
[f"level must be < {tf.shape(self._raw_contexts)[0]}. Got {level}"],
)
return self._raw_contexts[level]
return self._raw_contexts[level - 1] # 0th level does not exist.

def _get_generic_context_spec(self):
return ShellContext64.Spec(
Expand Down Expand Up @@ -106,46 +103,65 @@ def create_context64(
elif len(seed) < 64 and seed != "":
seed = seed.ljust(64)

context_sz = tf.shape(main_moduli)[0]
raw_contexts = tf.TensorArray(tf.variant, size=context_sz, clear_after_read=False)

# Generate and store the first context in the last index.
first_context, _, _, _, _ = shell_ops.context_import64(
log_n=log_n,
main_moduli=main_moduli,
aux_moduli=aux_moduli,
plaintext_modulus=plaintext_modulus,
noise_variance=noise_variance,
seed=seed,
)
raw_contexts = raw_contexts.write(context_sz - 1, first_context)

# Mod reduce to compute the remaining contexts.
raw_contexts, _ = tf.while_loop(
lambda cs, l: l > 0,
lambda cs, l: (
cs.write(l - 1, shell_ops.modulus_reduce_context64(cs.read(l))),
l - 1,
),
loop_vars=[raw_contexts, context_sz - 1],
shape_invariants=[
tf.TensorSpec(None, dtype=tf.variant),
tf.TensorSpec([], dtype=tf.int32),
],
parallel_iterations=1,
id_str = str(
hash(
(
log_n,
tuple(main_moduli),
plaintext_modulus,
tuple(aux_moduli),
noise_variance,
scaling_factor,
seed,
)
)
)

return ShellContext64(
_raw_contexts=raw_contexts.gather(tf.range(0, context_sz)),
is_autocontext=False,
log_n=log_n,
main_moduli=main_moduli,
aux_moduli=aux_moduli,
plaintext_modulus=plaintext_modulus,
noise_variance=noise_variance,
scaling_factor=scaling_factor,
seed=seed,
)
with tf.name_scope("create_context64"):
context_sz = tf.shape(main_moduli)[0]
raw_contexts = tf.TensorArray(
tf.variant, size=context_sz, clear_after_read=False
)

# Generate and store the first context in the last index.
first_context, _, _, _, _ = shell_ops.context_import64(
log_n=log_n,
main_moduli=main_moduli,
aux_moduli=aux_moduli,
plaintext_modulus=plaintext_modulus,
noise_variance=noise_variance,
seed=seed,
)
raw_contexts = raw_contexts.write(context_sz - 1, first_context)

# Mod reduce to compute the remaining contexts.
raw_contexts, _ = tf.while_loop(
lambda cs, l: l > 0,
lambda cs, l: (
cs.write(l - 1, shell_ops.modulus_reduce_context64(cs.read(l))),
l - 1,
),
loop_vars=[raw_contexts, context_sz - 1],
shape_invariants=[
tf.TensorSpec(None, dtype=tf.variant),
tf.TensorSpec([], dtype=tf.int32),
],
parallel_iterations=1,
)
raw_contexts = raw_contexts.gather(tf.range(0, context_sz))

return ShellContext64(
_raw_contexts=raw_contexts,
is_autocontext=False,
log_n=log_n,
main_moduli=main_moduli,
aux_moduli=aux_moduli,
plaintext_modulus=plaintext_modulus,
noise_variance=noise_variance,
scaling_factor=scaling_factor,
seed=seed,
id_str=id_str,
)


def create_autocontext64(
Expand All @@ -154,45 +170,118 @@ def create_autocontext64(
noise_offset_log2,
noise_variance=8,
seed="",
cache_path=None, # WARN: Caching will not update if graph changes.
):
if len(seed) > 64:
raise ValueError("Seed must be at most 64 characters long.")
elif len(seed) < 64 and seed != "":
seed = seed.ljust(64)

first_context, new_log_n, new_qs, new_ps, new_t = shell_ops.auto_shell_context64(
log2_cleartext_sz=log2_cleartext_sz,
scaling_factor=scaling_factor,
log2_noise_offset=noise_offset_log2,
noise_variance=noise_variance,
)
context_sz = tf.shape(new_qs)[0]
raw_contexts = tf.TensorArray(tf.variant, size=context_sz, clear_after_read=False)
raw_contexts = raw_contexts.write(context_sz - 1, first_context)

# Mod reduce to compute the remaining contexts.
raw_contexts, _ = tf.while_loop(
lambda cs, l: l > 0,
lambda cs, l: (
cs.write(l - 1, shell_ops.modulus_reduce_context64(cs.read(l))),
l - 1,
),
loop_vars=[raw_contexts, context_sz - 1],
shape_invariants=[
tf.TensorSpec(None, dtype=tf.variant),
tf.TensorSpec([], dtype=tf.int32),
],
parallel_iterations=1,
id_str = str(
hash(
(
log2_cleartext_sz,
scaling_factor,
noise_offset_log2,
noise_variance,
seed,
)
)
)

return ShellContext64(
_raw_contexts=raw_contexts.gather(tf.range(0, context_sz)),
is_autocontext=True,
log_n=new_log_n,
main_moduli=new_qs,
aux_moduli=new_ps,
plaintext_modulus=new_t,
noise_variance=noise_variance,
scaling_factor=scaling_factor,
seed=seed,
)
with tf.name_scope("create_autocontext64"):
if cache_path != None:
context_cache_path = cache_path + "/" + id_str + "_context"
log_n_cache_path = cache_path + "/" + id_str + "_log_n"
qs_cache_path = cache_path + "/" + id_str + "_qs"
ps_cache_path = cache_path + "/" + id_str + "_ps"
t_cache_path = cache_path + "/" + id_str + "_t"
paths = [
context_cache_path,
log_n_cache_path,
qs_cache_path,
ps_cache_path,
t_cache_path,
]

exists = all([tf.io.gfile.exists(p) for p in paths])
if exists:

def read_and_parse(path, ttype):
return tf.io.parse_tensor(tf.io.read_file(path), out_type=ttype)

raw_contexts = read_and_parse(context_cache_path, tf.variant)
new_log_n = read_and_parse(log_n_cache_path, tf.uint64)
new_qs = read_and_parse(qs_cache_path, tf.uint64)
new_ps = read_and_parse(ps_cache_path, tf.uint64)
new_t = read_and_parse(t_cache_path, tf.uint64)

# log_n and t will always be scalars. Set the static shape
# manually to help with shape inference.
new_log_n.set_shape([])
new_t.set_shape([])

return ShellContext64(
_raw_contexts=raw_contexts,
is_autocontext=True,
log_n=new_log_n,
main_moduli=new_qs,
aux_moduli=new_ps,
plaintext_modulus=new_t,
noise_variance=noise_variance,
scaling_factor=scaling_factor,
seed=seed,
id_str=id_str,
)

# Cache was not found, generate the context.
first_context, new_log_n, new_qs, new_ps, new_t = (
shell_ops.auto_shell_context64(
log2_cleartext_sz=log2_cleartext_sz,
scaling_factor=scaling_factor,
log2_noise_offset=noise_offset_log2,
noise_variance=noise_variance,
)
)
context_sz = tf.shape(new_qs)[0]
raw_contexts = tf.TensorArray(
tf.variant, size=context_sz, clear_after_read=False
)
raw_contexts = raw_contexts.write(context_sz - 1, first_context)

# Mod reduce to compute the remaining contexts.
raw_contexts, _ = tf.while_loop(
lambda cs, l: l > 0,
lambda cs, l: (
cs.write(l - 1, shell_ops.modulus_reduce_context64(cs.read(l))),
l - 1,
),
loop_vars=[raw_contexts, context_sz - 1],
shape_invariants=[
tf.TensorSpec(None, dtype=tf.variant),
tf.TensorSpec([], dtype=tf.int32),
],
parallel_iterations=1,
)

raw_contexts = raw_contexts.gather(tf.range(0, context_sz))

if cache_path != None:
tf.io.write_file(context_cache_path, tf.io.serialize_tensor(raw_contexts))
tf.io.write_file(log_n_cache_path, tf.io.serialize_tensor(new_log_n))
tf.io.write_file(qs_cache_path, tf.io.serialize_tensor(new_qs))
tf.io.write_file(ps_cache_path, tf.io.serialize_tensor(new_ps))
tf.io.write_file(t_cache_path, tf.io.serialize_tensor(new_t))

return ShellContext64(
_raw_contexts=raw_contexts,
is_autocontext=True,
log_n=new_log_n,
main_moduli=new_qs,
aux_moduli=new_ps,
plaintext_modulus=new_t,
noise_variance=noise_variance,
scaling_factor=scaling_factor,
seed=seed,
id_str=id_str,
)
Loading

0 comments on commit 6105513

Please sign in to comment.