Skip to content

Commit

Permalink
Fix issues with memory semantics for distributed execution.
Browse files Browse the repository at this point in the history
SHELL objects like RnsBgvCiphertext hold raw pointers to moduli. These
are all derived from a leader unique_ptr stored in an RnsContext object.
This causes problems when a ciphertext is sent to another machine. This
commit encodes SHELLs memory semantics with TensorFlows execution
manager.
  • Loading branch information
james-choncholas committed Sep 26, 2024
1 parent a8bdb36 commit f39e052
Show file tree
Hide file tree
Showing 23 changed files with 594 additions and 286 deletions.
55 changes: 22 additions & 33 deletions examples/label_dp_sgd.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2024-09-13 00:05:53.757039: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
"2024-09-13 00:05:53.780113: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"2024-09-23 06:17:56.606856: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
"2024-09-23 06:17:56.633047: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
}
Expand Down Expand Up @@ -90,27 +90,7 @@
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"tf_shell_sequential\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" shell_dense (ShellDense) (4096, 64) 50176 \n",
" \n",
" shell_dense_1 (ShellDense) (4096, 10) 640 \n",
" \n",
"=================================================================\n",
"Total params: 50816 (198.50 KB)\n",
"Trainable params: 50816 (198.50 KB)\n",
"Non-trainable params: 0 (0.00 Byte)\n",
"_________________________________________________________________\n"
]
}
],
"outputs": [],
"source": [
"# Turn on the shell optimizer to use autocontext.\n",
"shell_optimizers.enable_tf_shell_optimizer()\n",
Expand All @@ -135,8 +115,6 @@
" scaling_factor=3,\n",
" noise_offset_log2=68,\n",
" ),\n",
" None,\n",
" None,\n",
" True,\n",
")\n",
"\n",
Expand All @@ -145,13 +123,11 @@
" optimizer=tf.keras.optimizers.Adam(0.1),\n",
" loss=tf.keras.losses.CategoricalCrossentropy(),\n",
" metrics=[tf.keras.metrics.CategoricalAccuracy()],\n",
" # metrics=[\"accuracy\"],\n",
" # metrics=[keras.metrics.SparseCategoricalAccuracy()],\n",
")\n",
"\n",
"m.build([batch_size, 784])\n",
"# m.build([batch_size, 784]) # do not build if using autoparams\n",
"# m(train_dataset)\n",
"m.summary()\n"
"# m.summary()\n"
]
},
{
Expand All @@ -163,9 +139,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2024-09-13 00:05:55.276666: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.\n",
"2024-09-13 00:05:55.276687: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.\n",
"2024-09-13 00:05:55.276871: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.\n"
"2024-09-23 06:17:58.163788: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.\n",
"2024-09-23 06:17:58.163815: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.\n",
"2024-09-23 06:17:58.163882: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.\n"
]
},
{
Expand All @@ -176,7 +152,11 @@
"log_n: 12\n",
"t: 65537\n",
"qs: 288230376151760897 288230376152137729 \n",
"14/14 [==============================] - 111s 8s/step - categorical_accuracy: 0.0000e+00 - val_categorical_accuracy: 0.6646\n"
"Final parameters:\n",
"log_n: 12\n",
"t: 65537\n",
"qs: 288230376151760897 288230376152137729 \n",
"15/15 [==============================] - 109s 7s/step - num_slots: 4096.0000 - val_categorical_accuracy: 0.0973\n"
]
}
],
Expand All @@ -192,6 +172,15 @@
"\n",
"history = m.fit(train_dataset, epochs=1, validation_data=val_dataset, callbacks = [tboard_callback])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"m.summary()"
]
}
],
"metadata": {
Expand Down
2 changes: 0 additions & 2 deletions tf_shell/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@

from tf_shell.python.shell_key import ShellKey64
from tf_shell.python.shell_key import create_key64
from tf_shell.python.shell_key import mod_reduce_key64

from tf_shell.python.shell_key import ShellRotationKey64
from tf_shell.python.shell_key import create_rotation_key64
from tf_shell.python.shell_key import ShellFastRotationKey64
Expand Down
22 changes: 15 additions & 7 deletions tf_shell/cc/kernels/add_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,11 @@ class AddCtCtOp : public OpKernel {
ShellAddSub add_or_sub;
OP_REQUIRES_VALUE(SymmetricCt ct_c, op_ctx, add_or_sub(ct_a, ct_b));

SymmetricCtVariant ct_c_var(std::move(ct_c), shell_ctx_var->ct_context_,
shell_ctx_var->error_params_);
// SHELL's addition preserves moduli pointers of the first input.
// Ensure the output holds smart pointers to the input's context to
// prevent premature deletion of the moduli.
SymmetricCtVariant ct_c_var(std::move(ct_c), ct_a_var->ct_context,
ct_a_var->error_params);
flat_output(i) = std::move(ct_c_var);
}
}
Expand Down Expand Up @@ -210,8 +213,12 @@ class AddCtPtOp : public OpKernel {
ShellAddSub add_or_sub;
OP_REQUIRES_VALUE(SymmetricCt ct_c, op_ctx, add_or_sub(ct_a, pt_b));

SymmetricCtVariant ct_c_var(std::move(ct_c), shell_ctx_var->ct_context_,
shell_ctx_var->error_params_);
// The output ct will hold raw pointers to moduli stored in the input's
// context. Ensure the output ciphertext Variant wrapper holds smart
// pointers to the input's context to prevent premature deletion of the
// moduli
SymmetricCtVariant ct_c_var(std::move(ct_c), ct_a_var->ct_context,
ct_a_var->error_params);
flat_output(i) = std::move(ct_c_var);
}
}
Expand Down Expand Up @@ -324,9 +331,10 @@ class NegCtOp : public OpKernel {

OP_REQUIRES_VALUE(auto ct_out, op_ctx, ct_a.Negate());

SymmetricCtVariant ct_out_var(std::move(ct_out),
shell_ctx_var->ct_context_,
shell_ctx_var->error_params_);
// The output ct will hold smart pointers to the input's context
// to prevent premature deletion of the moduli.
SymmetricCtVariant ct_out_var(std::move(ct_out), ct_a_var->ct_context,
ct_a_var->error_params);
flat_output(i) = std::move(ct_out_var);
}
}
Expand Down
20 changes: 18 additions & 2 deletions tf_shell/cc/kernels/context_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,35 @@ class ContextImportOp : public OpKernel {
OP_REQUIRES_VALUE(tstring t_seed, op_ctx, GetScalar<tstring>(op_ctx, 5));
std::string seed(t_seed.c_str());

// Allocate the output.
// Allocate the outputs.
Tensor* out0;
OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, TensorShape{}, &out0));
Tensor* out1;
OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(1, TensorShape{}, &out1));
Tensor* out2;
OP_REQUIRES_OK(op_ctx,
op_ctx->allocate_output(2, TensorShape{qs.size()}, &out2));
Tensor* out3;
OP_REQUIRES_OK(op_ctx,
op_ctx->allocate_output(3, TensorShape{ps.size()}, &out3));
Tensor* out4;
OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(4, TensorShape{}, &out4));

// Initialize the context variant and store it in the output.
ContextVariant<T> ctx_variant{};
OP_REQUIRES_OK(op_ctx, ctx_variant.Initialize(log_n, qs, ps, pt_modulus,
noise_variance, seed));

out0->scalar<Variant>()() = std::move(ctx_variant);

// Output other parameters for usage with auto-context.
out1->scalar<uint64_t>()() = log_n;
for (size_t i = 0; i < qs.size(); ++i) {
out2->flat<T>()(i) = qs[i];
}
for (size_t i = 0; i < ps.size(); ++i) {
out3->flat<T>()(i) = ps[i];
}
out4->scalar<T>()() = pt_modulus;
}
};

Expand Down
20 changes: 13 additions & 7 deletions tf_shell/cc/kernels/mod_switch_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,15 @@ class ModulusReduceKeyOp : public OpKernel {
OP_REQUIRES_VALUE(SymmetricKeyVariant<T> const* secret_key_var, op_ctx,
GetVariant<SymmetricKeyVariant<T>>(op_ctx, 1));
OP_REQUIRES(
op_ctx, secret_key_var->key != nullptr,
op_ctx, secret_key_var != nullptr,
InvalidArgument("SymmetricKeyVariant did not unwrap successfully."));
OP_REQUIRES_OK(op_ctx,
const_cast<SymmetricKeyVariant<T>*>(secret_key_var)
->MaybeLazyDecode(shell_ctx_var->ct_context_,
shell_ctx_var->noise_variance_));
OP_REQUIRES(op_ctx, secret_key_var->key != nullptr,
InvalidArgument(
"SymmetricKeyVariant key did not unwrap successfully."));
Key secret_key = *secret_key_var->key; // Deep copy.

// Allocate a scalar output tensor to store the reduced key.
Expand All @@ -102,9 +105,11 @@ class ModulusReduceKeyOp : public OpKernel {

OP_REQUIRES_OK(op_ctx, secret_key.ModReduce());

// Store the reduced key in the output tensor.
// Store the reduced key in the output tensor. Keep a reference to the
// original context (even though it has the un-reduced moduli) to ensure
// the moduli held internally by the key are not deleted prematurely.
SymmetricKeyVariant<T> reduced_key_variant(std::move(secret_key),
shell_ctx_var->ct_context_);
secret_key_var->ct_context);
out->scalar<Variant>()() = std::move(reduced_key_variant);
}
};
Expand Down Expand Up @@ -169,10 +174,11 @@ class ModulusReduceCtOp : public OpKernel {

OP_REQUIRES_OK(op_ctx, result_ct.ModReduce(t, ql_inv));

// Store in the output.
SymmetricCtVariant<T> result_var(std::move(result_ct),
shell_ctx_var->ct_context_,
shell_ctx_var->error_params_);
// Store in the output. Keep a reference to the original context to
// ensure the moduli held internally by the ciphertext are not deleted
// prematurely.
SymmetricCtVariant<T> result_var(
std::move(result_ct), ct_a_var->ct_context, ct_a_var->error_params);
flat_output(i) = std::move(result_var);
}
};
Expand Down
66 changes: 46 additions & 20 deletions tf_shell/cc/kernels/mul_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,12 @@ class MulCtCtOp : public OpKernel {

OP_REQUIRES_VALUE(SymmetricCt ct_c, op_ctx, ct_a * ct_b);

SymmetricCtVariant ct_c_var(std::move(ct_c), shell_ctx_var->ct_context_,
shell_ctx_var->error_params_);
// Wrap the result in a SymmetricCtVariant and store it in the output.
// SHELL's multiplication preserves moduli pointers of the first input.
// Ensure the output holds smart pointers to the input's context to
// prevent premature deletion of the moduli.
SymmetricCtVariant ct_c_var(std::move(ct_c), ct_a_var->ct_context,
ct_a_var->error_params);
flat_output(i) = std::move(ct_c_var);
}
}
Expand Down Expand Up @@ -182,8 +186,13 @@ class MulCtPtOp : public OpKernel {
OP_REQUIRES_VALUE(SymmetricCt ct_c, op_ctx,
ct_a * pt_b); // shell absorb operation

SymmetricCtVariant ct_c_var(std::move(ct_c), shell_ctx_var->ct_context_,
shell_ctx_var->error_params_);
// Wrap the result in a SymmetricCtVariant and store it in the output.
// The output ct will hold raw pointers to moduli stored in the input's
// context. Ensure the output ciphertext Variant wrapper holds smart
// pointers to the input's context to prevent premature deletion of the
// moduli
SymmetricCtVariant ct_c_var(std::move(ct_c), ct_a_var->ct_context,
ct_a_var->error_params);
flat_output(i) = std::move(ct_c_var);
}
};
Expand Down Expand Up @@ -261,8 +270,8 @@ class MulShellTfScalarOp : public OpKernel {
OP_REQUIRES_VALUE(RnsPolynomial result, op_ctx,
poly.Mul(wrapped_b, shell_ctx->MainPrimeModuli()));

CtOrPolyVariant result_var(std::move(result),
shell_ctx_var->ct_context_);
PolynomialVariant<T> result_var(std::move(result),
shell_ctx_var->ct_context_);
flat_output(i) = std::move(result_var);
} else if constexpr (std::is_same<CtOrPolyVariant,
SymmetricCtVariant<T>>::value) {
Expand All @@ -275,9 +284,13 @@ class MulShellTfScalarOp : public OpKernel {
OP_REQUIRES_VALUE(SymmetricCt result, op_ctx,
ct * wrapped_b); // shell aborb operation

// The output ct will hold raw pointers to moduli stored in the input's
// context. Ensure the output ciphertext Variant wrapper holds smart
// pointers to the input's context to prevent premature deletion of the
// moduli
SymmetricCtVariant result_var(std::move(result),
shell_ctx_var->ct_context_,
shell_ctx_var->error_params_);
ct_or_pt_var->ct_context,
ct_or_pt_var->error_params);
flat_output(i) = std::move(result_var);
}
}
Expand Down Expand Up @@ -492,9 +505,8 @@ class MatMulCtPtOp : public OpKernel {
OP_REQUIRES_OK(op_ctx, ct_result.AddInPlace(scaled));
}

SymmetricCtVariant ct_result_var(std::move(ct_result),
shell_ctx_var->ct_context_,
shell_ctx_var->error_params_);
SymmetricCtVariant ct_result_var(
std::move(ct_result), ct_a_var->ct_context, ct_a_var->error_params);
flat_output(i) = std::move(ct_result_var);
}
};
Expand Down Expand Up @@ -554,10 +566,6 @@ class MatMulPtCtOp : public OpKernel {
auto const& sub_powers = shell_ctx_var->substitution_powers_;
OP_REQUIRES(op_ctx, shell_ctx != nullptr,
InvalidArgument("Shell context object is empty."));
auto const& main_moduli = shell_ctx->MainPrimeModuli();
std::vector<Modulus const*> main_moduli_vector;
main_moduli_vector.assign(main_moduli.begin(), main_moduli.end());

Encoder const* encoder = shell_ctx_var->encoder_.get();
// TODO if debug
OP_REQUIRES(op_ctx, encoder != nullptr,
Expand Down Expand Up @@ -623,6 +631,22 @@ class MatMulPtCtOp : public OpKernel {
auto flat_output = output->shaped<Variant, 3>(
{num_pt_outer_dims, num_pt_inner_rows, num_ct_cols});

// Extract the first ciphertext in b to recover the moduli. This is used
// to create new ciphertexts after fast rotations and note the moduli must
// come from the ciphertext, not the shell context, to ensure smart pointers
// are properly preserved.
SymmetricCtVariant<T> const* first_ct_b_var =
std::move(flat_b(0).get<SymmetricCtVariant<T>>());
OP_REQUIRES(op_ctx, first_ct_b_var != nullptr,
InvalidArgument("SymmetricCtVariant at flat index: 0",
" for input b did not unwrap successfully."));
OP_REQUIRES_OK(op_ctx, const_cast<SymmetricCtVariant<T>*>(first_ct_b_var)
->MaybeLazyDecode(shell_ctx_var->ct_context_,
shell_ctx_var->error_params_));
auto const& main_moduli = first_ct_b_var->ct.Moduli();
std::vector<Modulus const*> main_moduli_vector;
main_moduli_vector.assign(main_moduli.begin(), main_moduli.end());

// Setup constants used in parallelizing the computation.
auto thread_pool =
op_ctx->device()->tensorflow_cpu_worker_threads()->workers;
Expand Down Expand Up @@ -721,9 +745,11 @@ class MatMulPtCtOp : public OpKernel {
};

// Recreate the ciphertext with the new components.
ct_result = SymmetricCt{
std::move(components), main_moduli_vector, ct_b.PowerOfS(),
ct_b.Error() * ct_b.LogN(), ct_b.ErrorParams()};
// TODO(james-choncholas): Noise estimation is not correct.
ct_result = SymmetricCt{std::move(components), main_moduli_vector,
ct_result.PowerOfS(),
ct_result.Error() * ct_result.LogN(),
ct_result.ErrorParams()};

} else {
for (int shift = 1; shift < num_slots / 2; shift <<= 1) {
Expand All @@ -749,8 +775,8 @@ class MatMulPtCtOp : public OpKernel {
// the result of the reduce sum operation. Store in the output
// tensor.
SymmetricCtVariant<T> ct_result_var(std::move(ct_result),
shell_ctx_var->ct_context_,
shell_ctx_var->error_params_);
ct_b_var->ct_context,
ct_b_var->error_params);
flat_output(outer, i, ct_col) = std::move(ct_result_var);
}
}
Expand Down
Loading

0 comments on commit f39e052

Please sign in to comment.