Skip to content

Commit

Permalink
Fix bug in segment sum op (used by embedding layer) which prevents su…
Browse files Browse the repository at this point in the history
…bsequent operations.
  • Loading branch information
james-choncholas committed Oct 30, 2024
1 parent 0f0f7d1 commit 50a46a3
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 30 deletions.
58 changes: 38 additions & 20 deletions tf_shell/cc/kernels/segment_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,7 @@ struct UnsortedSegmentFunctor<CPUDevice, T, Index, InitialValueF, ReductionF> {
}
}

// Nothing to reduce. All output values equal to `InitialValueF()`.
if (trivial_reduction) {
return;
}
auto thread_pool = ctx->device()->tensorflow_cpu_worker_threads()->workers;

// Step 1: Reduce over the ciphertext dimension. There are many slots in a
// ciphertext, and some slots may be assigned to the same output. The
Expand Down Expand Up @@ -269,14 +266,16 @@ struct UnsortedSegmentFunctor<CPUDevice, T, Index, InitialValueF, ReductionF> {
}
}
};
auto thread_pool = ctx->device()->tensorflow_cpu_worker_threads()->workers;
// Use a fixed block size to avoid stragglers in the reduction.
int64_t const reduction_block_size = 4;
tsl::thread::ThreadPool::SchedulingParams reduction_scheduling_params(
tsl::thread::ThreadPool::SchedulingStrategy::kFixedBlockSize,
std::nullopt, reduction_block_size);
thread_pool->ParallelFor(num_segments, reduction_scheduling_params,
reductionWorker);
// The reduction (step 1) can be skipped if there is nothing to reduce.
if (!trivial_reduction) {
// Use a fixed block size to avoid stragglers in the reduction.
int64_t const reduction_block_size = 4;
tsl::thread::ThreadPool::SchedulingParams reduction_scheduling_params(
tsl::thread::ThreadPool::SchedulingStrategy::kFixedBlockSize,
std::nullopt, reduction_block_size);
thread_pool->ParallelFor(num_segments, reduction_scheduling_params,
reductionWorker);
}

// Step 2: Reduce over the slotting dimension. This requires rotating any
// non-empty slots in the output ciphertexts to the first slot using
Expand Down Expand Up @@ -333,16 +332,35 @@ struct UnsortedSegmentFunctor<CPUDevice, T, Index, InitialValueF, ReductionF> {
}
}
};

if (reduction_type == "galois") {
// Use a fixed block size to avoid stragglers in the reduction.
int64_t const batchaxis_block_size = 2;
tsl::thread::ThreadPool::SchedulingParams batchaxis_scheduling_params(
tsl::thread::ThreadPool::SchedulingStrategy::kFixedBlockSize,
std::nullopt, batchaxis_block_size);
thread_pool->ParallelFor(num_segments, batchaxis_scheduling_params,
if (!trivial_reduction && reduction_type == "galois") {
// Reduction does not generally have long stragglers, so a variable block
// size is best to mimiize overhead of parallelization.
int const cost_per_reduce = 18000 * num_slots; // ns
thread_pool->ParallelFor(num_segments, cost_per_reduce,
batchAxisReductionWorker);
}

// Step 3: For any ciphertexts in the output which are empty, fill in zeros.
// Create a zero ciphertext by subtracting one of the inputs from itself.
SymmetricCt zero_ct = data(0, 0).get<SymmetricCtVariant<T>>()->ct;
OP_REQUIRES_OK(ctx, zero_ct.SubInPlace(zero_ct));
auto fillInZerosWorker = [&](int64_t begin, int64_t end) -> void {
for (int64_t i = 0; i < output.dimension(0); ++i) {
for (int64_t j = begin; j < end; ++j) {
for (int64_t chip = 0; chip < inner_dim; ++chip) {
SymmetricCt& ct =
output(i, j, chip).get<SymmetricCtVariant<T>>()->ct;

// If the ciphertext degree is zero, it has no moduli and is empty.
if (ct.Degree() == -1) {
output(i, j, chip).get<SymmetricCtVariant<T>>()->ct = zero_ct;
}
}
}
}
};
int const cost_per_zero = 900;
thread_pool->ParallelFor(num_segments, cost_per_zero, fillInZerosWorker);
}
};

Expand Down
37 changes: 27 additions & 10 deletions tf_shell_ml/test/embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,15 @@ def test_embedding_forward(self):
"Embedding layer forward with encrypted value should fail."
)

def _test_embedding(self):
def _test_embedding(self, reduction):
input_dim = 100
output_dim = 10
embedding_layer = tf_shell_ml.ShellEmbedding(input_dim, output_dim)
output_dim = 16
embedding_layer = tf_shell_ml.ShellEmbedding(
input_dim,
output_dim,
skip_embeddings_below_index=1,
grad_reduction=reduction,
)

sentence_length = 3
special_index = 2
Expand All @@ -80,14 +85,24 @@ def forward_backward(x):

enc_dw, _ = embedding_layer.backward(enc_dy, rotation_key)
dw = tf_shell.to_tensorflow(enc_dw[0], key)
# dw = embedding_layer.unpack(packed_dw)
dw = tf.reduce_sum(dw, axis=0)
return dw
if reduction == "none":
dw = tf.reduce_sum(dw, axis=0)
else:
dw = embedding_layer.unpack(dw)
return dw, enc_dw[0].shape

dw = forward_backward(x)
dw, shape_inf = forward_backward(x)

# Check the inferred shape of the gradient is correct.
if reduction == "none":
self.assertAllEqual(shape_inf, [context.num_slots, input_dim, output_dim])
else:
self.assertAllEqual(
shape_inf, [context.num_slots, 2, input_dim, output_dim]
)

for i in range(0, input_dim):
# Check dw[ special_index] has counted the number of elements.
# Check dw[special_index] has counted the number of elements.
if i == special_index:
self.assertAllEqual(
dw[special_index, :],
Expand All @@ -101,11 +116,13 @@ def forward_backward(x):

def test_embedding_eager(self):
tf.config.run_functions_eagerly(True)
self._test_embedding()
self._test_embedding("galois")
self._test_embedding("none")

def test_embedding_defer(self):
tf.config.run_functions_eagerly(False)
self._test_embedding()
self._test_embedding("galois")
self._test_embedding("none")


if __name__ == "__main__":
Expand Down

0 comments on commit 50a46a3

Please sign in to comment.