diff --git a/tf_shell/cc/ops/shell_ops.cc b/tf_shell/cc/ops/shell_ops.cc index 0f1a052..66230c0 100644 --- a/tf_shell/cc/ops/shell_ops.cc +++ b/tf_shell/cc/ops/shell_ops.cc @@ -72,6 +72,7 @@ REGISTER_OP("PolynomialExport64") .Input("shell_context: variant") .Input("val: variant") .Input("runtime_batching_dim: int64") + .Attr("final_scaling_factor: int") .Output("out: dtype") .SetShapeFn(ExportAndAddBatchingDimShape<1>); @@ -94,6 +95,7 @@ REGISTER_OP("Decrypt64") .Input("key: variant") .Input("val: variant") .Input("runtime_batching_dim: int64") + .Attr("final_scaling_factor: int") .Output("out: dtype") .SetShapeFn(ExportAndAddBatchingDimShape<2>); @@ -298,6 +300,7 @@ REGISTER_OP("DecryptFastRotated64") .Input("fast_rotation_key: variant") .Input("val: variant") .Input("runtime_batching_dim: int64") + .Attr("final_scaling_factor: int") .Output("out: dtype") .SetShapeFn(ExportAndAddBatchingDimShape<2>); diff --git a/tf_shell/cc/optimizers/moduli_autotune.cc b/tf_shell/cc/optimizers/moduli_autotune.cc index 61c8e4b..ed5a447 100644 --- a/tf_shell/cc/optimizers/moduli_autotune.cc +++ b/tf_shell/cc/optimizers/moduli_autotune.cc @@ -26,8 +26,7 @@ namespace { constexpr bool const qs_mod_t_is_one = true; -constexpr bool const debug_moduli = true; -constexpr bool const debug_mul_depth = false; +constexpr bool const debug_moduli = false; constexpr bool const debug_noise_estimation = false; constexpr bool const debug_graph = false; constexpr bool const debug_output_params = true; @@ -109,8 +108,7 @@ Status AddScalarConstNode(T value, utils::Mutation* mutation, } else { []() { static_assert(flag, "AddScalarConstNode does not support this type"); - } - (); + }(); } tensor->set_allocated_tensor_shape(tensor_shape.release()); (*node.mutable_attr())["value"].set_allocated_tensor(tensor.release()); @@ -172,7 +170,8 @@ Status GetAutoShellContextParams(utils::MutableNodeView* autocontext, StatusOr DecryptUsesSameContext(utils::MutableNodeView const* node_view, utils::MutableNodeView const* context) { utils::MutableNodeView const* trace = node_view; - if (trace == nullptr || !IsDecrypt(*trace->node())) { + if (trace == nullptr || + !(IsDecrypt(*trace->node()) || IsDecode(*trace->node()))) { return errors::InvalidArgument( "Expected the node to be a decrypt node, but found ", trace->GetOp()); } @@ -216,102 +215,37 @@ StatusOr DecryptUsesSameContext(utils::MutableNodeView const* node_view, return true; } -StatusOr GetMulDepth(utils::MutableGraphView& graph_view, - utils::MutableNodeView const* autocontext) { - // Traverse the graph and return the maximum multiplicative depth. +StatusOr MaxScalingFactor(utils::MutableGraphView& graph_view, + utils::MutableNodeView const* autocontext) { + // Traverse the graph and return the maximum scaling factor across all decrypt + // ops tied to this autocontext. int const num_nodes = graph_view.NumNodes(); - std::vector node_mul_depth(num_nodes); - - if constexpr (debug_mul_depth) { - std::cout << "Calculating multiplicative depth." << std::endl; - } + int64_t max_sf = 0; - uint64_t max_depth = 0; for (int i = 0; i < num_nodes; ++i) { auto const* this_node_view = graph_view.GetNode(i); auto const* this_node_def = this_node_view->node(); - if (IsArithmetic(*this_node_def) || IsTfShellMatMul(*this_node_def) || - IsMulCtTfScalar(*this_node_def) || IsMulPtTfScalar(*this_node_def)) { - // Get the fanin nodes. - int const fanin_a_index = this_node_view->GetRegularFanin(1).node_index(); - int const fanin_b_index = this_node_view->GetRegularFanin(2).node_index(); - int max_fanin_depth = std::max(node_mul_depth[fanin_a_index], - node_mul_depth[fanin_b_index]); - - // Update the multiplicative depth of this node. - if (IsMulCtCt(*this_node_def) || IsMulCtPt(*this_node_def) || - IsMulPtPt(*this_node_def) || IsTfShellMatMul(*this_node_def)) { - node_mul_depth[i] = max_fanin_depth + 1; - } else { - node_mul_depth[i] = max_fanin_depth; - } - } - - else if (IsNegCt(*this_node_def) || - IsFastReduceSumByRotation(*this_node_def) || - IsUnsortedCtSegmentSum(*this_node_def)) { - int const fanin_a_index = this_node_view->GetRegularFanin(1).node_index(); - node_mul_depth[i] = node_mul_depth[fanin_a_index]; - } - - else if (IsRoll(*this_node_def) || IsReduceSumByRotation(*this_node_def)) { - int const fanin_a_index = this_node_view->GetRegularFanin(2).node_index(); - node_mul_depth[i] = node_mul_depth[fanin_a_index]; - } - - else if (IsConv2d(*this_node_def)) { - int const fanin_a_index = this_node_view->GetRegularFanin(1).node_index(); - int const fanin_b_index = this_node_view->GetRegularFanin(2).node_index(); - node_mul_depth[i] = std::max(node_mul_depth[fanin_a_index], - node_mul_depth[fanin_b_index]) + - 1; - } - - else if (IsMaxUnpool2d(*this_node_def)) { - // Max unpool performs a multiplication but does not affect the scaling - // factor since it is by a selection 0 or 1 plaintext. - int const fanin_a_index = this_node_view->GetRegularFanin(1).node_index(); - node_mul_depth[i] = node_mul_depth[fanin_a_index]; - } - - else if (IsExpandDimsVariant(*this_node_def)) { - int const fanin_a_index = this_node_view->GetRegularFanin(0).node_index(); - node_mul_depth[i] = node_mul_depth[fanin_a_index]; - } else if (IsBroadcastToShape(*this_node_def)) { - int const fanin_a_index = this_node_view->GetRegularFanin(0).node_index(); - node_mul_depth[i] = node_mul_depth[fanin_a_index]; - } else if (IsReshape(*this_node_def)) { - int const fanin_a_index = this_node_view->GetRegularFanin(0).node_index(); - node_mul_depth[i] = node_mul_depth[fanin_a_index]; - } else if (IsStridedSlice(*this_node_def)) { - int const fanin_a_index = this_node_view->GetRegularFanin(0).node_index(); - node_mul_depth[i] = node_mul_depth[fanin_a_index]; - - // Decryption is where the maximum multiplicative depth is reached. - } else if (IsDecrypt(*this_node_def)) { - int const fanin_a_index = this_node_view->GetRegularFanin(2).node_index(); - node_mul_depth[i] = node_mul_depth[fanin_a_index]; - + // Decryption is where the maximum multiplicative depth is reached. + if (IsDecrypt(*this_node_def) || IsDecode(*this_node_def)) { // Ensure the decrypt op uses the same autocontext node as the argument // (for the case where there are multiple autocontext nodes in the graph). TF_ASSIGN_OR_RETURN(bool is_same_autocontext, DecryptUsesSameContext(this_node_view, autocontext)); - if (is_same_autocontext) { - max_depth = std::max(max_depth, node_mul_depth[i]); + + int64 sf = 0; + if (!TryGetNodeAttr(*this_node_def, "final_scaling_factor", &sf)) { + std::cout + << "WARNING: Could not determine scaling factor in Decrypt op." + << " Plaintext modulus may be underprovisioned." << std::endl; } - } - if constexpr (debug_mul_depth) { - std::cout << " Node: " << this_node_def->name() - << " Depth: " << node_mul_depth[i] - << " Max depth: " << max_depth << std::endl; + if (is_same_autocontext) { + max_sf = std::max(max_sf, sf); + } } } - if constexpr (debug_mul_depth) { - std::cout << "Max Multiplicative Depth: " << max_depth << std::endl; - } - return max_depth; + return max_sf; } // Function for modular exponentiation @@ -862,12 +796,12 @@ Status ReplaceAutoparamWithContext(utils::MutableGraphView& graph_view, } // Create the new inputs for the ShellContext node. - std::string log_n_name = "ContextImport64/log_n"; - std::string qs_name = "ContextImport64/main_moduli"; - std::string ps_name = "ContextImport64/aux_moduli"; - std::string t_name = "ContextImport64/plaintext_modulus"; - std::string noise_var_name = "ContextImport64/noise_variance"; - std::string seed_str_name = "ContextImport64/seed"; + std::string log_n_name = autocontext->GetName() + "/log_n"; + std::string qs_name = autocontext->GetName() + "/main_moduli"; + std::string ps_name = autocontext->GetName() + "/aux_moduli"; + std::string t_name = autocontext->GetName() + "/plaintext_modulus"; + std::string noise_var_name = autocontext->GetName() + "/noise_variance"; + std::string seed_str_name = autocontext->GetName() + "/seed"; utils::Mutation* mutation = graph_view.GetMutationBuilder(); std::string device = autocontext->GetDevice(); @@ -942,18 +876,18 @@ Status OptimizeAutocontext(utils::MutableGraphView& graph_view, ShellAutoParams auto_params; TF_RETURN_IF_ERROR(GetAutoShellContextParams(autocontext, auto_params)); - // Find the maximum multiplicative depth of the graph and use this to set - // the plaintext modulus t, based on the scaling factor and depth. - // The depth computation includes the initial scaling factor included during - // encryption. - TF_ASSIGN_OR_RETURN(int mul_depth, GetMulDepth(graph_view, autocontext)); - uint64_t mul_bits = - BitWidth(std::pow(auto_params.scaling_factor, std::pow(2, mul_depth))); - uint64_t total_plaintext_bits = auto_params.cleartext_bits + mul_bits; + // Find the maximum scaling factor used in the graph and use this to set the + // plaintext modulus t. The maximum plaintext bits * the maximum scaling + // factor is the maximum size. This is too conservative an estimate, so + // for now, it is disabled. + // TF_ASSIGN_OR_RETURN(int64_t max_sf, + // MaxScalingFactor(graph_view, autocontext)); + // uint64_t sf_bits = BitWidth(max_sf); + uint64_t total_plaintext_bits = auto_params.cleartext_bits; if constexpr (debug_moduli) { - std::cout << "Multiplicative Depth: " << mul_depth << std::endl; - std::cout << "Bits of scaling factor: " << mul_bits << std::endl; + // std::cout << "Max bits of scaling factor upon decryption: " << sf_bits + // << std::endl; std::cout << "Total Cleartext Bits: " << total_plaintext_bits << std::endl; } if (total_plaintext_bits >= kMaxPrimeBitsCiphertext) { diff --git a/tf_shell/python/shell_tensor.py b/tf_shell/python/shell_tensor.py index 59a25cf..b223b03 100644 --- a/tf_shell/python/shell_tensor.py +++ b/tf_shell/python/shell_tensor.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math import tensorflow as tf import tf_shell.python.shell_ops as shell_ops from tf_shell.python.shell_context import ShellContext64 @@ -321,7 +322,7 @@ def __neg__(self): def __mul__(self, other): if isinstance(other, ShellTensor64): - matched_self, matched_other = _match_moduli_and_scaling(self, other) + matched_self, matched_other = _match_moduli(self, other) if self.is_encrypted and other.is_encrypted: if self._is_fast_rotated or other._is_fast_rotated: @@ -360,7 +361,7 @@ def __mul__(self, other): _level=matched_self._level, _num_mod_reductions=matched_self._num_mod_reductions, _underlying_dtype=self._underlying_dtype, - _scaling_factor=matched_self._scaling_factor**2, + _scaling_factor=matched_self._scaling_factor * matched_other._scaling_factor, _is_enc=self._is_enc or other._is_enc, _is_fast_rotated=self._is_fast_rotated or other._is_fast_rotated, ) @@ -471,33 +472,29 @@ def mod_reduce_tensor64(shell_tensor): return reduced_self - -def _match_moduli_and_scaling(x, y): - with tf.name_scope("match_moduli_and_scaling"): +def _match_moduli(x, y): + with tf.name_scope("match_moduli"): # Mod switch to the smaller modulus of the two. while x._num_mod_reductions < y._num_mod_reductions: x = mod_reduce_tensor64(x) while x._num_mod_reductions > y._num_mod_reductions: y = mod_reduce_tensor64(y) - # Make sure the scaling factors are compatible. This should always be - # true unless the user is mixing contexts with different scaling - # factors. - xsf = x._scaling_factor - ysf = y._scaling_factor - while xsf < ysf: - xsf *= x._context.scaling_factor - while ysf < xsf: - ysf *= y._context.scaling_factor - if xsf != ysf: - raise ValueError(f"Scaling factors must be compatible. Got {xsf} and {ysf}") + return x, y + +def _match_moduli_and_scaling(x, y): + with tf.name_scope("match_moduli_and_scaling"): + x, y = _match_moduli(x, y) - # Match the scaling factors. - if xsf > x._scaling_factor: - x = x.__mul__(xsf / (x._scaling_factor**2)) - if ysf > y._scaling_factor: - y = y.__mul__(ysf / (y._scaling_factor**2)) + gcd = math.gcd(x._scaling_factor, y._scaling_factor) + lcm = math.lcm(x._scaling_factor, y._scaling_factor) + # Match the scaling factors. + if lcm > x._scaling_factor: + x = x.__mul__(gcd / x._scaling_factor) + if lcm > y._scaling_factor: + y = y.__mul__(gcd / y._scaling_factor) + return x, y @@ -663,6 +660,7 @@ def to_tensorflow(s_tensor, key=None): runtime_batching_dim=s_tensor._context.num_slots, dtype=shell_dtype, batching_dim=batching_dim, + final_scaling_factor=s_tensor._scaling_factor, ) elif s_tensor.is_encrypted: @@ -679,6 +677,7 @@ def to_tensorflow(s_tensor, key=None): runtime_batching_dim=s_tensor._context.num_slots, dtype=shell_dtype, batching_dim=batching_dim, + final_scaling_factor=s_tensor._scaling_factor, ) elif not s_tensor.is_encrypted: @@ -690,6 +689,7 @@ def to_tensorflow(s_tensor, key=None): runtime_batching_dim=s_tensor._context.num_slots, dtype=shell_dtype, batching_dim=batching_dim, + final_scaling_factor=s_tensor._scaling_factor, ) else: @@ -894,7 +894,7 @@ def matmul(x, y, rotation_key=None, pt_ct_reduction="galois"): ) # Encode the plaintext x to the same scaling factor as y. - scaled_x = _encode_scaling(x, y._scaling_factor) + scaled_x = _encode_scaling(x, y._context.scaling_factor) if pt_ct_reduction == "galois": if not isinstance(rotation_key, ShellRotationKey64): @@ -926,7 +926,7 @@ def matmul(x, y, rotation_key=None, pt_ct_reduction="galois"): _level=y._level, _num_mod_reductions=y._num_mod_reductions, _underlying_dtype=y._underlying_dtype, - _scaling_factor=y._scaling_factor**2, + _scaling_factor=y._scaling_factor * y._context.scaling_factor, _is_enc=True, _is_fast_rotated=pt_ct_reduction == "fast", ) @@ -1112,7 +1112,7 @@ def _conv2d(x, filt, strides, padding, dilations, func): "A ShellTensor which has been fast-rotated or fast-reduced-summed cannot be an input to conv2d." ) - matched_x, matched_filt = _match_moduli_and_scaling(x, filt) + matched_x, matched_filt = _match_moduli(x, filt) return ShellTensor64( _raw_tensor=func( @@ -1128,7 +1128,7 @@ def _conv2d(x, filt, strides, padding, dilations, func): _level=matched_x._level, _num_mod_reductions=matched_x._num_mod_reductions, _underlying_dtype=matched_x._underlying_dtype, - _scaling_factor=matched_x._scaling_factor**2, + _scaling_factor=matched_x._scaling_factor * matched_filt._scaling_factor, _is_enc=True, _is_fast_rotated=False, ) diff --git a/tf_shell/test/composite_test.py b/tf_shell/test/composite_test.py index 90ff362..ebeeb27 100644 --- a/tf_shell/test/composite_test.py +++ b/tf_shell/test/composite_test.py @@ -95,9 +95,12 @@ def _test_ct_ct_mulmul(self, test_context): ) # Here, ec has a mul depth of 1 while eb has a mul depth of 0. To - # multiply them, eb needs to be mod reduced to match ec. ShellTensor - # should handle this automatically. + # multiply them, tf-shell needs to account for the difference in scaling + # factors. For ct_ct multiplication, the scaling factors must match, so + # in this case eb will be scaled up with a ct_pt multiplication to match + # ec. tf-shell will handle this automatically. ed = ec * eb + self.assertEqual(ed._scaling_factor, (ea._scaling_factor**2)**2) self.assertAllClose( a * b * b, tf_shell.to_tensorflow(ed, test_context.key), atol=1e-3 ) @@ -132,10 +135,13 @@ def _test_ct_pt_mulmul(self, test_context): ) # Here, ec has a mul depth of 1 while b is has a mul depth of 0. To - # multiply them, b needs to be encoded as a shell plaintext with - # moduli which match the now-mod-reduced ec. ShellTensor should handle - # this automatically. + # multiply them, tf-shell needs to account for the difference in + # scaling factors. For ct_pt multiplication, the scaling factors do + # not need to match, but their product must be remembered and divided + # out when the result is decrypted. tf-shell will handle this + # automatically. ed = ec * b + self.assertEqual(ed._scaling_factor, ea._scaling_factor**3) self.assertAllClose( a * b * b, tf_shell.to_tensorflow(ed, test_context.key), atol=1e-3 ) diff --git a/tf_shell/test/test_utils.py b/tf_shell/test/test_utils.py index b8be8ba..dd74b09 100644 --- a/tf_shell/test/test_utils.py +++ b/tf_shell/test/test_utils.py @@ -189,7 +189,8 @@ def uniform_for_n_muls(test_context, num_muls, shape=None, subsequent_adds=0): min_val, max_val = get_bounds_for_n_muls(test_context, num_muls) - subsequent_adds = tf.cast(subsequent_adds, min_val.dtype) + if hasattr(min_val, "dtype"): + subsequent_adds = tf.cast(subsequent_adds, min_val.dtype) min_val = min_val / (subsequent_adds + 1) max_val = max_val / (subsequent_adds + 1) diff --git a/tf_shell_ml/dpsgd_sequential_model.py b/tf_shell_ml/dpsgd_sequential_model.py index df7e03a..8164556 100644 --- a/tf_shell_ml/dpsgd_sequential_model.py +++ b/tf_shell_ml/dpsgd_sequential_model.py @@ -111,6 +111,7 @@ def train_step(self, data): with tf.device(self.labels_party_dev): if self.disable_encryption: enc_y = y + public_backprop_rotation_key = None else: backprop_context = self.backprop_context_fn() backprop_secret_key = tf_shell.create_key64(