From 0c0ecac60d044eeaddd10a8d268b2d9621084821 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Sun, 10 Nov 2024 01:50:32 +0000 Subject: [PATCH] Parallelize remaining single threaded add and mul kernels. --- examples/benchmark.ipynb | 37 +-- tf_shell/cc/kernels/add_kernels.cc | 387 +++++++++++++++++++---------- tf_shell/cc/kernels/mul_kernels.cc | 286 +++++++++++++-------- 3 files changed, 459 insertions(+), 251 deletions(-) diff --git a/examples/benchmark.ipynb b/examples/benchmark.ipynb index 8671910..8c98e59 100644 --- a/examples/benchmark.ipynb +++ b/examples/benchmark.ipynb @@ -16,9 +16,16 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-10-29 21:41:36.488386: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", - "2024-10-29 21:41:36.514318: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", - "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" + "2024-11-10 01:15:42.366708: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2024-11-10 01:15:42.367179: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-11-10 01:15:42.369814: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-11-10 01:15:42.376276: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", + "E0000 00:00:1731201342.387674 117259 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "E0000 00:00:1731201342.391135 117259 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", + "2024-11-10 01:15:42.402569: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2024-11-10 01:15:43.514528: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)\n" ] }, { @@ -64,7 +71,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "0.5060953950014664\n" + "0.5067476989970601\n" ] } ], @@ -85,7 +92,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "0.17610475800029235\n" + "0.1824721849989146\n" ] } ], @@ -106,7 +113,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "0.5579067959988606\n" + "0.5525883640002576\n" ] } ], @@ -127,7 +134,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "0.7779848270001821\n" + "0.18785935000050813\n" ] } ], @@ -148,7 +155,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "0.44140414300272823\n" + "0.19354859399754787\n" ] } ], @@ -169,13 +176,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "1.415064643999358\n" + "0.17636438800036558\n" ] } ], "source": [ "def ct_ct_mul():\n", - " return enc_a * enc_a\n", + " return enc_a * 4\n", "\n", "time = min(timeit.Timer(ct_ct_mul).repeat(repeat=3, number=1))\n", "print(time)" @@ -190,7 +197,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "0.8980931189980765\n" + "0.7591958359989803\n" ] } ], @@ -211,7 +218,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "0.7085658289979619\n" + "0.7197426070015354\n" ] } ], @@ -232,7 +239,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "27.201639023998723\n" + "27.007230432998767\n" ] } ], @@ -253,7 +260,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "360.84974249600054\n" + "370.09557524000047\n" ] } ], @@ -274,7 +281,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "6.758062808999966\n" + "5.250123850997625\n" ] } ], diff --git a/tf_shell/cc/kernels/add_kernels.cc b/tf_shell/cc/kernels/add_kernels.cc index 405d3e6..bf85b8c 100644 --- a/tf_shell/cc/kernels/add_kernels.cc +++ b/tf_shell/cc/kernels/add_kernels.cc @@ -112,45 +112,68 @@ class AddCtCtOp : public OpKernel { IndexConverterFunctor a_bcaster(bcast.output_shape(), a.shape()); IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape()); + // Recover num_slots from first ciphertext. + SymmetricCtVariant const* ct_var = + std::move(flat_a(0).get>()); + OP_REQUIRES( + op_ctx, ct_var != nullptr, + InvalidArgument("SymmetricCtVariant a did not unwrap successfully.")); + OP_REQUIRES_OK( + op_ctx, const_cast*>(ct_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_, shell_ctx_var->error_params_)); + SymmetricCt const& ct = ct_var->ct; + int num_slots = 1 << ct.LogN(); + int num_components = ct.NumModuli(); + // Allocate the output tensor which is the same size as one of the inputs. Tensor* output; TensorShape output_shape = BCast::ToShape(bcast.output_shape()); OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, output_shape, &output)); auto flat_output = output->flat(); - for (int i = 0; i < flat_output.dimension(0); ++i) { - SymmetricCtVariant const* ct_a_var = - std::move(flat_a(a_bcaster(i)).get>()); - OP_REQUIRES(op_ctx, ct_a_var != nullptr, - InvalidArgument("SymmetricCtVariant at flat index: ", i, - " for input a did not unwrap successfully.")); - OP_REQUIRES_OK( - op_ctx, - const_cast*>(ct_a_var)->MaybeLazyDecode( - shell_ctx_var->ct_context_, shell_ctx_var->error_params_)); - SymmetricCt const& ct_a = ct_a_var->ct; - - SymmetricCtVariant const* ct_b_var = - std::move(flat_b(b_bcaster(i)).get>()); - OP_REQUIRES(op_ctx, ct_b_var != nullptr, - InvalidArgument("SymmetricCtVariant at flat index: ", i, - " for input b did not unwrap successfully.")); - OP_REQUIRES_OK( - op_ctx, - const_cast*>(ct_b_var)->MaybeLazyDecode( - shell_ctx_var->ct_context_, shell_ctx_var->error_params_)); - SymmetricCt const& ct_b = ct_b_var->ct; - - ShellAddSub add_or_sub; - OP_REQUIRES_VALUE(SymmetricCt ct_c, op_ctx, add_or_sub(ct_a, ct_b)); - - // SHELL's addition preserves moduli pointers of the first input. - // Ensure the output holds smart pointers to the input's context to - // prevent premature deletion of the moduli. - SymmetricCtVariant ct_c_var(std::move(ct_c), ct_a_var->ct_context, - ct_a_var->error_params); - flat_output(i) = std::move(ct_c_var); - } + auto add_in_range = [&](int start, int end) { + for (int i = start; i < end; ++i) { + SymmetricCtVariant const* ct_a_var = + std::move(flat_a(a_bcaster(i)).get>()); + OP_REQUIRES( + op_ctx, ct_a_var != nullptr, + InvalidArgument("SymmetricCtVariant at flat index: ", i, + " for input a did not unwrap successfully.")); + OP_REQUIRES_OK( + op_ctx, + const_cast*>(ct_a_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_, shell_ctx_var->error_params_)); + SymmetricCt const& ct_a = ct_a_var->ct; + + SymmetricCtVariant const* ct_b_var = + std::move(flat_b(b_bcaster(i)).get>()); + OP_REQUIRES( + op_ctx, ct_b_var != nullptr, + InvalidArgument("SymmetricCtVariant at flat index: ", i, + " for input b did not unwrap successfully.")); + OP_REQUIRES_OK( + op_ctx, + const_cast*>(ct_b_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_, shell_ctx_var->error_params_)); + SymmetricCt const& ct_b = ct_b_var->ct; + + ShellAddSub add_or_sub; + OP_REQUIRES_VALUE(SymmetricCt ct_c, op_ctx, add_or_sub(ct_a, ct_b)); + + // SHELL's addition preserves moduli pointers of the first input. + // Ensure the output holds smart pointers to the input's context to + // prevent premature deletion of the moduli. + SymmetricCtVariant ct_c_var(std::move(ct_c), ct_a_var->ct_context, + ct_a_var->error_params); + flat_output(i) = std::move(ct_c_var); + } + }; + + auto thread_pool = + op_ctx->device()->tensorflow_cpu_worker_threads()->workers; + int const cost_per_add = 30 * num_slots * num_components; + thread_pool->ParallelFor(flat_output.dimension(0), cost_per_add, + add_in_range); } }; @@ -182,45 +205,69 @@ class AddCtPtOp : public OpKernel { IndexConverterFunctor a_bcaster(bcast.output_shape(), a.shape()); IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape()); + // Recover num_slots from first ciphertext. + SymmetricCtVariant const* ct_var = + std::move(flat_a(0).get>()); + OP_REQUIRES( + op_ctx, ct_var != nullptr, + InvalidArgument("SymmetricCtVariant a did not unwrap successfully.")); + OP_REQUIRES_OK( + op_ctx, const_cast*>(ct_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_, shell_ctx_var->error_params_)); + SymmetricCt const& ct = ct_var->ct; + int num_slots = 1 << ct.LogN(); + int num_components = ct.NumModuli(); + // Allocate the output tensor which is the same size as one of the inputs. Tensor* output; TensorShape output_shape = BCast::ToShape(bcast.output_shape()); OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, output_shape, &output)); auto flat_output = output->flat(); - for (int i = 0; i < flat_output.dimension(0); ++i) { - SymmetricCtVariant const* ct_a_var = - std::move(flat_a(a_bcaster(i)).get>()); - OP_REQUIRES(op_ctx, ct_a_var != nullptr, - InvalidArgument("SymmetricCtVariant at flat index: ", i, - " for input a did not unwrap successfully.")); - OP_REQUIRES_OK( - op_ctx, - const_cast*>(ct_a_var)->MaybeLazyDecode( - shell_ctx_var->ct_context_, shell_ctx_var->error_params_)); - SymmetricCt const& ct_a = ct_a_var->ct; - - PolynomialVariant const* pv_b_var = - std::move(flat_b(b_bcaster(i)).get>()); - OP_REQUIRES(op_ctx, pv_b_var != nullptr, - InvalidArgument("PolynomialVariant at flat index: ", i, - " for input b did not unwrap successfully.")); - OP_REQUIRES_OK( - op_ctx, const_cast*>(pv_b_var)->MaybeLazyDecode( - shell_ctx_var->ct_context_)); - RnsPolynomial const& pt_b = pv_b_var->poly; - - ShellAddSub add_or_sub; - OP_REQUIRES_VALUE(SymmetricCt ct_c, op_ctx, add_or_sub(ct_a, pt_b)); - - // The output ct will hold raw pointers to moduli stored in the input's - // context. Ensure the output ciphertext Variant wrapper holds smart - // pointers to the input's context to prevent premature deletion of the - // moduli - SymmetricCtVariant ct_c_var(std::move(ct_c), ct_a_var->ct_context, - ct_a_var->error_params); - flat_output(i) = std::move(ct_c_var); - } + auto add_in_range = [&](int start, int end) { + for (int i = start; i < end; ++i) { + SymmetricCtVariant const* ct_a_var = + std::move(flat_a(a_bcaster(i)).get>()); + OP_REQUIRES( + op_ctx, ct_a_var != nullptr, + InvalidArgument("SymmetricCtVariant at flat index: ", i, + " for input a did not unwrap successfully.")); + OP_REQUIRES_OK( + op_ctx, + const_cast*>(ct_a_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_, shell_ctx_var->error_params_)); + SymmetricCt const& ct_a = ct_a_var->ct; + + PolynomialVariant const* pv_b_var = + std::move(flat_b(b_bcaster(i)).get>()); + OP_REQUIRES( + op_ctx, pv_b_var != nullptr, + InvalidArgument("PolynomialVariant at flat index: ", i, + " for input b did not unwrap successfully.")); + OP_REQUIRES_OK( + op_ctx, + const_cast*>(pv_b_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_)); + RnsPolynomial const& pt_b = pv_b_var->poly; + + ShellAddSub add_or_sub; + OP_REQUIRES_VALUE(SymmetricCt ct_c, op_ctx, add_or_sub(ct_a, pt_b)); + + // The output ct will hold raw pointers to moduli stored in the input's + // context. Ensure the output ciphertext Variant wrapper holds smart + // pointers to the input's context to prevent premature deletion of the + // moduli + SymmetricCtVariant ct_c_var(std::move(ct_c), ct_a_var->ct_context, + ct_a_var->error_params); + flat_output(i) = std::move(ct_c_var); + } + }; + + auto thread_pool = + op_ctx->device()->tensorflow_cpu_worker_threads()->workers; + int const cost_per_add = 30 * num_slots * num_components; + thread_pool->ParallelFor(flat_output.dimension(0), cost_per_add, + add_in_range); } }; @@ -253,41 +300,65 @@ class AddPtPtOp : public OpKernel { IndexConverterFunctor a_bcaster(bcast.output_shape(), a.shape()); IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape()); + // Recover num_slots from first plaintext. + PolynomialVariant const* pt_var = + std::move(flat_a(0).get>()); + OP_REQUIRES( + op_ctx, pt_var != nullptr, + InvalidArgument("PolynomialVariant a did not unwrap successfully.")); + OP_REQUIRES_OK(op_ctx, + const_cast*>(pt_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_)); + RnsPolynomial const& pt = pt_var->poly; + int num_slots = 1 << pt.LogN(); + // Allocate the output tensor which is the same size as one of the inputs. Tensor* output; TensorShape output_shape = BCast::ToShape(bcast.output_shape()); OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, output_shape, &output)); auto flat_output = output->flat(); - for (int i = 0; i < flat_output.dimension(0); ++i) { - PolynomialVariant const* pv_a_var = - std::move(flat_a(a_bcaster(i)).get>()); - OP_REQUIRES(op_ctx, pv_a_var != nullptr, - InvalidArgument("PolynomialVariant at flat index: ", i, - " for input a did not unwrap successfully.")); - OP_REQUIRES_OK( - op_ctx, const_cast*>(pv_a_var)->MaybeLazyDecode( - shell_ctx_var->ct_context_)); - RnsPolynomial const& pt_a = pv_a_var->poly; - - PolynomialVariant const* pv_b_var = - std::move(flat_b(b_bcaster(i)).get>()); - OP_REQUIRES(op_ctx, pv_b_var != nullptr, - InvalidArgument("PolynomialVariant at flat index: ", i, - " for input b did not unwrap successfully.")); - OP_REQUIRES_OK( - op_ctx, const_cast*>(pv_b_var)->MaybeLazyDecode( - shell_ctx_var->ct_context_)); - RnsPolynomial const& pt_b = pv_b_var->poly; - - ShellAddSubWithParams add_or_sub; - OP_REQUIRES_VALUE(RnsPolynomial pt_c, op_ctx, - add_or_sub(pt_a, pt_b, shell_ctx->MainPrimeModuli())); - - PolynomialVariant pt_c_var(std::move(pt_c), - shell_ctx_var->ct_context_); - flat_output(i) = std::move(pt_c_var); - } + auto add_in_range = [&](int start, int end) { + for (int i = start; i < end; ++i) { + PolynomialVariant const* pv_a_var = + std::move(flat_a(a_bcaster(i)).get>()); + OP_REQUIRES( + op_ctx, pv_a_var != nullptr, + InvalidArgument("PolynomialVariant at flat index: ", i, + " for input a did not unwrap successfully.")); + OP_REQUIRES_OK( + op_ctx, + const_cast*>(pv_a_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_)); + RnsPolynomial const& pt_a = pv_a_var->poly; + + PolynomialVariant const* pv_b_var = + std::move(flat_b(b_bcaster(i)).get>()); + OP_REQUIRES( + op_ctx, pv_b_var != nullptr, + InvalidArgument("PolynomialVariant at flat index: ", i, + " for input b did not unwrap successfully.")); + OP_REQUIRES_OK( + op_ctx, + const_cast*>(pv_b_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_)); + RnsPolynomial const& pt_b = pv_b_var->poly; + + ShellAddSubWithParams add_or_sub; + OP_REQUIRES_VALUE(RnsPolynomial pt_c, op_ctx, + add_or_sub(pt_a, pt_b, shell_ctx->MainPrimeModuli())); + + PolynomialVariant pt_c_var(std::move(pt_c), + shell_ctx_var->ct_context_); + flat_output(i) = std::move(pt_c_var); + } + }; + + auto thread_pool = + op_ctx->device()->tensorflow_cpu_worker_threads()->workers; + int const cost_per_add = 20 * num_slots; + thread_pool->ParallelFor(flat_output.dimension(0), cost_per_add, + add_in_range); } }; @@ -308,35 +379,57 @@ class NegCtOp : public OpKernel { OP_REQUIRES_VALUE(ContextVariant const* shell_ctx_var, op_ctx, GetVariant>(op_ctx, 0)); Tensor const& a = op_ctx->input(1); + auto flat_a = a.flat(); + + // Recover num_slots from first ciphertext. + SymmetricCtVariant const* ct_var = + std::move(flat_a(0).get>()); + OP_REQUIRES( + op_ctx, ct_var != nullptr, + InvalidArgument("SymmetricCtVariant a did not unwrap successfully.")); + OP_REQUIRES_OK( + op_ctx, const_cast*>(ct_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_, shell_ctx_var->error_params_)); + SymmetricCt const& ct = ct_var->ct; + int num_slots = 1 << ct.LogN(); + int num_components = ct.NumModuli(); // Allocate the output tensor which is the same size as the input. Tensor* output; OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, a.shape(), &output)); - // Set up flat views of the input and output tensors. - auto flat_a = a.flat(); + // Set up flat view of the output tensor. auto flat_output = output->flat(); - for (int i = 0; i < flat_output.dimension(0); ++i) { - SymmetricCtVariant const* ct_a_var = - std::move(flat_a(i).get>()); - OP_REQUIRES(op_ctx, ct_a_var != nullptr, - InvalidArgument("SymmetricCtVariant at flat index: ", i, - " for input a did not unwrap successfully.")); - OP_REQUIRES_OK( - op_ctx, - const_cast*>(ct_a_var)->MaybeLazyDecode( - shell_ctx_var->ct_context_, shell_ctx_var->error_params_)); - SymmetricCt const& ct_a = ct_a_var->ct; - - OP_REQUIRES_VALUE(auto ct_out, op_ctx, ct_a.Negate()); - - // The output ct will hold smart pointers to the input's context - // to prevent premature deletion of the moduli. - SymmetricCtVariant ct_out_var(std::move(ct_out), ct_a_var->ct_context, - ct_a_var->error_params); - flat_output(i) = std::move(ct_out_var); - } + auto negate_in_range = [&](int start, int end) { + for (int i = start; i < end; ++i) { + SymmetricCtVariant const* ct_a_var = + std::move(flat_a(i).get>()); + OP_REQUIRES( + op_ctx, ct_a_var != nullptr, + InvalidArgument("SymmetricCtVariant at flat index: ", i, + " for input a did not unwrap successfully.")); + OP_REQUIRES_OK( + op_ctx, + const_cast*>(ct_a_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_, shell_ctx_var->error_params_)); + SymmetricCt const& ct_a = ct_a_var->ct; + + OP_REQUIRES_VALUE(auto ct_out, op_ctx, ct_a.Negate()); + + // The output ct will hold smart pointers to the input's context + // to prevent premature deletion of the moduli. + SymmetricCtVariant ct_out_var(std::move(ct_out), ct_a_var->ct_context, + ct_a_var->error_params); + flat_output(i) = std::move(ct_out_var); + } + }; + + auto thread_pool = + op_ctx->device()->tensorflow_cpu_worker_threads()->workers; + int const cost_per_neg = 20 * num_slots * num_components; + thread_pool->ParallelFor(flat_output.dimension(0), cost_per_neg, + negate_in_range); } }; @@ -357,32 +450,54 @@ class NegPtOp : public OpKernel { Context const* shell_ctx = shell_ctx_var->ct_context_.get(); Tensor const& a = op_ctx->input(1); + auto flat_a = a.flat(); + + // Recover num_slots from first plaintext. + PolynomialVariant const* pt_var = + std::move(flat_a(0).get>()); + OP_REQUIRES( + op_ctx, pt_var != nullptr, + InvalidArgument("PolynomialVariant a did not unwrap successfully.")); + OP_REQUIRES_OK(op_ctx, + const_cast*>(pt_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_)); + RnsPolynomial const& pt = pt_var->poly; + int num_slots = 1 << pt.LogN(); // Allocate the output tensor which is the same size as the input. Tensor* output; OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, a.shape(), &output)); - // Set up flat views of the input and output tensors. - auto flat_a = a.flat(); + // Set up flat view of the output tensor. auto flat_output = output->flat(); - for (int i = 0; i < flat_output.dimension(0); ++i) { - PolynomialVariant const* pt_a_var = - std::move(flat_a(i).get>()); - OP_REQUIRES(op_ctx, pt_a_var != nullptr, - InvalidArgument("SymmetricCtVariant at flat index: ", i, - " for input a did not unwrap successfully.")); - OP_REQUIRES_OK( - op_ctx, const_cast*>(pt_a_var)->MaybeLazyDecode( - shell_ctx_var->ct_context_)); - RnsPolynomial const& pt_a = pt_a_var->poly; - - OP_REQUIRES_VALUE(RnsPolynomial pt_out, op_ctx, - pt_a.Negate(shell_ctx->MainPrimeModuli())); - PolynomialVariant pt_out_var(std::move(pt_out), - shell_ctx_var->ct_context_); - flat_output(i) = std::move(pt_out_var); - } + auto negate_in_range = [&](int start, int end) { + for (int i = start; i < end; ++i) { + PolynomialVariant const* pt_a_var = + std::move(flat_a(i).get>()); + OP_REQUIRES( + op_ctx, pt_a_var != nullptr, + InvalidArgument("SymmetricCtVariant at flat index: ", i, + " for input a did not unwrap successfully.")); + OP_REQUIRES_OK( + op_ctx, + const_cast*>(pt_a_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_)); + RnsPolynomial const& pt_a = pt_a_var->poly; + + OP_REQUIRES_VALUE(RnsPolynomial pt_out, op_ctx, + pt_a.Negate(shell_ctx->MainPrimeModuli())); + PolynomialVariant pt_out_var(std::move(pt_out), + shell_ctx_var->ct_context_); + flat_output(i) = std::move(pt_out_var); + } + }; + + auto thread_pool = + op_ctx->device()->tensorflow_cpu_worker_threads()->workers; + int const cost_per_neg = 20 * num_slots; + thread_pool->ParallelFor(flat_output.dimension(0), cost_per_neg, + negate_in_range); } }; diff --git a/tf_shell/cc/kernels/mul_kernels.cc b/tf_shell/cc/kernels/mul_kernels.cc index a3f44d4..eb7f9ac 100644 --- a/tf_shell/cc/kernels/mul_kernels.cc +++ b/tf_shell/cc/kernels/mul_kernels.cc @@ -70,6 +70,19 @@ class MulCtCtOp : public OpKernel { IndexConverterFunctor a_bcaster(bcast.output_shape(), a.shape()); IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape()); + // Recover num_slots from first ciphertext. + SymmetricCtVariant const* ct_var = + std::move(flat_a(0).get>()); + OP_REQUIRES( + op_ctx, ct_var != nullptr, + InvalidArgument("SymmetricCtVariant a did not unwrap successfully.")); + OP_REQUIRES_OK( + op_ctx, const_cast*>(ct_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_, shell_ctx_var->error_params_)); + SymmetricCt const& ct = ct_var->ct; + int num_slots = 1 << ct.LogN(); + int num_components = ct.NumModuli(); + // Allocate the output tensor which is the same shape as each of the inputs. Tensor* output; TensorShape output_shape = BCast::ToShape(bcast.output_shape()); @@ -77,39 +90,49 @@ class MulCtCtOp : public OpKernel { auto flat_output = output->flat(); // Multiply each pair of ciphertexts and store the result in the output. - for (int i = 0; i < flat_output.dimension(0); ++i) { - SymmetricCtVariant const* ct_a_var = - std::move(flat_a(a_bcaster(i)).get>()); - OP_REQUIRES(op_ctx, ct_a_var != nullptr, - InvalidArgument("SymmetricCtVariant at flat index:", i, - " for input a did not unwrap successfully.")); - OP_REQUIRES_OK( - op_ctx, - const_cast*>(ct_a_var)->MaybeLazyDecode( - shell_ctx_var->ct_context_, shell_ctx_var->error_params_)); - SymmetricCt const& ct_a = ct_a_var->ct; + auto mul_in_range = [&](int start, int end) { + for (int i = start; i < end; ++i) { + SymmetricCtVariant const* ct_a_var = + std::move(flat_a(a_bcaster(i)).get>()); + OP_REQUIRES( + op_ctx, ct_a_var != nullptr, + InvalidArgument("SymmetricCtVariant at flat index:", i, + " for input a did not unwrap successfully.")); + OP_REQUIRES_OK( + op_ctx, + const_cast*>(ct_a_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_, shell_ctx_var->error_params_)); + SymmetricCt const& ct_a = ct_a_var->ct; - SymmetricCtVariant const* ct_b_var = - std::move(flat_b(b_bcaster(i)).get>()); - OP_REQUIRES(op_ctx, ct_b_var != nullptr, - InvalidArgument("SymmetricCtVariant at flat index:", i, - " for input b did not unwrap successfully.")); - OP_REQUIRES_OK( - op_ctx, - const_cast*>(ct_b_var)->MaybeLazyDecode( - shell_ctx_var->ct_context_, shell_ctx_var->error_params_)); - SymmetricCt const& ct_b = ct_b_var->ct; + SymmetricCtVariant const* ct_b_var = + std::move(flat_b(b_bcaster(i)).get>()); + OP_REQUIRES( + op_ctx, ct_b_var != nullptr, + InvalidArgument("SymmetricCtVariant at flat index:", i, + " for input b did not unwrap successfully.")); + OP_REQUIRES_OK( + op_ctx, + const_cast*>(ct_b_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_, shell_ctx_var->error_params_)); + SymmetricCt const& ct_b = ct_b_var->ct; - OP_REQUIRES_VALUE(SymmetricCt ct_c, op_ctx, ct_a * ct_b); + OP_REQUIRES_VALUE(SymmetricCt ct_c, op_ctx, ct_a * ct_b); - // Wrap the result in a SymmetricCtVariant and store it in the output. - // SHELL's multiplication preserves moduli pointers of the first input. - // Ensure the output holds smart pointers to the input's context to - // prevent premature deletion of the moduli. - SymmetricCtVariant ct_c_var(std::move(ct_c), ct_a_var->ct_context, - ct_a_var->error_params); - flat_output(i) = std::move(ct_c_var); - } + // Wrap the result in a SymmetricCtVariant and store it in the output. + // SHELL's multiplication preserves moduli pointers of the first input. + // Ensure the output holds smart pointers to the input's context to + // prevent premature deletion of the moduli. + SymmetricCtVariant ct_c_var(std::move(ct_c), ct_a_var->ct_context, + ct_a_var->error_params); + flat_output(i) = std::move(ct_c_var); + } + }; + + auto thread_pool = + op_ctx->device()->tensorflow_cpu_worker_threads()->workers; + int const cost_per_mul = 30 * num_slots * num_components; + thread_pool->ParallelFor(flat_output.dimension(0), cost_per_mul, + mul_in_range); } }; @@ -244,6 +267,36 @@ class MulShellTfScalarOp : public OpKernel { IndexConverterFunctor a_bcaster(bcast.output_shape(), a.shape()); IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape()); + // Recover num_slots from first or plaintext. + int num_slots; + int num_components; + if constexpr (std::is_same>::value) { + SymmetricCtVariant const* ct_var = + std::move(flat_a(0).get>()); + OP_REQUIRES( + op_ctx, ct_var != nullptr, + InvalidArgument("SymmetricCtVariant a did not unwrap successfully.")); + OP_REQUIRES_OK( + op_ctx, + const_cast*>(ct_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_, shell_ctx_var->error_params_)); + SymmetricCt const& ct = ct_var->ct; + num_slots = 1 << ct.LogN(); + num_components = ct.NumModuli(); + } else { + PolynomialVariant const* pt_var = + std::move(flat_a(0).get>()); + OP_REQUIRES( + op_ctx, pt_var != nullptr, + InvalidArgument("PolynomialVariant a did not unwrap successfully.")); + OP_REQUIRES_OK(op_ctx, + const_cast*>(pt_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_)); + RnsPolynomial const& pt = pt_var->poly; + num_slots = 1 << pt.LogN(); + num_components = 1; + } + // Allocate the output tensor which is the same shape as the first input. Tensor* output; TensorShape output_shape = BCast::ToShape(bcast.output_shape()); @@ -251,52 +304,61 @@ class MulShellTfScalarOp : public OpKernel { auto flat_output = output->flat(); // Now multiply. - for (int i = 0; i < flat_output.dimension(0); ++i) { - // First encode the scalar b - // TDOO(jchoncholas): encode all scalars at once beforehand. - T wrapped_b{}; - EncodeScalar(op_ctx, flat_b(b_bcaster(i)), encoder, &wrapped_b); - - CtOrPolyVariant const* ct_or_pt_var = - std::move(flat_a(a_bcaster(i)).get()); - OP_REQUIRES(op_ctx, ct_or_pt_var != nullptr, - InvalidArgument("Input at flat index:", i, - " for input a did not unwrap successfully.")); - - if constexpr (std::is_same>::value) { - OP_REQUIRES_OK(op_ctx, - const_cast*>(ct_or_pt_var) - ->MaybeLazyDecode(shell_ctx_var->ct_context_)); - RnsPolynomial const& poly = ct_or_pt_var->poly; - - OP_REQUIRES_VALUE(RnsPolynomial result, op_ctx, - poly.Mul(wrapped_b, shell_ctx->MainPrimeModuli())); - - PolynomialVariant result_var(std::move(result), - shell_ctx_var->ct_context_); - flat_output(i) = std::move(result_var); - } else if constexpr (std::is_same>::value) { - OP_REQUIRES_OK(op_ctx, - const_cast*>(ct_or_pt_var) - ->MaybeLazyDecode(shell_ctx_var->ct_context_, - shell_ctx_var->error_params_)); - SymmetricCt const& ct = ct_or_pt_var->ct; - - OP_REQUIRES_VALUE(SymmetricCt result, op_ctx, - ct * wrapped_b); // shell aborb operation + auto mul_in_range = [&](int start, int end) { + for (int i = start; i < end; ++i) { + // First encode the scalar b + // TDOO(jchoncholas): encode all scalars at once beforehand. + T wrapped_b{}; + EncodeScalar(op_ctx, flat_b(b_bcaster(i)), encoder, &wrapped_b); - // The output ct will hold raw pointers to moduli stored in the input's - // context. Ensure the output ciphertext Variant wrapper holds smart - // pointers to the input's context to prevent premature deletion of the - // moduli - SymmetricCtVariant result_var(std::move(result), - ct_or_pt_var->ct_context, - ct_or_pt_var->error_params); - flat_output(i) = std::move(result_var); + CtOrPolyVariant const* ct_or_pt_var = + std::move(flat_a(a_bcaster(i)).get()); + OP_REQUIRES( + op_ctx, ct_or_pt_var != nullptr, + InvalidArgument("Input at flat index:", i, + " for input a did not unwrap successfully.")); + + if constexpr (std::is_same>::value) { + OP_REQUIRES_OK(op_ctx, + const_cast*>(ct_or_pt_var) + ->MaybeLazyDecode(shell_ctx_var->ct_context_)); + RnsPolynomial const& poly = ct_or_pt_var->poly; + + OP_REQUIRES_VALUE(RnsPolynomial result, op_ctx, + poly.Mul(wrapped_b, shell_ctx->MainPrimeModuli())); + + PolynomialVariant result_var(std::move(result), + shell_ctx_var->ct_context_); + flat_output(i) = std::move(result_var); + } else if constexpr (std::is_same>::value) { + OP_REQUIRES_OK(op_ctx, + const_cast*>(ct_or_pt_var) + ->MaybeLazyDecode(shell_ctx_var->ct_context_, + shell_ctx_var->error_params_)); + SymmetricCt const& ct = ct_or_pt_var->ct; + + OP_REQUIRES_VALUE(SymmetricCt result, op_ctx, + ct * wrapped_b); // shell aborb operation + + // The output ct will hold raw pointers to moduli stored in the + // input's context. Ensure the output ciphertext Variant wrapper holds + // smart pointers to the input's context to prevent premature deletion + // of the moduli + SymmetricCtVariant result_var(std::move(result), + ct_or_pt_var->ct_context, + ct_or_pt_var->error_params); + flat_output(i) = std::move(result_var); + } } - } + }; + + auto thread_pool = + op_ctx->device()->tensorflow_cpu_worker_threads()->workers; + int const cost_per_mul = 20 * num_slots * num_components; + thread_pool->ParallelFor(flat_output.dimension(0), cost_per_mul, + mul_in_range); } private: @@ -357,6 +419,18 @@ class MulPtPtOp : public OpKernel { IndexConverterFunctor a_bcaster(bcast.output_shape(), a.shape()); IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape()); + // Recover num_slots from first plaintext. + PolynomialVariant const* pt_var = + std::move(flat_a(0).get>()); + OP_REQUIRES( + op_ctx, pt_var != nullptr, + InvalidArgument("PolynomialVariant a did not unwrap successfully.")); + OP_REQUIRES_OK(op_ctx, + const_cast*>(pt_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_)); + RnsPolynomial const& pt = pt_var->poly; + int num_slots = 1 << pt.LogN(); + // Allocate the output tensor which is the same shape as each of the // inputs. Tensor* output; @@ -364,34 +438,46 @@ class MulPtPtOp : public OpKernel { OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, output_shape, &output)); auto flat_output = output->flat(); - for (int i = 0; i < flat_output.dimension(0); ++i) { - PolynomialVariant const* pv_a_var = - std::move(flat_a(a_bcaster(i)).get>()); - OP_REQUIRES(op_ctx, pv_a_var != nullptr, - InvalidArgument("PolynomialVariant at flat index:", i, - " for input a did not unwrap successfully.")); - OP_REQUIRES_OK( - op_ctx, const_cast*>(pv_a_var)->MaybeLazyDecode( - shell_ctx_var->ct_context_)); - RnsPolynomial const& pt_a = pv_a_var->poly; - - PolynomialVariant const* pv_b_var = - std::move(flat_b(b_bcaster(i)).get>()); - OP_REQUIRES(op_ctx, pv_b_var != nullptr, - InvalidArgument("PolynomialVariant at flat index:", i, - " for input b did not unwrap successfully.")); - OP_REQUIRES_OK( - op_ctx, const_cast*>(pv_b_var)->MaybeLazyDecode( - shell_ctx_var->ct_context_)); - RnsPolynomial const& pt_b = pv_b_var->poly; + auto add_in_range = [&](int start, int end) { + for (int i = start; i < end; ++i) { + PolynomialVariant const* pv_a_var = + std::move(flat_a(a_bcaster(i)).get>()); + OP_REQUIRES( + op_ctx, pv_a_var != nullptr, + InvalidArgument("PolynomialVariant at flat index:", i, + " for input a did not unwrap successfully.")); + OP_REQUIRES_OK( + op_ctx, + const_cast*>(pv_a_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_)); + RnsPolynomial const& pt_a = pv_a_var->poly; + + PolynomialVariant const* pv_b_var = + std::move(flat_b(b_bcaster(i)).get>()); + OP_REQUIRES( + op_ctx, pv_b_var != nullptr, + InvalidArgument("PolynomialVariant at flat index:", i, + " for input b did not unwrap successfully.")); + OP_REQUIRES_OK( + op_ctx, + const_cast*>(pv_b_var)->MaybeLazyDecode( + shell_ctx_var->ct_context_)); + RnsPolynomial const& pt_b = pv_b_var->poly; - OP_REQUIRES_VALUE(RnsPolynomial pt_c, op_ctx, - pt_a.Mul(pt_b, shell_ctx->MainPrimeModuli())); + OP_REQUIRES_VALUE(RnsPolynomial pt_c, op_ctx, + pt_a.Mul(pt_b, shell_ctx->MainPrimeModuli())); - PolynomialVariant pt_c_var(std::move(pt_c), - shell_ctx_var->ct_context_); - flat_output(i) = std::move(pt_c_var); - } + PolynomialVariant pt_c_var(std::move(pt_c), + shell_ctx_var->ct_context_); + flat_output(i) = std::move(pt_c_var); + } + }; + + auto thread_pool = + op_ctx->device()->tensorflow_cpu_worker_threads()->workers; + int const cost_per_add = 20 * num_slots; + thread_pool->ParallelFor(flat_output.dimension(0), cost_per_add, + add_in_range); } };