diff --git a/tf_shell/cc/kernels/segment_kernels.cc b/tf_shell/cc/kernels/segment_kernels.cc index 0ede42b..90174bd 100644 --- a/tf_shell/cc/kernels/segment_kernels.cc +++ b/tf_shell/cc/kernels/segment_kernels.cc @@ -175,10 +175,7 @@ struct UnsortedSegmentFunctor { } } - // Nothing to reduce. All output values equal to `InitialValueF()`. - if (trivial_reduction) { - return; - } + auto thread_pool = ctx->device()->tensorflow_cpu_worker_threads()->workers; // Step 1: Reduce over the ciphertext dimension. There are many slots in a // ciphertext, and some slots may be assigned to the same output. The @@ -269,14 +266,16 @@ struct UnsortedSegmentFunctor { } } }; - auto thread_pool = ctx->device()->tensorflow_cpu_worker_threads()->workers; - // Use a fixed block size to avoid stragglers in the reduction. - int64_t const reduction_block_size = 4; - tsl::thread::ThreadPool::SchedulingParams reduction_scheduling_params( - tsl::thread::ThreadPool::SchedulingStrategy::kFixedBlockSize, - std::nullopt, reduction_block_size); - thread_pool->ParallelFor(num_segments, reduction_scheduling_params, - reductionWorker); + // The reduction (step 1) can be skipped if there is nothing to reduce. + if (!trivial_reduction) { + // Use a fixed block size to avoid stragglers in the reduction. + int64_t const reduction_block_size = 4; + tsl::thread::ThreadPool::SchedulingParams reduction_scheduling_params( + tsl::thread::ThreadPool::SchedulingStrategy::kFixedBlockSize, + std::nullopt, reduction_block_size); + thread_pool->ParallelFor(num_segments, reduction_scheduling_params, + reductionWorker); + } // Step 2: Reduce over the slotting dimension. This requires rotating any // non-empty slots in the output ciphertexts to the first slot using @@ -333,16 +332,35 @@ struct UnsortedSegmentFunctor { } } }; - - if (reduction_type == "galois") { - // Use a fixed block size to avoid stragglers in the reduction. - int64_t const batchaxis_block_size = 2; - tsl::thread::ThreadPool::SchedulingParams batchaxis_scheduling_params( - tsl::thread::ThreadPool::SchedulingStrategy::kFixedBlockSize, - std::nullopt, batchaxis_block_size); - thread_pool->ParallelFor(num_segments, batchaxis_scheduling_params, + if (!trivial_reduction && reduction_type == "galois") { + // Reduction does not generally have long stragglers, so a variable block + // size is best to mimiize overhead of parallelization. + int const cost_per_reduce = 18000 * num_slots; // ns + thread_pool->ParallelFor(num_segments, cost_per_reduce, batchAxisReductionWorker); } + + // Step 3: For any ciphertexts in the output which are empty, fill in zeros. + // Create a zero ciphertext by subtracting one of the inputs from itself. + SymmetricCt zero_ct = data(0, 0).get>()->ct; + OP_REQUIRES_OK(ctx, zero_ct.SubInPlace(zero_ct)); + auto fillInZerosWorker = [&](int64_t begin, int64_t end) -> void { + for (int64_t i = 0; i < output.dimension(0); ++i) { + for (int64_t j = begin; j < end; ++j) { + for (int64_t chip = 0; chip < inner_dim; ++chip) { + SymmetricCt& ct = + output(i, j, chip).get>()->ct; + + // If the ciphertext degree is zero, it has no moduli and is empty. + if (ct.Degree() == -1) { + output(i, j, chip).get>()->ct = zero_ct; + } + } + } + } + }; + int const cost_per_zero = 900; + thread_pool->ParallelFor(num_segments, cost_per_zero, fillInZerosWorker); } }; diff --git a/tf_shell_ml/test/embedding_test.py b/tf_shell_ml/test/embedding_test.py index 013dd6d..29f516f 100644 --- a/tf_shell_ml/test/embedding_test.py +++ b/tf_shell_ml/test/embedding_test.py @@ -59,10 +59,15 @@ def test_embedding_forward(self): "Embedding layer forward with encrypted value should fail." ) - def _test_embedding(self): + def _test_embedding(self, reduction): input_dim = 100 - output_dim = 10 - embedding_layer = tf_shell_ml.ShellEmbedding(input_dim, output_dim) + output_dim = 16 + embedding_layer = tf_shell_ml.ShellEmbedding( + input_dim, + output_dim, + skip_embeddings_below_index=1, + grad_reduction=reduction, + ) sentence_length = 3 special_index = 2 @@ -80,14 +85,24 @@ def forward_backward(x): enc_dw, _ = embedding_layer.backward(enc_dy, rotation_key) dw = tf_shell.to_tensorflow(enc_dw[0], key) - # dw = embedding_layer.unpack(packed_dw) - dw = tf.reduce_sum(dw, axis=0) - return dw + if reduction == "none": + dw = tf.reduce_sum(dw, axis=0) + else: + dw = embedding_layer.unpack(dw) + return dw, enc_dw[0].shape - dw = forward_backward(x) + dw, shape_inf = forward_backward(x) + + # Check the inferred shape of the gradient is correct. + if reduction == "none": + self.assertAllEqual(shape_inf, [context.num_slots, input_dim, output_dim]) + else: + self.assertAllEqual( + shape_inf, [context.num_slots, 2, input_dim, output_dim] + ) for i in range(0, input_dim): - # Check dw[ special_index] has counted the number of elements. + # Check dw[special_index] has counted the number of elements. if i == special_index: self.assertAllEqual( dw[special_index, :], @@ -101,11 +116,13 @@ def forward_backward(x): def test_embedding_eager(self): tf.config.run_functions_eagerly(True) - self._test_embedding() + self._test_embedding("galois") + self._test_embedding("none") def test_embedding_defer(self): tf.config.run_functions_eagerly(False) - self._test_embedding() + self._test_embedding("galois") + self._test_embedding("none") if __name__ == "__main__":