diff --git a/tf_shell/cc/kernels/mul_kernels.cc b/tf_shell/cc/kernels/mul_kernels.cc index 550cc86..a3f44d4 100644 --- a/tf_shell/cc/kernels/mul_kernels.cc +++ b/tf_shell/cc/kernels/mul_kernels.cc @@ -547,7 +547,7 @@ class MatMulCtPtOp : public OpKernel { } }; -template +template class MatMulPtCtOp : public OpKernel { using ModularInt = rlwe::MontgomeryInt; using Context = rlwe::RnsContext; @@ -557,8 +557,20 @@ class MatMulPtCtOp : public OpKernel { using Encoder = rlwe::FiniteFieldEncoder; using RotationKey = rlwe::RnsGaloisKey; + std::string reduction; + char const* galois_reduction = "galois"; + char const* fast_reduction = "fast"; + char const* no_reduction = "none"; + public: - explicit MatMulPtCtOp(OpKernelConstruction* op_ctx) : OpKernel(op_ctx) {} + explicit MatMulPtCtOp(OpKernelConstruction* op_ctx) : OpKernel(op_ctx) { + OP_REQUIRES_OK(op_ctx, op_ctx->GetAttr("reduction", &reduction)); + OP_REQUIRES( + op_ctx, + reduction == "galois" || reduction == "fast" || reduction == "none", + InvalidArgument("Invalid reduction attribute: ", reduction, + ". Must be 'galois', 'fast', or 'none'.")); + } void Compute(OpKernelContext* op_ctx) override { // Get the input tensors. @@ -579,9 +591,7 @@ class MatMulPtCtOp : public OpKernel { // Rotation keys are only required if fast rotations are enabled. RotationKeyVariant const* rotation_key_var = nullptr; - if constexpr (use_fast_rotations) { - (void)rotation_key_var; - } else { + if (reduction == galois_reduction) { OP_REQUIRES_VALUE(rotation_key_var, op_ctx, GetVariant>(op_ctx, 3)); OP_REQUIRES( @@ -590,10 +600,12 @@ class MatMulPtCtOp : public OpKernel { OP_REQUIRES_OK(op_ctx, const_cast*>(rotation_key_var) ->MaybeLazyDecode(shell_ctx_var->ct_context_)); + } else { + (void)rotation_key_var; } std::vector> empty_rot_keys{}; std::vector> const& rot_keys = - use_fast_rotations ? empty_rot_keys : rotation_key_var->keys; + reduction == galois_reduction ? rotation_key_var->keys : empty_rot_keys; // b is a vector of Polynomials so first dimension is the number of // slots. @@ -729,7 +741,25 @@ class MatMulPtCtOp : public OpKernel { // Note the ciphertext rotations operate on each half of the // ciphertext separately. So the max rotatation is by half the // number of slots. - if constexpr (use_fast_rotations) { + if (reduction == galois_reduction) { + for (int shift = 1; shift < num_slots / 2; shift <<= 1) { + OP_REQUIRES( + op_ctx, + shift - 1 < + static_cast(rot_keys.size()), // Skip key 0. + InvalidArgument("No key for shift of '", shift, "'")); + RotationKey const* k = rot_keys[shift].get(); + + // Rotate by the shift. + OP_REQUIRES_VALUE(auto ct_sub, op_ctx, + ct_result.Substitute(k->SubstitutionPower())); + OP_REQUIRES_VALUE(auto ct_rot, op_ctx, k->ApplyTo(ct_sub)); + + // Add to the sum. + OP_REQUIRES_OK(op_ctx, ct_result.AddInPlace(ct_rot)); + } + + } else if (reduction == fast_reduction) { OP_REQUIRES_VALUE( RnsPolynomial sum_component_zero, op_ctx, ct_result.Component(0)); // deep copy to start the sum. @@ -759,25 +789,8 @@ class MatMulPtCtOp : public OpKernel { ct_result.PowerOfS(), ct_result.Error() * ct_result.LogN(), ct_result.ErrorParams()}; - - } else { - for (int shift = 1; shift < num_slots / 2; shift <<= 1) { - OP_REQUIRES( - op_ctx, - shift - 1 < - static_cast(rot_keys.size()), // Skip key 0. - InvalidArgument("No key for shift of '", shift, "'")); - RotationKey const* k = rot_keys[shift].get(); - - // Rotate by the shift. - OP_REQUIRES_VALUE(auto ct_sub, op_ctx, - ct_result.Substitute(k->SubstitutionPower())); - OP_REQUIRES_VALUE(auto ct_rot, op_ctx, k->ApplyTo(ct_sub)); - - // Add to the sum. - OP_REQUIRES_OK(op_ctx, ct_result.AddInPlace(ct_rot)); - } } + // else if no reduction, do nothing. // At this point we have one ciphertext per row of the plaintext // matrix where every element in the ciphertext is the same value, @@ -896,57 +909,28 @@ REGISTER_KERNEL_BUILDER( // Matrix multiply plaintext and ciphertext. REGISTER_KERNEL_BUILDER( Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), - MatMulPtCtOp); + MatMulPtCtOp); REGISTER_KERNEL_BUILDER( Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), - MatMulPtCtOp); + MatMulPtCtOp); REGISTER_KERNEL_BUILDER( Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), - MatMulPtCtOp); + MatMulPtCtOp); REGISTER_KERNEL_BUILDER( Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), - MatMulPtCtOp); + MatMulPtCtOp); REGISTER_KERNEL_BUILDER( Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), - MatMulPtCtOp); + MatMulPtCtOp); REGISTER_KERNEL_BUILDER( Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), - MatMulPtCtOp); + MatMulPtCtOp); REGISTER_KERNEL_BUILDER( Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), - MatMulPtCtOp); + MatMulPtCtOp); REGISTER_KERNEL_BUILDER( Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), - MatMulPtCtOp); - -// Matrix multiply plaintext and ciphertext with fast rotations. -REGISTER_KERNEL_BUILDER( - Name("FastMatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), - MatMulPtCtOp); -REGISTER_KERNEL_BUILDER( - Name("FastMatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), - MatMulPtCtOp); - -REGISTER_KERNEL_BUILDER( - Name("FastMatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), - MatMulPtCtOp); -REGISTER_KERNEL_BUILDER( - Name("FastMatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), - MatMulPtCtOp); - -REGISTER_KERNEL_BUILDER( - Name("FastMatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), - MatMulPtCtOp); -REGISTER_KERNEL_BUILDER( - Name("FastMatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), - MatMulPtCtOp); - -REGISTER_KERNEL_BUILDER( - Name("FastMatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), - MatMulPtCtOp); -REGISTER_KERNEL_BUILDER( - Name("FastMatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), - MatMulPtCtOp); \ No newline at end of file + MatMulPtCtOp); diff --git a/tf_shell/cc/kernels/segment_kernels.cc b/tf_shell/cc/kernels/segment_kernels.cc index 7edf84e..0ede42b 100644 --- a/tf_shell/cc/kernels/segment_kernels.cc +++ b/tf_shell/cc/kernels/segment_kernels.cc @@ -121,9 +121,9 @@ struct UnsortedSegmentFunctor { TensorShape const& segment_ids_shape, typename TTypes::ConstTensor segment_ids, typename TTypes::ConstTensor data, - typename TTypes::Tensor unreduced_output, typename TTypes::Tensor output, - typename TTypes::Tensor slot_counter) { + typename TTypes::Tensor slot_counter, + std::string const& reduction_type) { // Initialize the output. // auto initial_value = InitialValueF()(shell_ctx_var); for (int i = 0; i < output.dimension(0); ++i) { @@ -148,7 +148,7 @@ struct UnsortedSegmentFunctor { // Reduce `N` rows input to `num_segments` rows output. int64_t const N = segment_ids.dimension(1); - int64_t const num_segments = unreduced_output.dimension(0); + int64_t const num_segments = output.dimension(1); int64_t const inner_dim = data.dimension(1); int64_t const num_slots = 1 << shell_ctx_var->log_n_; ShellContext const* shell_ctx = shell_ctx_var->ct_context_.get(); @@ -180,14 +180,6 @@ struct UnsortedSegmentFunctor { return; } - // Initialize intermediate storage tensor. - for (int i = 0; i < unreduced_output.dimension(0); ++i) { - for (int ii = 0; ii < unreduced_output.dimension(1); ++ii) { - // unreduced_output(i, ii) = initial_value; - unreduced_output(i, ii) = InitialValueF()(shell_ctx_var); - } - } - // Step 1: Reduce over the ciphertext dimension. There are many slots in a // ciphertext, and some slots may be assigned to the same output. The // `reductionWorker1D` will extract the all slots for the same destination @@ -255,7 +247,7 @@ struct UnsortedSegmentFunctor { data_var->ct * mask_pt); SymmetricCtVariant* output_var = - unreduced_output((int64_t)j, chip).get>(); + output(0, (int64_t)j, chip).get>(); OP_REQUIRES(ctx, output_var != nullptr, InvalidArgument("SymmetricCtVariant for output did not " "unwrap successfully.")); @@ -269,7 +261,7 @@ struct UnsortedSegmentFunctor { // input's context to prevent premature deletion of the moduli. SymmetricCtVariant var(masked_data_ct, data_var->ct_context, data_var->error_params); - unreduced_output((int64_t)j, chip) = std::move(var); + output(0, (int64_t)j, chip) = std::move(var); } else { OP_REQUIRES_OK(ctx, reduction(masked_data_ct, output_var->ct)); } @@ -295,12 +287,10 @@ struct UnsortedSegmentFunctor { for (int64_t chip = 0; chip < inner_dim; ++chip) { // Start the reduction for the top and bottom halves of the ciphertext // with the output value. - SymmetricCt accum_top = unreduced_output(j, chip) - .get>() - ->ct; // deep copy - SymmetricCt accum_bottom = unreduced_output(j, chip) - .get>() - ->ct; // deep copy + SymmetricCt accum_top = + output(0, j, chip).get>()->ct; // deep copy + SymmetricCt accum_bottom = + output(0, j, chip).get>()->ct; // deep copy // No need to lazy decode the accums, they were created in this op. for (int64_t slot = 1; slot < num_slots; ++slot) { @@ -317,7 +307,7 @@ struct UnsortedSegmentFunctor { key = keys[key_slot].get(); SymmetricCt const& ct = - unreduced_output(j, chip).get>()->ct; + output(0, j, chip).get>()->ct; // Rotate. OP_REQUIRES_VALUE(auto ct_sub, ctx, @@ -344,13 +334,15 @@ struct UnsortedSegmentFunctor { } }; - // Use a fixed block size to avoid stragglers in the reduction. - int64_t const batchaxis_block_size = 2; - tsl::thread::ThreadPool::SchedulingParams batchaxis_scheduling_params( - tsl::thread::ThreadPool::SchedulingStrategy::kFixedBlockSize, - std::nullopt, batchaxis_block_size); - thread_pool->ParallelFor(num_segments, batchaxis_scheduling_params, - batchAxisReductionWorker); + if (reduction_type == "galois") { + // Use a fixed block size to avoid stragglers in the reduction. + int64_t const batchaxis_block_size = 2; + tsl::thread::ThreadPool::SchedulingParams batchaxis_scheduling_params( + tsl::thread::ThreadPool::SchedulingStrategy::kFixedBlockSize, + std::nullopt, batchaxis_block_size); + thread_pool->ParallelFor(num_segments, batchaxis_scheduling_params, + batchAxisReductionWorker); + } } }; @@ -378,7 +370,7 @@ struct SumOp { // check routines not in the templated class to reduce code size template Status ValidateUnsortedSegmentReduction(OpKernel* op_kernel, - OpKernelContext* context, + OpKernelContext* op_ctx, ContextVariant const* shell_ctx_var, Tensor const& data, Tensor const& segment_ids, @@ -416,82 +408,104 @@ class UnsortedSegmentReductionOp : public OpKernel { private: using ModularInt = rlwe::MontgomeryInt; using ShellContext = rlwe::RnsContext; + using RotationKey = rlwe::RnsGaloisKey; + + std::string reduction_type; + char const* galois_reduction = "galois"; + char const* no_reduction = "none"; public: - explicit UnsortedSegmentReductionOp(OpKernelConstruction* context) - : OpKernel(context), reduction_functor_(DeviceReductionFunctor()) {} + explicit UnsortedSegmentReductionOp(OpKernelConstruction* op_ctx) + : OpKernel(op_ctx), reduction_functor_(DeviceReductionFunctor()) { + OP_REQUIRES_OK(op_ctx, op_ctx->GetAttr("reduction", &reduction_type)); + OP_REQUIRES(op_ctx, reduction_type == "galois" || reduction_type == "none", + InvalidArgument("Invalid reduction attribute: ", reduction_type, + ". Must be 'galois' or 'none'.")); + } - void Compute(OpKernelContext* context) override { - OP_REQUIRES_VALUE(ContextVariant const* shell_ctx_var, context, - GetVariant>(context, 0)); + void Compute(OpKernelContext* op_ctx) override { + // Recover the input tensors. + OP_REQUIRES_VALUE(ContextVariant const* shell_ctx_var, op_ctx, + GetVariant>(op_ctx, 0)); int64_t const num_slots = 1 << shell_ctx_var->log_n_; - Tensor const& data = context->input(1); - Tensor const& segment_ids = context->input(2); - Tensor const& num_segments = context->input(3); - OP_REQUIRES_VALUE(RotationKeyVariant const* rotation_key_var, context, - GetVariant>(context, 4)); - OP_REQUIRES( - context, rotation_key_var != nullptr, - InvalidArgument("RotationKeyVariant did not unwrap successfully.")); - OP_REQUIRES_OK(context, const_cast*>(rotation_key_var) - ->MaybeLazyDecode(shell_ctx_var->ct_context_)); - - OP_REQUIRES_OK(context, ValidateUnsortedSegmentReduction( - this, context, shell_ctx_var, data, segment_ids, - num_segments)); + Tensor const& data = op_ctx->input(1); + Tensor const& segment_ids = op_ctx->input(2); + Tensor const& num_segments = op_ctx->input(3); + + // Validate the input tensors. + OP_REQUIRES_OK(op_ctx, ValidateUnsortedSegmentReduction( + this, op_ctx, shell_ctx_var, data, segment_ids, + num_segments)); Index const output_rows = static_cast( num_segments.dtype() == DT_INT32 ? num_segments.scalar()() : num_segments.scalar()()); - OP_REQUIRES(context, output_rows >= 0, + OP_REQUIRES(op_ctx, output_rows >= 0, InvalidArgument("Input num_segments == ", output_rows, " must not be negative.")); + // Recover the rotation keys if needed. + RotationKeyVariant const* rotation_key_var = nullptr; + if (reduction_type == galois_reduction) { + OP_REQUIRES_VALUE(rotation_key_var, op_ctx, + GetVariant>(op_ctx, 4)); + OP_REQUIRES( + op_ctx, rotation_key_var != nullptr, + InvalidArgument("RotationKeyVariant did not unwrap successfully.")); + OP_REQUIRES_OK(op_ctx, + const_cast*>(rotation_key_var) + ->MaybeLazyDecode(shell_ctx_var->ct_context_)); + } + std::vector> empty_rot_keys{}; + std::vector> const& rot_keys = + reduction_type == galois_reduction ? rotation_key_var->keys + : empty_rot_keys; + + // Build the output tensor shape. TensorShape output_shape; - OP_REQUIRES_OK(context, output_shape.AddDimWithStatus(output_rows)); + OP_REQUIRES_OK(op_ctx, output_shape.AddDimWithStatus(output_rows)); for (int i = segment_ids.dims() - 1; i < data.dims(); i++) { // -1 for batch axis packing. - OP_REQUIRES_OK(context, output_shape.AddDimWithStatus(data.dim_size(i))); + OP_REQUIRES_OK(op_ctx, output_shape.AddDimWithStatus(data.dim_size(i))); } - // `unreduced_data` is a temporary tensor to store the intermediate result - // before reducing over the packing dimension. - Tensor unreduced_data; - OP_REQUIRES_OK(context, context->allocate_temp(DT_VARIANT, output_shape, - &unreduced_data)); - - // `slot_counter` records which slots in the unreduced_output ciphertext - // contain real data. This is used to rotate the occupied slots in - // unreduced_output to the first (and mid) positions for the real output. - // Since this is a plaintext output, prepend the packing dimension. - // TensorShape slot_counter_shape = output_shape; - // OP_REQUIRES_OK(context, slot_counter_shape.InsertDimWithStatus(0, + // `slot_counter` records which slots in the output ciphertexts contain real + // data. This is used to rotate the occupied slots in unreduced_output to + // the first (and mid) positions. Since this is a plaintext output, prepend + // the packing dimension. TensorShape slot_counter_shape = output_shape; + // OP_REQUIRES_OK(op_ctx, slot_counter_shape.InsertDimWithStatus(0, // num_slots)); TensorShape slot_counter_shape = {num_slots, output_rows}; Tensor* slot_counter = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(1, slot_counter_shape, - &slot_counter)); - - // The reduced output will have a prefix dimension of 2, corresponding to - // the result from the top and bottom half of the ciphertexts. - OP_REQUIRES_OK(context, output_shape.InsertDimWithStatus(0, 2)); + OP_REQUIRES_OK( + op_ctx, op_ctx->allocate_output(1, slot_counter_shape, &slot_counter)); + + int64_t output_prefix_dim; + + if (reduction_type == galois_reduction) { + // The reduced output will have a prefix dimension of 2, corresponding to + // the result from the top and bottom half of the ciphertexts. + output_prefix_dim = 2; + OP_REQUIRES_OK(op_ctx, + output_shape.InsertDimWithStatus(0, output_prefix_dim)); + } else { + output_prefix_dim = 1; + } Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, output_shape, &output)); auto const segment_flat = segment_ids.flat_outer_dims(); - auto unreduced_data_flat = unreduced_data.flat_outer_dims(); auto output_flat = - output->shaped({2, unreduced_data_flat.dimension(0), - unreduced_data_flat.dimension(1)}); + output->flat_inner_outer_dims(output_prefix_dim - 2); auto data_flat = data.flat_inner_outer_dims(segment_ids.dims() - 1 - 1); // -1 because flat_inner_outer_dims arg is an includsive range, // -1 again for batch axis packing (dimension 0 of data is imaginary). auto slot_counter_matrix = slot_counter->flat_outer_dims(); - reduction_functor_(context, shell_ctx_var, rotation_key_var->keys, - segment_ids.shape(), segment_flat, data_flat, - unreduced_data_flat, output_flat, slot_counter_matrix); + reduction_functor_(op_ctx, shell_ctx_var, rot_keys, segment_ids.shape(), + segment_flat, data_flat, output_flat, + slot_counter_matrix, reduction_type); } protected: diff --git a/tf_shell/cc/ops/shape_inference.cc b/tf_shell/cc/ops/shape_inference.cc index 9d98454..df60e57 100644 --- a/tf_shell/cc/ops/shape_inference.cc +++ b/tf_shell/cc/ops/shape_inference.cc @@ -142,18 +142,25 @@ Status ShellSegmentReductionWithNumSegmentsShape(InferenceContext* c) { DimensionHandle num_segments_dim; TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(3, &num_segments_dim)); - // Output is {2} + {segment_id_rank} + s_data[segment_id_rank - - // 1:]. 2 is because the top and bottom of the ciphertexts are treated + std::string reduction_type; + TF_RETURN_IF_ERROR(c->GetAttr("reduction", &reduction_type)); + + // When the reduction type is galois, the output shape is: + // [2, segment_id_rank, s_data[segment_id_rank - 1:]]. + // 2 is because the top and bottom of the ciphertexts are treated // independently. The packing dimension is not included as the output is // a ciphertext and holds this dimension implicitly. + // When the reduction type is none, the 2 is omitted. ShapeHandle s_data_suffix; auto rank = c->Rank(s_segment_ids_suffix); TF_RETURN_IF_ERROR(c->Subshape(s_data, rank, &s_data_suffix)); TF_RETURN_IF_ERROR( c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &data_out)); - TF_RETURN_IF_ERROR( - c->Concatenate(c->Vector(c->MakeDim(2)), data_out, &data_out)); + if (reduction_type == "galois") { + TF_RETURN_IF_ERROR( + c->Concatenate(c->Vector(c->MakeDim(2)), data_out, &data_out)); + } TF_RETURN_IF_ERROR(c->WithRankAtLeast(s_segment_ids, 1, &s_segment_ids)); DimensionHandle num_slots = c->Dim(s_segment_ids, 0); diff --git a/tf_shell/cc/ops/shell_ops.cc b/tf_shell/cc/ops/shell_ops.cc index 1dd9101..0f1a052 100644 --- a/tf_shell/cc/ops/shell_ops.cc +++ b/tf_shell/cc/ops/shell_ops.cc @@ -204,15 +204,7 @@ REGISTER_OP("MatMulPtCt64") .Input("a: Dtype") .Input("b: variant") .Input("rotation_key: variant") - .Output("c: variant") - .SetShapeFn(ShellMatMulPtCtShape); - -REGISTER_OP("FastMatMulPtCt64") - .Attr("Dtype: {uint8, int8, uint16, int16, uint32, int32, uint64, int64}") - .Input("context: variant") - .Input("a: Dtype") - .Input("b: variant") - // no rotation key + .Attr("reduction: string") .Output("c: variant") .SetShapeFn(ShellMatMulPtCtShape); @@ -441,6 +433,7 @@ REGISTER_OP("UnsortedCtSegmentSum") .Input("segment_ids: Tindices") .Input("num_segments: Tnumsegments") .Input("rotation_key: variant") + .Attr("reduction: string") .Output("output: variant") .Output("reduction_counts: Tindices") .Attr("Tindices: {int32,int64}") diff --git a/tf_shell/python/shell_ops.py b/tf_shell/python/shell_ops.py index 6312cec..221f712 100644 --- a/tf_shell/python/shell_ops.py +++ b/tf_shell/python/shell_ops.py @@ -50,7 +50,6 @@ mul_pt_pt64 = shell_ops.mul_pt_pt64 mat_mul_ct_pt64 = shell_ops.mat_mul_ct_pt64 mat_mul_pt_ct64 = shell_ops.mat_mul_pt_ct64 -fast_mat_mul_pt_ct64 = shell_ops.fast_mat_mul_pt_ct64 # Rotate slots. rotation_key_gen64 = shell_ops.rotation_key_gen64 diff --git a/tf_shell/python/shell_tensor.py b/tf_shell/python/shell_tensor.py index 29d09fa..59a25cf 100644 --- a/tf_shell/python/shell_tensor.py +++ b/tf_shell/python/shell_tensor.py @@ -842,7 +842,7 @@ def fast_reduce_sum(x): ) -def matmul(x, y, rotation_key=None, fast=False): +def matmul(x, y, rotation_key=None, pt_ct_reduction="galois"): """Matrix multiplication is specialized to whether the operands are plaintext or ciphertext. @@ -888,53 +888,48 @@ def matmul(x, y, rotation_key=None, fast=False): f"Underlying dtypes must match. Got {x.dtype} and {y._underlying_dtype}" ) + if pt_ct_reduction not in ["galois", "fast", "none"]: + raise ValueError( + f"pt_ct_reduction must be 'galois', 'fast', or 'none'. Got {pt_ct_reduction}." + ) + # Encode the plaintext x to the same scaling factor as y. scaled_x = _encode_scaling(x, y._scaling_factor) - if fast: - if y._is_fast_rotated: - raise ValueError( - "A ShellTensor which has been fast-reduced-summed cannot be fast-reduced-summed again." - ) - return ShellTensor64( - _raw_tensor=shell_ops.fast_mat_mul_pt_ct64( - y._context._get_context_at_level(y._level), - scaled_x, - y._raw_tensor, - # no rotation key - ), - _context=y._context, - _level=y._level, - _num_mod_reductions=y._num_mod_reductions, - _underlying_dtype=y._underlying_dtype, - _scaling_factor=y._scaling_factor**2, - _is_enc=True, - _is_fast_rotated=True, - ) - else: + if pt_ct_reduction == "galois": if not isinstance(rotation_key, ShellRotationKey64): raise ValueError( - f"Rotation key must be provided to matmul pt*ct. Instead saw {rotation_key}." + f"Rotation key must be provided to matmul pt*ct with galois reduction. Instead saw {rotation_key}." ) - # Get the correct rotation key for the level of y. raw_rotation_key = rotation_key._get_key_at_level(y._level) + elif pt_ct_reduction == "fast": + if y._is_fast_rotated: + raise ValueError( + "A ShellTensor which has been fast-reduced-summed cannot be fast-reduced-summed again." + ) + # Any variant tensor will do. It is ignored by the op. + raw_rotation_key = y._context._get_context_at_level(y._level) + elif pt_ct_reduction == "none": + # Any variant tensor will do. It is ignored by the op. + raw_rotation_key = y._context._get_context_at_level(y._level) - return ShellTensor64( - _raw_tensor=shell_ops.mat_mul_pt_ct64( - y._context._get_context_at_level(y._level), - scaled_x, - y._raw_tensor, - raw_rotation_key, - ), - _context=y._context, - _level=y._level, - _num_mod_reductions=y._num_mod_reductions, - _underlying_dtype=y._underlying_dtype, - _scaling_factor=y._scaling_factor**2, - _is_enc=True, - _is_fast_rotated=y._is_fast_rotated, - ) + return ShellTensor64( + _raw_tensor=shell_ops.mat_mul_pt_ct64( + y._context._get_context_at_level(y._level), + scaled_x, + y._raw_tensor, + raw_rotation_key, + reduction=pt_ct_reduction, + ), + _context=y._context, + _level=y._level, + _num_mod_reductions=y._num_mod_reductions, + _underlying_dtype=y._underlying_dtype, + _scaling_factor=y._scaling_factor**2, + _is_enc=True, + _is_fast_rotated=pt_ct_reduction == "fast", + ) elif isinstance(x, ShellTensor64) and isinstance(y, ShellTensor64): return NotImplementedError @@ -1061,16 +1056,24 @@ def split(x, num_or_size_splits, axis=0, num_splits=None): raise ValueError("Unsupported type for expand_dims") -def segment_sum(x, segments, num_segments, rotation_key=None): +def segment_sum(x, segments, num_segments, rotation_key=None, reduction="galois"): if not isinstance(segments, tf.Tensor): raise ValueError("`segments` must be a TensorFlow tensor.") if isinstance(x, ShellTensor64): - if not isinstance(rotation_key, ShellRotationKey64): - raise ValueError( - f"Rotation key must be provided. Instead saw {rotation_key}." - ) - raw_rotation_key = rotation_key._get_key_at_level(x._level) + if reduction not in ["galois", "none"]: + raise ValueError(f"Reduction must be 'galois' or 'none'. Got {reduction}.") + + if reduction == "galois": + if not isinstance(rotation_key, ShellRotationKey64): + raise ValueError( + f"Rotation key must be provided for galois-based reduction. Instead saw {rotation_key}." + ) + raw_rotation_key = rotation_key._get_key_at_level(x._level) + + elif reduction == "none": + # Any variant tensor will do. It is ignored by the op. + raw_rotation_key = x._context._get_context_at_level(x._level) raw_result, reduction_count = shell_ops.segment_sum_ct( x._context._get_context_at_level(x._level), @@ -1078,6 +1081,7 @@ def segment_sum(x, segments, num_segments, rotation_key=None): segments, num_segments, raw_rotation_key, + reduction=reduction, ) return ( diff --git a/tf_shell/test/mat_mul_test.py b/tf_shell/test/mat_mul_test.py index 985c124..bfe522f 100644 --- a/tf_shell/test/mat_mul_test.py +++ b/tf_shell/test/mat_mul_test.py @@ -179,7 +179,7 @@ def _test_tf_ct_matmul(self, test_context, use_fast_rotation): @tf.function def test_functor(): if use_fast_rotation: - ec = tf_shell.matmul(a, eb, fast=True) + ec = tf_shell.matmul(a, eb, pt_ct_reduction="fast") else: ec = tf_shell.matmul(a, eb, test_context.rotation_key) # Tests shape inference diff --git a/tf_shell/test/segment_test.py b/tf_shell/test/segment_test.py index 9087b36..6650e38 100644 --- a/tf_shell/test/segment_test.py +++ b/tf_shell/test/segment_test.py @@ -194,6 +194,76 @@ def test_segment_sum(self): ): self._test_segment_sum(test_context, segment_creator) + def _test_segment_sum_no_reduction(self, test_context, segment_creator_functor): + repeats = 8 + num_segments = test_context.shell_context.num_slots.numpy() // repeats + + a = self.create_rand_data(test_context, repeats) + if a is None: + return + + sa = tf_shell.to_shell_plaintext(a, test_context.shell_context) + ea = tf_shell.to_encrypted(sa, test_context.key, test_context.shell_context) + + segments = segment_creator_functor(test_context, repeats, num_segments) + segments_shape_should_be = [ + test_context.shell_context.num_slots.numpy(), + num_segments, + ] + counts_shape_should_be = [ + test_context.shell_context.num_slots.numpy(), + num_segments, + ] + + @tf.function + def test_functor(ea, segments, num_segments): + ess, counts = tf_shell.segment_sum( + ea, segments, num_segments, reduction="none" + ) + # Tests shape inference + self.assertEqual(ess.shape.ndims, len(segments_shape_should_be)) + for i in range(ess.shape.ndims): + if ess.shape[i] is not None: + self.assertEqual(ess.shape[i], segments_shape_should_be[i]) + self.assertEqual(counts.shape.ndims, len(counts_shape_should_be)) + for i in range(counts.shape.ndims): + if counts.shape[i] is not None: + self.assertEqual(counts.shape[i], counts_shape_should_be[i]) + + return ess, counts + + ess, counts = test_functor(ea, segments, num_segments) + + ss = tf_shell.to_tensorflow(ess, test_context.key) + + pt_result = tf.math.unsorted_segment_sum(a, segments, num_segments) + + # Ensure the data is correct. + self.assertAllClose(pt_result, tf.reduce_sum(ss, axis=0)) + + # Ensure the counts are correct. + def bincount(x): + return tf.math.bincount(x, minlength=num_segments, maxlength=num_segments) + + segments_nonnegative = tf.where(segments >= 0, segments, num_segments + 1) + pt_counts = tf.map_fn(bincount, segments_nonnegative) + self.assertAllEqual(pt_counts, counts) + + # Ensure initial arguments are not modified. + self.assertAllClose(a, tf_shell.to_tensorflow(sa)) + self.assertAllClose(a, tf_shell.to_tensorflow(ea, test_context.key)) + + def test_segment_sum_no_reduction(self): + for test_context in self.test_contexts: + for segment_creator in [ + self.create_uniform_segments, + self.create_nonuniform_segments, + ]: + with self.subTest( + f"{self._testMethodName} with context `{test_context}` and segment creator `{segment_creator}``." + ): + self._test_segment_sum_no_reduction(test_context, segment_creator) + def _test_segment_sum_fewer_dims(self, test_context, segment_creator_functor): repeats = 8 num_segments = test_context.shell_context.num_slots.numpy() // repeats diff --git a/tf_shell/test/test_utils.py b/tf_shell/test/test_utils.py index 9b4de67..b8be8ba 100644 --- a/tf_shell/test/test_utils.py +++ b/tf_shell/test/test_utils.py @@ -189,6 +189,7 @@ def uniform_for_n_muls(test_context, num_muls, shape=None, subsequent_adds=0): min_val, max_val = get_bounds_for_n_muls(test_context, num_muls) + subsequent_adds = tf.cast(subsequent_adds, min_val.dtype) min_val = min_val / (subsequent_adds + 1) max_val = max_val / (subsequent_adds + 1) diff --git a/tf_shell_ml/conv2d.py b/tf_shell_ml/conv2d.py index 42ccf3d..9f6779d 100644 --- a/tf_shell_ml/conv2d.py +++ b/tf_shell_ml/conv2d.py @@ -33,7 +33,7 @@ def __init__( kernel_initializer="glorot_uniform", weight_dtype=tf.float32, is_first_layer=False, - use_fast_reduce_sum=False, + grad_reduction="none", ): super().__init__() self.filters = int(filters) @@ -60,7 +60,12 @@ def __init__( self.kernel_initializer = initializers.get(kernel_initializer) self.weight_dtype = weight_dtype self.is_first_layer = is_first_layer - self.use_fast_reduce_sum = use_fast_reduce_sum + self.grad_reduction = grad_reduction + + if grad_reduction not in ["galois", "fast", "none"]: + raise ValueError( + f"Invalid grad_reduction type: {grad_reduction} (must be 'galois', 'fast', or 'none')" + ) def get_config(self): config = super().get_config() @@ -133,10 +138,10 @@ def backward(self, dy, rotation_key): x, dy_exp, [1, 1, 1, 1], self.padding, self.strides, with_channel=True ) - if self.use_fast_reduce_sum: - d_w = tf_shell.fast_reduce_sum(d_w) - else: + if self.grad_reduction == "galois": d_w = tf_shell.reduce_sum(d_w, 0, rotation_key) + elif self.grad_reduction == "fast": + d_w = tf_shell.fast_reduce_sum(d_w) grad_weights.append(d_w) diff --git a/tf_shell_ml/dense.py b/tf_shell_ml/dense.py index 642b5bd..7f556cf 100644 --- a/tf_shell_ml/dense.py +++ b/tf_shell_ml/dense.py @@ -28,10 +28,9 @@ def __init__( use_bias=False, kernel_initializer="glorot_uniform", bias_initializer="zeros", - skip_normalization=True, weight_dtype=tf.float32, is_first_layer=False, - use_fast_reduce_sum=False, + grad_reduction="none", ): super().__init__() self.units = int(units) @@ -41,10 +40,14 @@ def __init__( self.kernel_initializer = initializers.get(kernel_initializer) self.bias_initializer = initializers.get(bias_initializer) - self.skip_normalization = skip_normalization self.weight_dtype = weight_dtype self.is_first_layer = is_first_layer - self.use_fast_reduce_sum = use_fast_reduce_sum + self.grad_reduction = grad_reduction + + if grad_reduction not in ["galois", "fast", "none"]: + raise ValueError( + f"Invalid grad_reduction type: {grad_reduction} (must be 'galois', 'fast', or 'none')" + ) def get_config(self): config = super().get_config() @@ -94,7 +97,7 @@ def backward(self, dy, rotation_key): x = self._layer_input z = self._layer_intermediate kernel = self.weights[0] - grad_weights = [] + d_ws = [] # On the forward pass, inputs may be batched differently than the # ciphertext scheme when not in eager mode. Pad them to match the @@ -116,27 +119,25 @@ def backward(self, dy, rotation_key): d_x = tf_shell.matmul(dy, kernel_t) # Perform the multiplication for dy/dw. - if self.use_fast_reduce_sum: - d_weights = tf_shell.matmul(tf.transpose(x), dy, fast=True) - else: - d_weights = tf_shell.matmul(tf.transpose(x), dy, rotation_key) - - if not self.skip_normalization: - batch_size = tf.shape(plaintext_packed_dx)[0] - d_weights = d_weights / batch_size - grad_weights.append(d_weights) + d_w = tf_shell.matmul( + tf.transpose(x), dy, rotation_key, pt_ct_reduction=self.grad_reduction + ) + d_ws.append(d_w) if self.use_bias: - if self.use_fast_reduce_sum: + if self.grad_reduction == "galois": + d_bias = tf_shell.reduce_sum(dy, axis=0, rotation_key=rotation_key) + elif self.grad_reduction == "fast": d_bias = tf_shell.fast_reduce_sum(dy) else: - d_bias = tf_shell.reduce_sum(dy, axis=0, rotation_key=rotation_key) + if not isinstance(dy, tf_shell.ShellTensor64): + d_bias = tf.reduce_sum(dy, axis=0) + else: + d_bias = dy - if not self.skip_normalization: - d_bias = d_bias / batch_size - grad_weights.append(d_bias) + d_ws.append(d_bias) - return grad_weights, d_x + return d_ws, d_x @staticmethod def unpack(plaintext_packed_dx): diff --git a/tf_shell_ml/dpsgd_sequential_model.py b/tf_shell_ml/dpsgd_sequential_model.py index de85155..df7e03a 100644 --- a/tf_shell_ml/dpsgd_sequential_model.py +++ b/tf_shell_ml/dpsgd_sequential_model.py @@ -84,13 +84,6 @@ def call(self, x, training=False): x = l(x, training=training) return x - def build(self, input_shape): - super().build(input_shape) - self.unpacking_funcs = [] - for l in self.layers: - if hasattr(l, "unpacking_funcs"): - self.unpacking_funcs.extend(l.unpacking_funcs()) - def compute_max_two_norm_and_pred(self, features, skip_two_norm): with tf.GradientTape(persistent=tf.executing_eagerly()) as tape: y_pred = self(features, training=True) # forward pass @@ -187,7 +180,7 @@ def train_step(self, data): t = tf.cast(backprop_context.plaintext_modulus, tf.float32) t_half = t // 2 mask_scaling_factors = [g._scaling_factor for g in reversed(dJ_dw)] - mask = [ + masks = [ tf.random.uniform( tf_shell.shape(g), dtype=tf.float32, @@ -195,13 +188,11 @@ def train_step(self, data): maxval=t_half / s, ) for g, s in zip(reversed(dJ_dw), mask_scaling_factors) - # tf.zeros_like(tf_shell.shape(g), dtype=tf.int64) - # for g in dJ_dw ] # Mask the encrypted gradients and reverse the order to match # the order of the layers. - grads = [(g + m) for g, m in zip(reversed(dJ_dw), mask)] + grads = [(g + m) for g, m in zip(reversed(dJ_dw), masks)] if not self.disable_noise: # Features party encrypts the max two norm to send to the labels @@ -217,7 +208,7 @@ def train_step(self, data): with tf.device(self.labels_party_dev): if not self.disable_encryption: # Decrypt the weight gradients with the backprop key. - packed_grads = [ + grads = [ tf_shell.to_tensorflow( g, ( @@ -229,12 +220,14 @@ def train_step(self, data): for g in grads ] - # Unpack the plaintext gradients using the corresponding layer's - # unpack function. - # TODO: Make sure this doesn't require sending the layers - # themselves just for unpacking. The weights should not be - # shared with the labels party. - grads = [f(g) for f, g in zip(self.unpacking_funcs, packed_grads)] + # Sum the masked gradients over the batch. + if self.disable_masking or self.disable_encryption: + grads = [tf.reduce_sum(g, 0) for g in grads] + else: + grads = [ + tf_shell.reduce_sum_with_mod(g, 0, backprop_context, s) + for g, s in zip(grads, mask_scaling_factors) + ] if not self.disable_noise: tf.assert_equal( @@ -271,7 +264,7 @@ def train_step(self, data): # secret key. flat_grads = tf_shell.to_tensorflow(grads, noise_secret_key) - if self.check_overflow_INSECURE: + if not self.disable_encryption and self.check_overflow_INSECURE: nosie_scaling_factors = grads._scaling_factor self.warn_on_overflow( [flat_grads], @@ -289,12 +282,19 @@ def train_step(self, data): ) if not self.disable_masking and not self.disable_encryption: + # Sum the masks over the batch. + sum_masks = [ + tf_shell.reduce_sum_with_mod(m, 0, backprop_context, s) + for m, s in zip(masks, mask_scaling_factors) + ] + + # Unmask the batch gradient. + grads = [mg - m for mg, m in zip(grads, sum_masks)] + # SHELL represents floats as integers between [0, t) where t is - # the plaintext modulus. To mimic the modulo operation without - # SHELL, numbers which exceed the range [-t/2, t/2) are shifted - # back into the range. In this context, t is the plaintext - # modulus of the backprop context, since that what the gradients - # were encrypted with when the mask was added. + # the plaintext modulus. To mimic SHELL's modulo operations in + # TensorFlow, numbers which exceed the range [-t/2, t/2] are + # shifted back into the range. epsilon = tf.constant(1e-6, dtype=float) def rebalance(x, s): @@ -305,15 +305,6 @@ def rebalance(x, s): x = tf.where(x < l_bound, x + t_over_s, x) return x - # Unmask the gradients using the mask. The unpacking function may - # sum the mask from two of the gradients (one from each batch), so - # the mask must be brought back into the range of [-t/2, t/2] before - # subtracting it from the gradient, and again after. - unpacked_mask = [f(m) for f, m in zip(self.unpacking_funcs, mask)] - unpacked_mask = [ - rebalance(m, s) for m, s in zip(unpacked_mask, mask_scaling_factors) - ] - grads = [mg - m for mg, m in zip(grads, unpacked_mask)] grads = [rebalance(g, s) for g, s in zip(grads, mask_scaling_factors)] # Apply the gradients to the model. diff --git a/tf_shell_ml/embedding.py b/tf_shell_ml/embedding.py index 56faeda..93fb31d 100644 --- a/tf_shell_ml/embedding.py +++ b/tf_shell_ml/embedding.py @@ -26,12 +26,19 @@ def __init__( output_dim, embeddings_initializer="uniform", skip_embeddings_below_index=0, + grad_reduction="none", ): super().__init__() self.input_dim = int(input_dim) self.output_dim = int(output_dim) self.embeddings_initializer = initializers.get(embeddings_initializer) self.skip_embeddings_below_index = skip_embeddings_below_index + self.grad_reduction = grad_reduction + + if grad_reduction not in ["galois", "none"]: + raise ValueError( + f"Invalid grad_reduction type: {grad_reduction} (must be 'galois' or 'none')" + ) def get_config(self): config = super().get_config() @@ -112,6 +119,7 @@ def backward(self, dy, rotation_key): indices, self.input_dim, rotation_key, + reduction=self.grad_reduction, ) return [summedvalues], tf.zeros(0) diff --git a/tf_shell_ml/postscale_sequential_model.py b/tf_shell_ml/postscale_sequential_model.py index 0fc63e3..7e09601 100644 --- a/tf_shell_ml/postscale_sequential_model.py +++ b/tf_shell_ml/postscale_sequential_model.py @@ -272,10 +272,10 @@ def train_step(self, data): # Unmask the batch gradient. grads = [mg - m for mg, m in zip(grads, sum_masks)] - # SHELL represents floats as integers between [0, t) where t is the - # plaintext modulus. To mimic SHELL's modulo operations in - # TensorFlow, numbers which exceed the range [-t/2, t/2] are shifted - # back into the range. + # SHELL represents floats as integers between [0, t) where t is + # the plaintext modulus. To mimic SHELL's modulo operations in + # TensorFlow, numbers which exceed the range [-t/2, t/2] are + # shifted back into the range. epsilon = tf.constant(1e-6, dtype=float) def rebalance(x, s): diff --git a/tf_shell_ml/test/conv2d_test.py b/tf_shell_ml/test/conv2d_test.py index f411168..3b31189 100644 --- a/tf_shell_ml/test/conv2d_test.py +++ b/tf_shell_ml/test/conv2d_test.py @@ -82,7 +82,7 @@ def _test_conv2d_plaintext_forward_backward_correct( tf_dw = tape.gradient(y, tf_conv_layer.trainable_variables) self.assertAllClose(dx, tf_dx) - self.assertAllClose(dws[0], tf_dw[0]) + self.assertAllClose(tf.reduce_sum(dws[0], axis=0), tf_dw[0]) def test_conv2d_plaintext_forward_backward_correct(self): for stride in [1, 2]: @@ -124,7 +124,7 @@ def forward_backward(x): # Encrypted backward pass. enc_dw, enc_dx = conv_layer.backward(enc_dy, rotation_key) dw = tf_shell.to_tensorflow(enc_dw[0], key) - dw = conv_layer.unpack(dw) + # dw = conv_layer.unpack(dw) # for layer reduction 'fast' or 'galois' dx = tf_shell.to_tensorflow(enc_dx, key) # Plaintext backward pass. diff --git a/tf_shell_ml/test/dpsgd_conv_model_local_test.py b/tf_shell_ml/test/dpsgd_conv_model_local_test.py index 6e303f9..95f8063 100644 --- a/tf_shell_ml/test/dpsgd_conv_model_local_test.py +++ b/tf_shell_ml/test/dpsgd_conv_model_local_test.py @@ -55,7 +55,6 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise): kernel_size=8, strides=2, padding="SAME", - use_fast_reduce_sum=True, ), tf_shell_ml.MaxPool2D( pool_size=(2, 2), @@ -65,18 +64,15 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise): filters=32, kernel_size=4, strides=2, - use_fast_reduce_sum=True, ), tf_shell_ml.Flatten(), tf_shell_ml.ShellDense( 16, activation=tf.nn.softmax, - use_fast_reduce_sum=True, ), tf_shell_ml.ShellDense( 10, activation=tf.nn.softmax, - use_fast_reduce_sum=True, ), ], backprop_context_fn=lambda: tf_shell.create_autocontext64( diff --git a/tf_shell_ml/test/dpsgd_model_distrib_test.py b/tf_shell_ml/test/dpsgd_model_distrib_test.py index ab04ba9..f62c110 100644 --- a/tf_shell_ml/test/dpsgd_model_distrib_test.py +++ b/tf_shell_ml/test/dpsgd_model_distrib_test.py @@ -88,12 +88,10 @@ def test_model(self): 64, activation=tf_shell_ml.relu, activation_deriv=tf_shell_ml.relu_deriv, - use_fast_reduce_sum=True, ), tf_shell_ml.ShellDense( 10, activation=tf.nn.softmax, - use_fast_reduce_sum=True, ), ], lambda: tf_shell.create_autocontext64( diff --git a/tf_shell_ml/test/dpsgd_model_local_test.py b/tf_shell_ml/test/dpsgd_model_local_test.py index 7bfe911..09cc56d 100644 --- a/tf_shell_ml/test/dpsgd_model_local_test.py +++ b/tf_shell_ml/test/dpsgd_model_local_test.py @@ -52,12 +52,10 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise): 64, activation=tf_shell_ml.relu, activation_deriv=tf_shell_ml.relu_deriv, - use_fast_reduce_sum=True, ), tf_shell_ml.ShellDense( 10, activation=tf.nn.softmax, - use_fast_reduce_sum=True, ), ], lambda: tf_shell.create_autocontext64( diff --git a/tf_shell_ml/test/embedding_test.py b/tf_shell_ml/test/embedding_test.py index d183eb0..013dd6d 100644 --- a/tf_shell_ml/test/embedding_test.py +++ b/tf_shell_ml/test/embedding_test.py @@ -79,25 +79,25 @@ def forward_backward(x): enc_dy = tf_shell.to_encrypted(dy, key, context) enc_dw, _ = embedding_layer.backward(enc_dy, rotation_key) - packed_dx = tf_shell.to_tensorflow(enc_dw[0], key) - dx = embedding_layer.unpack(packed_dx) + dw = tf_shell.to_tensorflow(enc_dw[0], key) + # dw = embedding_layer.unpack(packed_dw) + dw = tf.reduce_sum(dw, axis=0) + return dw - return dx - - dx = forward_backward(x) + dw = forward_backward(x) for i in range(0, input_dim): - # Check dx[ special_index] has counted the number of elements. + # Check dw[ special_index] has counted the number of elements. if i == special_index: self.assertAllEqual( - dx[special_index, :], + dw[special_index, :], tf.constant( context.num_slots * sentence_length, shape=(output_dim,) ), ) # Make sure the rest of the gradient elements are 0. else: - self.assertAllEqual(dx[i, :], tf.constant(0, shape=(output_dim,))) + self.assertAllEqual(dw[i, :], tf.constant(0, shape=(output_dim,))) def test_embedding_eager(self): tf.config.run_functions_eagerly(True) diff --git a/tf_shell_ml/test/mnist_enc_backprop_test.py b/tf_shell_ml/test/mnist_enc_backprop_test.py index 745cf81..5285aae 100644 --- a/tf_shell_ml/test/mnist_enc_backprop_test.py +++ b/tf_shell_ml/test/mnist_enc_backprop_test.py @@ -91,7 +91,7 @@ def _test_mnist_enc_backprop(self, use_fast_reduce_sum): activation=tf_shell_ml.relu, activation_deriv=tf_shell_ml.relu_deriv, is_first_layer=True, - use_fast_reduce_sum=use_fast_reduce_sum, + grad_reduction="fast" if use_fast_reduce_sum else "galois", ) output_layer = tf_shell_ml.ShellDense( 10, @@ -102,7 +102,7 @@ def _test_mnist_enc_backprop(self, use_fast_reduce_sum): # to compute than each of them individually). So instead just let the # loss function derivative incorporate y_pred - y and let the derivative # of this last layer's activation be a no-op. - use_fast_reduce_sum=use_fast_reduce_sum, + reduction="fast" if use_fast_reduce_sum else "galois", ) # Call the layers once to create the weights.