Skip to content

Commit

Permalink
More efficient use of plaintext moduli.
Browse files Browse the repository at this point in the history
Ciphertext plaintext multiplication does not require squaring the scaling factor.
  • Loading branch information
james-choncholas committed Oct 26, 2024
1 parent e9d7a83 commit 0f7a688
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 134 deletions.
3 changes: 3 additions & 0 deletions tf_shell/cc/ops/shell_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ REGISTER_OP("PolynomialExport64")
.Input("shell_context: variant")
.Input("val: variant")
.Input("runtime_batching_dim: int64")
.Attr("final_scaling_factor: int")
.Output("out: dtype")
.SetShapeFn(ExportAndAddBatchingDimShape<1>);

Expand All @@ -94,6 +95,7 @@ REGISTER_OP("Decrypt64")
.Input("key: variant")
.Input("val: variant")
.Input("runtime_batching_dim: int64")
.Attr("final_scaling_factor: int")
.Output("out: dtype")
.SetShapeFn(ExportAndAddBatchingDimShape<2>);

Expand Down Expand Up @@ -298,6 +300,7 @@ REGISTER_OP("DecryptFastRotated64")
.Input("fast_rotation_key: variant")
.Input("val: variant")
.Input("runtime_batching_dim: int64")
.Attr("final_scaling_factor: int")
.Output("out: dtype")
.SetShapeFn(ExportAndAddBatchingDimShape<2>);

Expand Down
140 changes: 37 additions & 103 deletions tf_shell/cc/optimizers/moduli_autotune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ namespace {

constexpr bool const qs_mod_t_is_one = true;

constexpr bool const debug_moduli = true;
constexpr bool const debug_mul_depth = false;
constexpr bool const debug_moduli = false;
constexpr bool const debug_noise_estimation = false;
constexpr bool const debug_graph = false;
constexpr bool const debug_output_params = true;
Expand Down Expand Up @@ -109,8 +108,7 @@ Status AddScalarConstNode(T value, utils::Mutation* mutation,
} else {
[]<bool flag = false>() {
static_assert(flag, "AddScalarConstNode does not support this type");
}
();
}();
}
tensor->set_allocated_tensor_shape(tensor_shape.release());
(*node.mutable_attr())["value"].set_allocated_tensor(tensor.release());
Expand Down Expand Up @@ -172,7 +170,8 @@ Status GetAutoShellContextParams(utils::MutableNodeView* autocontext,
StatusOr<bool> DecryptUsesSameContext(utils::MutableNodeView const* node_view,
utils::MutableNodeView const* context) {
utils::MutableNodeView const* trace = node_view;
if (trace == nullptr || !IsDecrypt(*trace->node())) {
if (trace == nullptr ||
!(IsDecrypt(*trace->node()) || IsDecode(*trace->node()))) {
return errors::InvalidArgument(
"Expected the node to be a decrypt node, but found ", trace->GetOp());
}
Expand Down Expand Up @@ -216,102 +215,37 @@ StatusOr<bool> DecryptUsesSameContext(utils::MutableNodeView const* node_view,
return true;
}

StatusOr<int> GetMulDepth(utils::MutableGraphView& graph_view,
utils::MutableNodeView const* autocontext) {
// Traverse the graph and return the maximum multiplicative depth.
StatusOr<int64_t> MaxScalingFactor(utils::MutableGraphView& graph_view,
utils::MutableNodeView const* autocontext) {
// Traverse the graph and return the maximum scaling factor across all decrypt
// ops tied to this autocontext.
int const num_nodes = graph_view.NumNodes();
std::vector<uint64_t> node_mul_depth(num_nodes);

if constexpr (debug_mul_depth) {
std::cout << "Calculating multiplicative depth." << std::endl;
}
int64_t max_sf = 0;

uint64_t max_depth = 0;
for (int i = 0; i < num_nodes; ++i) {
auto const* this_node_view = graph_view.GetNode(i);
auto const* this_node_def = this_node_view->node();

if (IsArithmetic(*this_node_def) || IsTfShellMatMul(*this_node_def) ||
IsMulCtTfScalar(*this_node_def) || IsMulPtTfScalar(*this_node_def)) {
// Get the fanin nodes.
int const fanin_a_index = this_node_view->GetRegularFanin(1).node_index();
int const fanin_b_index = this_node_view->GetRegularFanin(2).node_index();
int max_fanin_depth = std::max(node_mul_depth[fanin_a_index],
node_mul_depth[fanin_b_index]);

// Update the multiplicative depth of this node.
if (IsMulCtCt(*this_node_def) || IsMulCtPt(*this_node_def) ||
IsMulPtPt(*this_node_def) || IsTfShellMatMul(*this_node_def)) {
node_mul_depth[i] = max_fanin_depth + 1;
} else {
node_mul_depth[i] = max_fanin_depth;
}
}

else if (IsNegCt(*this_node_def) ||
IsFastReduceSumByRotation(*this_node_def) ||
IsUnsortedCtSegmentSum(*this_node_def)) {
int const fanin_a_index = this_node_view->GetRegularFanin(1).node_index();
node_mul_depth[i] = node_mul_depth[fanin_a_index];
}

else if (IsRoll(*this_node_def) || IsReduceSumByRotation(*this_node_def)) {
int const fanin_a_index = this_node_view->GetRegularFanin(2).node_index();
node_mul_depth[i] = node_mul_depth[fanin_a_index];
}

else if (IsConv2d(*this_node_def)) {
int const fanin_a_index = this_node_view->GetRegularFanin(1).node_index();
int const fanin_b_index = this_node_view->GetRegularFanin(2).node_index();
node_mul_depth[i] = std::max(node_mul_depth[fanin_a_index],
node_mul_depth[fanin_b_index]) +
1;
}

else if (IsMaxUnpool2d(*this_node_def)) {
// Max unpool performs a multiplication but does not affect the scaling
// factor since it is by a selection 0 or 1 plaintext.
int const fanin_a_index = this_node_view->GetRegularFanin(1).node_index();
node_mul_depth[i] = node_mul_depth[fanin_a_index];
}

else if (IsExpandDimsVariant(*this_node_def)) {
int const fanin_a_index = this_node_view->GetRegularFanin(0).node_index();
node_mul_depth[i] = node_mul_depth[fanin_a_index];
} else if (IsBroadcastToShape(*this_node_def)) {
int const fanin_a_index = this_node_view->GetRegularFanin(0).node_index();
node_mul_depth[i] = node_mul_depth[fanin_a_index];
} else if (IsReshape(*this_node_def)) {
int const fanin_a_index = this_node_view->GetRegularFanin(0).node_index();
node_mul_depth[i] = node_mul_depth[fanin_a_index];
} else if (IsStridedSlice(*this_node_def)) {
int const fanin_a_index = this_node_view->GetRegularFanin(0).node_index();
node_mul_depth[i] = node_mul_depth[fanin_a_index];

// Decryption is where the maximum multiplicative depth is reached.
} else if (IsDecrypt(*this_node_def)) {
int const fanin_a_index = this_node_view->GetRegularFanin(2).node_index();
node_mul_depth[i] = node_mul_depth[fanin_a_index];

// Decryption is where the maximum multiplicative depth is reached.
if (IsDecrypt(*this_node_def) || IsDecode(*this_node_def)) {
// Ensure the decrypt op uses the same autocontext node as the argument
// (for the case where there are multiple autocontext nodes in the graph).
TF_ASSIGN_OR_RETURN(bool is_same_autocontext,
DecryptUsesSameContext(this_node_view, autocontext));
if (is_same_autocontext) {
max_depth = std::max(max_depth, node_mul_depth[i]);

int64 sf = 0;
if (!TryGetNodeAttr(*this_node_def, "final_scaling_factor", &sf)) {
std::cout
<< "WARNING: Could not determine scaling factor in Decrypt op."
<< " Plaintext modulus may be underprovisioned." << std::endl;
}
}

if constexpr (debug_mul_depth) {
std::cout << " Node: " << this_node_def->name()
<< " Depth: " << node_mul_depth[i]
<< " Max depth: " << max_depth << std::endl;
if (is_same_autocontext) {
max_sf = std::max(max_sf, sf);
}
}
}
if constexpr (debug_mul_depth) {
std::cout << "Max Multiplicative Depth: " << max_depth << std::endl;
}
return max_depth;
return max_sf;
}

// Function for modular exponentiation
Expand Down Expand Up @@ -862,12 +796,12 @@ Status ReplaceAutoparamWithContext(utils::MutableGraphView& graph_view,
}

// Create the new inputs for the ShellContext node.
std::string log_n_name = "ContextImport64/log_n";
std::string qs_name = "ContextImport64/main_moduli";
std::string ps_name = "ContextImport64/aux_moduli";
std::string t_name = "ContextImport64/plaintext_modulus";
std::string noise_var_name = "ContextImport64/noise_variance";
std::string seed_str_name = "ContextImport64/seed";
std::string log_n_name = autocontext->GetName() + "/log_n";
std::string qs_name = autocontext->GetName() + "/main_moduli";
std::string ps_name = autocontext->GetName() + "/aux_moduli";
std::string t_name = autocontext->GetName() + "/plaintext_modulus";
std::string noise_var_name = autocontext->GetName() + "/noise_variance";
std::string seed_str_name = autocontext->GetName() + "/seed";

utils::Mutation* mutation = graph_view.GetMutationBuilder();
std::string device = autocontext->GetDevice();
Expand Down Expand Up @@ -942,18 +876,18 @@ Status OptimizeAutocontext(utils::MutableGraphView& graph_view,
ShellAutoParams auto_params;
TF_RETURN_IF_ERROR(GetAutoShellContextParams(autocontext, auto_params));

// Find the maximum multiplicative depth of the graph and use this to set
// the plaintext modulus t, based on the scaling factor and depth.
// The depth computation includes the initial scaling factor included during
// encryption.
TF_ASSIGN_OR_RETURN(int mul_depth, GetMulDepth(graph_view, autocontext));
uint64_t mul_bits =
BitWidth(std::pow(auto_params.scaling_factor, std::pow(2, mul_depth)));
uint64_t total_plaintext_bits = auto_params.cleartext_bits + mul_bits;
// Find the maximum scaling factor used in the graph and use this to set the
// plaintext modulus t. The maximum plaintext bits * the maximum scaling
// factor is the maximum size. This is too conservative an estimate, so
// for now, it is disabled.
// TF_ASSIGN_OR_RETURN(int64_t max_sf,
// MaxScalingFactor(graph_view, autocontext));
// uint64_t sf_bits = BitWidth(max_sf);
uint64_t total_plaintext_bits = auto_params.cleartext_bits;

if constexpr (debug_moduli) {
std::cout << "Multiplicative Depth: " << mul_depth << std::endl;
std::cout << "Bits of scaling factor: " << mul_bits << std::endl;
// std::cout << "Max bits of scaling factor upon decryption: " << sf_bits
// << std::endl;
std::cout << "Total Cleartext Bits: " << total_plaintext_bits << std::endl;
}
if (total_plaintext_bits >= kMaxPrimeBitsCiphertext) {
Expand Down
50 changes: 25 additions & 25 deletions tf_shell/python/shell_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import tensorflow as tf
import tf_shell.python.shell_ops as shell_ops
from tf_shell.python.shell_context import ShellContext64
Expand Down Expand Up @@ -321,7 +322,7 @@ def __neg__(self):

def __mul__(self, other):
if isinstance(other, ShellTensor64):
matched_self, matched_other = _match_moduli_and_scaling(self, other)
matched_self, matched_other = _match_moduli(self, other)

if self.is_encrypted and other.is_encrypted:
if self._is_fast_rotated or other._is_fast_rotated:
Expand Down Expand Up @@ -360,7 +361,7 @@ def __mul__(self, other):
_level=matched_self._level,
_num_mod_reductions=matched_self._num_mod_reductions,
_underlying_dtype=self._underlying_dtype,
_scaling_factor=matched_self._scaling_factor**2,
_scaling_factor=matched_self._scaling_factor * matched_other._scaling_factor,
_is_enc=self._is_enc or other._is_enc,
_is_fast_rotated=self._is_fast_rotated or other._is_fast_rotated,
)
Expand Down Expand Up @@ -471,33 +472,29 @@ def mod_reduce_tensor64(shell_tensor):

return reduced_self


def _match_moduli_and_scaling(x, y):
with tf.name_scope("match_moduli_and_scaling"):
def _match_moduli(x, y):
with tf.name_scope("match_moduli"):
# Mod switch to the smaller modulus of the two.
while x._num_mod_reductions < y._num_mod_reductions:
x = mod_reduce_tensor64(x)
while x._num_mod_reductions > y._num_mod_reductions:
y = mod_reduce_tensor64(y)

# Make sure the scaling factors are compatible. This should always be
# true unless the user is mixing contexts with different scaling
# factors.
xsf = x._scaling_factor
ysf = y._scaling_factor
while xsf < ysf:
xsf *= x._context.scaling_factor
while ysf < xsf:
ysf *= y._context.scaling_factor
if xsf != ysf:
raise ValueError(f"Scaling factors must be compatible. Got {xsf} and {ysf}")
return x, y

def _match_moduli_and_scaling(x, y):
with tf.name_scope("match_moduli_and_scaling"):
x, y = _match_moduli(x, y)

# Match the scaling factors.
if xsf > x._scaling_factor:
x = x.__mul__(xsf / (x._scaling_factor**2))
if ysf > y._scaling_factor:
y = y.__mul__(ysf / (y._scaling_factor**2))
gcd = math.gcd(x._scaling_factor, y._scaling_factor)
lcm = math.lcm(x._scaling_factor, y._scaling_factor)

# Match the scaling factors.
if lcm > x._scaling_factor:
x = x.__mul__(gcd / x._scaling_factor)
if lcm > y._scaling_factor:
y = y.__mul__(gcd / y._scaling_factor)

return x, y


Expand Down Expand Up @@ -663,6 +660,7 @@ def to_tensorflow(s_tensor, key=None):
runtime_batching_dim=s_tensor._context.num_slots,
dtype=shell_dtype,
batching_dim=batching_dim,
final_scaling_factor=s_tensor._scaling_factor,
)

elif s_tensor.is_encrypted:
Expand All @@ -679,6 +677,7 @@ def to_tensorflow(s_tensor, key=None):
runtime_batching_dim=s_tensor._context.num_slots,
dtype=shell_dtype,
batching_dim=batching_dim,
final_scaling_factor=s_tensor._scaling_factor,
)

elif not s_tensor.is_encrypted:
Expand All @@ -690,6 +689,7 @@ def to_tensorflow(s_tensor, key=None):
runtime_batching_dim=s_tensor._context.num_slots,
dtype=shell_dtype,
batching_dim=batching_dim,
final_scaling_factor=s_tensor._scaling_factor,
)

else:
Expand Down Expand Up @@ -894,7 +894,7 @@ def matmul(x, y, rotation_key=None, pt_ct_reduction="galois"):
)

# Encode the plaintext x to the same scaling factor as y.
scaled_x = _encode_scaling(x, y._scaling_factor)
scaled_x = _encode_scaling(x, y._context.scaling_factor)

if pt_ct_reduction == "galois":
if not isinstance(rotation_key, ShellRotationKey64):
Expand Down Expand Up @@ -926,7 +926,7 @@ def matmul(x, y, rotation_key=None, pt_ct_reduction="galois"):
_level=y._level,
_num_mod_reductions=y._num_mod_reductions,
_underlying_dtype=y._underlying_dtype,
_scaling_factor=y._scaling_factor**2,
_scaling_factor=y._scaling_factor * y._context.scaling_factor,
_is_enc=True,
_is_fast_rotated=pt_ct_reduction == "fast",
)
Expand Down Expand Up @@ -1112,7 +1112,7 @@ def _conv2d(x, filt, strides, padding, dilations, func):
"A ShellTensor which has been fast-rotated or fast-reduced-summed cannot be an input to conv2d."
)

matched_x, matched_filt = _match_moduli_and_scaling(x, filt)
matched_x, matched_filt = _match_moduli(x, filt)

return ShellTensor64(
_raw_tensor=func(
Expand All @@ -1128,7 +1128,7 @@ def _conv2d(x, filt, strides, padding, dilations, func):
_level=matched_x._level,
_num_mod_reductions=matched_x._num_mod_reductions,
_underlying_dtype=matched_x._underlying_dtype,
_scaling_factor=matched_x._scaling_factor**2,
_scaling_factor=matched_x._scaling_factor * matched_filt._scaling_factor,
_is_enc=True,
_is_fast_rotated=False,
)
Expand Down
16 changes: 11 additions & 5 deletions tf_shell/test/composite_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,12 @@ def _test_ct_ct_mulmul(self, test_context):
)

# Here, ec has a mul depth of 1 while eb has a mul depth of 0. To
# multiply them, eb needs to be mod reduced to match ec. ShellTensor
# should handle this automatically.
# multiply them, tf-shell needs to account for the difference in scaling
# factors. For ct_ct multiplication, the scaling factors must match, so
# in this case eb will be scaled up with a ct_pt multiplication to match
# ec. tf-shell will handle this automatically.
ed = ec * eb
self.assertEqual(ed._scaling_factor, (ea._scaling_factor**2)**2)
self.assertAllClose(
a * b * b, tf_shell.to_tensorflow(ed, test_context.key), atol=1e-3
)
Expand Down Expand Up @@ -132,10 +135,13 @@ def _test_ct_pt_mulmul(self, test_context):
)

# Here, ec has a mul depth of 1 while b is has a mul depth of 0. To
# multiply them, b needs to be encoded as a shell plaintext with
# moduli which match the now-mod-reduced ec. ShellTensor should handle
# this automatically.
# multiply them, tf-shell needs to account for the difference in
# scaling factors. For ct_pt multiplication, the scaling factors do
# not need to match, but their product must be remembered and divided
# out when the result is decrypted. tf-shell will handle this
# automatically.
ed = ec * b
self.assertEqual(ed._scaling_factor, ea._scaling_factor**3)
self.assertAllClose(
a * b * b, tf_shell.to_tensorflow(ed, test_context.key), atol=1e-3
)
Expand Down
3 changes: 2 additions & 1 deletion tf_shell/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ def uniform_for_n_muls(test_context, num_muls, shape=None, subsequent_adds=0):

min_val, max_val = get_bounds_for_n_muls(test_context, num_muls)

subsequent_adds = tf.cast(subsequent_adds, min_val.dtype)
if hasattr(min_val, "dtype"):
subsequent_adds = tf.cast(subsequent_adds, min_val.dtype)
min_val = min_val / (subsequent_adds + 1)
max_val = max_val / (subsequent_adds + 1)

Expand Down
Loading

0 comments on commit 0f7a688

Please sign in to comment.