diff --git a/shell_tensor/python/shell_tensor.py b/shell_tensor/python/shell_tensor.py index 7cdd999..0dbb3ee 100644 --- a/shell_tensor/python/shell_tensor.py +++ b/shell_tensor/python/shell_tensor.py @@ -18,6 +18,54 @@ import shell_tensor.shell as raw_bindings +class ShellContext64(object): + def __init__( + self, + shell_context, + log_n, + main_moduli, + aux_moduli, + plaintext_modulus, + noise_variance, + seed, + ): + self._raw_context = shell_context + self.log_n = log_n + self.num_slots = 2**log_n + self.main_moduli = main_moduli + self.aux_moduli = aux_moduli + self.plaintext_modulus = plaintext_modulus + self.noise_variance = noise_variance + if noise_variance % 2 == 0: + self.noise_bits = noise_variance.bit_length() + else: + self.noise_bits = noise_variance.bit_length() + 1 + self.seed = seed + + +def create_context64( + log_n, main_moduli, aux_moduli, plaintext_modulus, noise_variance, seed="" +): + shell_context = shell_ops.context_import64( + log_n=log_n, + main_moduli=main_moduli, + aux_moduli=aux_moduli, + plaintext_modulus=plaintext_modulus, + noise_variance=noise_variance, + seed=seed, + ) + + return ShellContext64( + shell_context=shell_context, + log_n=log_n, + main_moduli=main_moduli, + aux_moduli=aux_moduli, + plaintext_modulus=plaintext_modulus, + noise_variance=noise_variance, + seed=seed, + ) + + class ShellTensor64(object): is_tensor_like = True # needed to pass tf.is_tensor, new as of TF 2.2+ @@ -25,11 +73,11 @@ def __init__( self, value, context, - num_slots, underlying_dtype, - is_enc=False, - fxp_fractional_bits=0, - mult_count=0, + is_enc, + fxp_fractional_bits, + mul_count, + noise_bit_count, ): assert isinstance( value, tf.Tensor @@ -37,9 +85,9 @@ def __init__( assert ( value.dtype is tf.variant ), f"Should be variant tensor, instead got {value.dtype}" + assert isinstance(context, ShellContext64), f"Should be ShellContext64" self._raw = value self._context = context - self._num_slots = num_slots self._underlying_dtype = underlying_dtype self._is_enc = is_enc @@ -54,24 +102,26 @@ def __init__( # ShellTensor keeps track of the number of multiplications that have # occured, then right shift for all multiplications together at the end. # This requires keeping track of the number of multiplications that have - # occured. This is the mult_count parameter. + # occured. This is the mul_count parameter. # # Note that now adding/subtracting/multiplying two ShellTensors together # is more complicated as each could have a different number of # multiplications, and thus a different number of fractional bits. The - # operand with fewer mult_count must be scaled up by - # 2**(difference_in_mult_count * fractional_bits) to match the other + # operand with fewer mul_count must be scaled up by + # 2**(difference_in_mul_count * fractional_bits) to match the other # operand. _self_at_fxp_multiplier caches scaled up versions of itself - # for each previously requested mult_count. + # for each previously requested mul_count. self._fxp_fractional_bits = fxp_fractional_bits - self._mult_count = ( - mult_count # number of preceeding multiplications resulting in self. + self._mul_count = ( + mul_count # number of preceding multiplications resulting in self. ) - self._self_at_fxp_multiplier = {mult_count: self} + self._self_at_fxp_multiplier = {mul_count: self} + + self._noise_bit_count = noise_bit_count @property def shape(self): - return self._num_slots + self._raw.shape + return [self._context.num_slots] + self._raw.shape @property def name(self): @@ -89,21 +139,42 @@ def plaintext_dtype(self): def is_encrypted(self): return self._is_enc + @property + def noise_bits(self): + return self._noise_bit_count + 1 + + def __getitem__(self, slice): + slots = slice[0] + if slots.start != None or slots.stop != None or slots.step != None: + raise ValueError( + f"ShellTensor does not support intra-slot slicing. Be sure to use `:` on the first dimension. Got {slice}" + ) + return ShellTensor64( + value=self._raw[slice[1:]], + context=self._context, + underlying_dtype=self._underlying_dtype, + is_enc=self.is_encrypted, + fxp_fractional_bits=self._fxp_fractional_bits, + mul_count=self._mul_count, + noise_bit_count=self._noise_bit_count, + ) + + @property def num_fxp_fractional_bits(self): - return self._fxp_fractional_bits * (2**self._mult_count) + return self._fxp_fractional_bits * (2**self._mul_count) def get_encrypted(self, key): if self._is_enc: return self else: return ShellTensor64( - value=shell_ops.encrypt64(self._context, key, self._raw), + value=shell_ops.encrypt64(self._context._raw_context, key, self._raw), context=self._context, - num_slots=self._num_slots, underlying_dtype=self._underlying_dtype, is_enc=True, fxp_fractional_bits=self._fxp_fractional_bits, - mult_count=self._mult_count, + mul_count=self._mul_count, + noise_bit_count=self._noise_bit_count, ) def get_decrypted(self, key=None): @@ -119,7 +190,7 @@ def get_decrypted(self, key=None): # Decrypt op returns a tf Tensor. tf_tensor = shell_ops.decrypt64( - self._context, + self._context._raw_context, key, self._raw, dtype=fxp_dtype, @@ -128,15 +199,15 @@ def get_decrypted(self, key=None): # Convert out of fixed point to the underlying dtype. return _from_fixed_point( tf_tensor, - self.num_fxp_fractional_bits(), + self.num_fxp_fractional_bits, self._underlying_dtype, ) def __add__(self, other): if isinstance(other, ShellTensor64): - max_mult_count = max(self._mult_count, other._mult_count) - matched_other = other.get_at_multiplication_count(max_mult_count) - matched_self = self.get_at_multiplication_count(max_mult_count) + max_mul_count = max(self._mul_count, other._mul_count) + matched_other = other.get_at_multiplication_count(max_mul_count) + matched_self = self.get_at_multiplication_count(max_mul_count) if self.is_encrypted and other.is_encrypted: result_raw = shell_ops.add_ct_ct64( @@ -152,7 +223,7 @@ def __add__(self, other): ) elif not self.is_encrypted and not other.is_encrypted: result_raw = shell_ops.add_pt_pt64( - self._context, matched_self._raw, matched_other._raw + self._context._raw_context, matched_self._raw, matched_other._raw ) else: raise ValueError("Invalid operands") @@ -160,11 +231,11 @@ def __add__(self, other): return ShellTensor64( value=result_raw, context=self._context, - num_slots=self._num_slots, underlying_dtype=self._underlying_dtype, is_enc=self._is_enc or other._is_enc, fxp_fractional_bits=self._fxp_fractional_bits, - mult_count=max_mult_count, + mul_count=max_mul_count, + noise_bit_count=self._noise_bit_count + 1, ) elif isinstance(other, tf.Tensor): @@ -186,9 +257,9 @@ def __radd__(self, other): def __sub__(self, other): if isinstance(other, ShellTensor64): - max_mult_count = max(self._mult_count, other._mult_count) - matched_other = other.get_at_multiplication_count(max_mult_count) - matched_self = self.get_at_multiplication_count(max_mult_count) + max_mul_count = max(self._mul_count, other._mul_count) + matched_other = other.get_at_multiplication_count(max_mul_count) + matched_self = self.get_at_multiplication_count(max_mul_count) if self.is_encrypted and other.is_encrypted: result_raw = shell_ops.sub_ct_ct64( @@ -205,7 +276,7 @@ def __sub__(self, other): ) elif not self.is_encrypted and not other.is_encrypted: result_raw = shell_ops.sub_pt_pt64( - self._context, matched_self._raw, matched_other._raw + self._context._raw_context, matched_self._raw, matched_other._raw ) else: raise ValueError("Invalid operands") @@ -213,11 +284,11 @@ def __sub__(self, other): return ShellTensor64( value=result_raw, context=self._context, - num_slots=self._num_slots, underlying_dtype=self._underlying_dtype, is_enc=self._is_enc or other._is_enc, fxp_fractional_bits=self._fxp_fractional_bits, - mult_count=max_mult_count, + mul_count=max_mul_count, + noise_bit_count=self._noise_bit_count + 1, ) elif isinstance(other, tf.Tensor): # TODO(jchoncholas): Subtracting a scalar uses a special op that is @@ -238,7 +309,7 @@ def __rsub__(self, other): self._context, other, self._fxp_fractional_bits ) matched_shell_other = shell_other.get_at_multiplication_count( - self._mult_count + self._mul_count ) if self.is_encrypted: @@ -248,17 +319,17 @@ def __rsub__(self, other): ) else: raw_result = shell_ops.sub_pt_pt64( - self._context, matched_shell_other._raw, self._raw + self._context._raw_context, matched_shell_other._raw, self._raw ) return ShellTensor64( value=raw_result, context=self._context, - num_slots=self._num_slots, underlying_dtype=self._underlying_dtype, is_enc=self._is_enc, fxp_fractional_bits=self._fxp_fractional_bits, - mult_count=self._mult_count, + mul_count=self._mul_count, + noise_bit_count=self._noise_bit_count + 1, ) else: return NotImplemented @@ -267,23 +338,23 @@ def __neg__(self): if self.is_encrypted: raw_result = shell_ops.neg_ct64(self._raw) else: - raw_result = shell_ops.neg_pt64(self._context, self._raw) + raw_result = shell_ops.neg_pt64(self._context._raw_context, self._raw) return ShellTensor64( value=raw_result, context=self._context, - num_slots=self._num_slots, underlying_dtype=self._underlying_dtype, is_enc=self._is_enc, fxp_fractional_bits=self._fxp_fractional_bits, - mult_count=self._mult_count, + mul_count=self._mul_count, + noise_bit_count=self._noise_bit_count + 1, ) def __mul__(self, other): if isinstance(other, ShellTensor64): - max_mult_count = max(self._mult_count, other._mult_count) - matched_other = other.get_at_multiplication_count(max_mult_count) - matched_self = self.get_at_multiplication_count(max_mult_count) + max_mul_count = max(self._mul_count, other._mul_count) + matched_other = other.get_at_multiplication_count(max_mul_count) + matched_self = self.get_at_multiplication_count(max_mul_count) if self.is_encrypted and other.is_encrypted: raw_result = shell_ops.mul_ct_ct64( @@ -299,7 +370,7 @@ def __mul__(self, other): ) elif not self.is_encrypted and not other.is_encrypted: raw_result = shell_ops.mul_pt_pt64( - self._context, matched_self._raw, matched_other._raw + self._context._raw_context, matched_self._raw, matched_other._raw ) else: raise ValueError("Invalid operands") @@ -307,38 +378,39 @@ def __mul__(self, other): return ShellTensor64( value=raw_result, context=self._context, - num_slots=self._num_slots, underlying_dtype=self._underlying_dtype, is_enc=self._is_enc or other._is_enc, fxp_fractional_bits=self._fxp_fractional_bits, - mult_count=max_mult_count + 1, + mul_count=max_mul_count + 1, + noise_bit_count=matched_self._noise_bit_count + + matched_other._noise_bit_count, ) elif isinstance(other, tf.Tensor): # Multiplying by a scalar uses a special op which is more efficient # than the caller creating creating a ShellTensor the same # dimensions as self and multiplying. if other.shape == []: - # Convert other to fixed point. Using num_fxp_fractional_bits() + # Convert other to fixed point. Using num_fxp_fractional_bits # ensure's it has the same number of fractional bits as self, # taking multiplicative depth into account. - fxp_tensor = _to_fixed_point(other, self.num_fxp_fractional_bits()) + fxp_tensor = _to_fixed_point(other, self.num_fxp_fractional_bits) if self.is_encrypted: raw_result = shell_ops.mul_ct_tf_scalar64( - self._context, self._raw, fxp_tensor + self._context._raw_context, self._raw, fxp_tensor ) else: raw_result = shell_ops.mul_pt_tf_scalar64( - self._context, self._raw, fxp_tensor + self._context._raw_context, self._raw, fxp_tensor ) return ShellTensor64( value=raw_result, context=self._context, - num_slots=self._num_slots, underlying_dtype=self._underlying_dtype, is_enc=self._is_enc, fxp_fractional_bits=self._fxp_fractional_bits, - mult_count=self._mult_count + 1, + mul_count=self._mul_count + 1, + noise_bit_count=self._noise_bit_count + self._context.noise_bits, ) # Lift tensorflow tensor to shell tensor before multiplication. @@ -351,23 +423,23 @@ def __mul__(self, other): def __rmul__(self, other): return self * other - def roll(self, rotation_key, num_slots): + def roll(self, rotation_key, shift): if not self._is_enc: raise ValueError("Unencrypted ShellTensor rotation not supported yet.") else: - num_slots = tf.cast(num_slots, tf.int64) + shift = tf.cast(shift, tf.int64) return ShellTensor64( - value=shell_ops.roll64(rotation_key, self._raw, num_slots), + value=shell_ops.roll64(rotation_key, self._raw, shift), context=self._context, - num_slots=self._num_slots, underlying_dtype=self._underlying_dtype, is_enc=True, fxp_fractional_bits=self._fxp_fractional_bits, - mult_count=self._mult_count, + mul_count=self._mul_count, + noise_bit_count=self._noise_bit_count + 1, # TODO correct? ) - def reduce_sum(self, axis=0, rotation_key=None): + def reduce_sum(self, axis, rotation_key=None): if not self._is_enc: raise ValueError("Unencrypted ShellTensor reduce_sum not supported yet.") # Check axis is a scalar @@ -379,84 +451,101 @@ def reduce_sum(self, axis=0, rotation_key=None): "Rotation key must be provided to reduce_sum over axis 0." ) + # reduce sum does log2(num_slots) rotations and additions. + # TODO: add noise from rotations? + result_noise_bits = ( + self._noise_bit_count + self._context.num_slots.bit_length() + 1, + ) + return ShellTensor64( value=shell_ops.reduce_sum_by_rotation64(self._raw, rotation_key), context=self._context, - num_slots=self._num_slots, underlying_dtype=self._underlying_dtype, is_enc=True, fxp_fractional_bits=self._fxp_fractional_bits, - mult_count=self._mult_count, + mul_count=self._mul_count, + noise_bit_count=result_noise_bits, ) else: if axis >= len(self.shape): raise ValueError("Axis greater than number of dimensions") + + result_noise_bits = ( + self._noise_bit_count + self.shape[axis].bit_length() + 1 + ) + return ShellTensor64( value=shell_ops.reduce_sum64(self._raw, axis), context=self._context, - num_slots=self._num_slots, underlying_dtype=self._underlying_dtype, is_enc=True, fxp_fractional_bits=self._fxp_fractional_bits, - mult_count=self._mult_count, + mul_count=self._mul_count, + noise_bit_count=result_noise_bits, ) - def get_at_multiplication_count(self, mult_count): - """Returns a ShellTensor whose values have been left shifted to match - the specified multiplicative depth. Fixed point multiplication doubles - the number of fractional bits. Since ShellTensor does not right shift - after multiplication, (only after decryption) two ShellTensors may have - different number of fractional bits. This function will return a new - ShellTensor with the same value as itself but left shifted to match the - specified multiplicative depth. + def get_at_multiplication_count(self, mul_count): + """Returns a ShellTensor whose values have been left or right shifted to + match the specified multiplicative depth for a given number of fixed + point fractional bits. Fixed point multiplication doubles the number of + fractional bits. Since ShellTensor does not right shift after every + multiplication, (only required after decryption) two ShellTensors may + have different number of fractional bits and they cannot be for example + added together. This function will return a new ShellTensor with the + same value as itself but shifted to match the specified multiplicative + depth. This function will cache the result for future calls with the same multiplicative depth. - For now, this function only supports increasing the number of - multiplications. Decreasing the number of multiplications is possible - and would require right shifting by 2^(-x) mod t, a multiplicative - inverse in the plaintext modulus's field. + Note that right shifting will lose precision which could otherwise be + recovered on decrypting to a floating point data type. """ - if self._mult_count > mult_count: - raise ValueError("Cannot reduce multiplication count of ShellTensor64.") - elif mult_count in self._self_at_fxp_multiplier: - return self._self_at_fxp_multiplier[mult_count] - else: - num_mul_to_do = mult_count - self._mult_count - - wanted_fxp_fractional_bits = self._fxp_fractional_bits * ( - 2**num_mul_to_do - ) - needed_fxp_fractional_bits = ( - wanted_fxp_fractional_bits - self._fxp_fractional_bits - ) + if mul_count < 0: + raise ValueError("mul_count must be non-negative.") + elif mul_count in self._self_at_fxp_multiplier: + return self._self_at_fxp_multiplier[mul_count] + + num_mul_to_do = mul_count - self._mul_count + wanted_frac_bits = self._fxp_fractional_bits * (2**num_mul_to_do) + needed_frac_bits = wanted_frac_bits - self._fxp_fractional_bits + + if needed_frac_bits > 0: + # If we are left shifting, multiply by the number of + # multiplications required. + fxp_multiplier = tf.constant(2**needed_frac_bits, dtype=tf.int64) + elif needed_frac_bits < 0: + # If we are right shifting, need to use the multiplicative + # inverse in the plaintext modulus's field. + pt_modulus = self._context.plaintext_modulus fxp_multiplier = tf.constant( - 2**needed_fxp_fractional_bits, dtype=tf.int64 + pow(2, int(-needed_frac_bits), int(pt_modulus)), dtype=tf.int64 ) + else: + raise IndexError("Asked for self mul_count.") - # Perform the multiplication. - if self.is_encrypted: - raw_result = shell_ops.mul_ct_tf_scalar64( - self._context, self._raw, fxp_multiplier - ) - else: - raw_result = shell_ops.mul_pt_tf_scalar64( - self._context, self._raw, fxp_multiplier - ) - left_shifted = ShellTensor64( - value=raw_result, - context=self._context, - num_slots=self._num_slots, - underlying_dtype=self._underlying_dtype, - is_enc=self._is_enc, - fxp_fractional_bits=self._fxp_fractional_bits, - mult_count=mult_count, # Override the mult_count, may be higher - # than self._mult_count + 1 if mult_count_to_do was larger than - # 1. + # Perform the multiplication. + if self.is_encrypted: + raw_result = shell_ops.mul_ct_tf_scalar64( + self._context._raw_context, self._raw, fxp_multiplier ) - self._self_at_fxp_multiplier[mult_count] = left_shifted - return left_shifted + else: + raw_result = shell_ops.mul_pt_tf_scalar64( + self._context._raw_context, self._raw, fxp_multiplier + ) + shifted = ShellTensor64( + value=raw_result, + context=self._context, + underlying_dtype=self._underlying_dtype, + is_enc=self._is_enc, + fxp_fractional_bits=self._fxp_fractional_bits, + mul_count=mul_count, # Override the mul_count, may be higher + # than self._mul_count + 1 if mult_count_to_do was larger than + # 1. + noise_bit_count=self._noise_bit_count * 2, + ) + self._self_at_fxp_multiplier[mul_count] = shifted + return shifted # This class uses a fixed point dtype depending on the dtype of tensorflow @@ -477,8 +566,9 @@ def _to_fixed_point(tf_tensor, fxp_fractional_bits): if tf_tensor.dtype in [tf.float32, tf.float64]: fxp_fractional_multiplier = 2**fxp_fractional_bits integer = tf.cast(tf_tensor, tf.int64) * fxp_fractional_multiplier - fractional = tf.cast(tf.math.mod(tf_tensor, 1), tf.int64) - return integer + fractional + fractional = tf_tensor - tf.cast(tf.cast(tf_tensor, tf.int64), tf_tensor.dtype) + fractional_quantized = tf.cast(fractional * fxp_fractional_multiplier, tf.int64) + return integer + fractional_quantized elif tf_tensor.dtype in [tf.uint8, tf.uint16, tf.uint32, tf.uint64]: # If the tensor dtype is uint64, we assume its fixed point representation # also fits in a uint64. @@ -507,11 +597,11 @@ def _from_fixed_point(fxp_tensor, fxp_fractional_bits, output_dtype): fractional = tf.cast( tf.bitwise.bitwise_and(fxp_tensor, fractional_mask), output_dtype ) - return integer + fractional / fxp_fractional_multiplier + return integer + (fractional / fxp_fractional_multiplier) elif output_dtype in [tf.uint8, tf.uint16, tf.uint32, tf.uint64]: - # When returning an integer datatype, the fractional bits of the fixed + # When returning an unsigned datatype, the fractional bits of the fixed # point representation cannot be stored. Throw the low precision bits - # away. TODO(jchoncholas): Round + # away. assert fxp_tensor.dtype == tf.uint64 tf_tensor = tf.bitwise.right_shift(fxp_tensor, fxp_fractional_bits) # round_up = tf.bitwise.right_shift(fxp_tensor, fxp_fractional_bits - 1) @@ -519,6 +609,9 @@ def _from_fixed_point(fxp_tensor, fxp_fractional_bits, output_dtype): # tf_tensor += round_up return tf.cast(tf_tensor, output_dtype) elif output_dtype in [tf.int8, tf.int16, tf.int32, tf.int64]: + # When returning an integer datatype, the fractional bits of the fixed + # point representation cannot be stored. Throw the low precision bits + # away. assert fxp_tensor.dtype == tf.int64 tf_tensor = tf.bitwise.right_shift(fxp_tensor, fxp_fractional_bits) # round_up = tf.bitwise.right_shift(fxp_tensor, fxp_fractional_bits - 1) @@ -533,16 +626,20 @@ def to_shell_tensor(context, tensor, fxp_fractional_bits=0): if isinstance(tensor, ShellTensor64): return tensor if isinstance(tensor, tf.Tensor): + assert isinstance( + context, ShellContext64 + ), f"Context must be a ShellContext64, instead got {type(context)}" # Convert to fixed point. fxp_tensor = _to_fixed_point(tensor, fxp_fractional_bits) return ShellTensor64( - value=shell_ops.polynomial_import64(context, fxp_tensor), + value=shell_ops.polynomial_import64(context._raw_context, fxp_tensor), context=context, - num_slots=tensor.shape[0], underlying_dtype=tensor.dtype, + is_enc=False, fxp_fractional_bits=fxp_fractional_bits, - mult_count=0, + mul_count=0, + noise_bit_count=context.noise_bits, ) else: raise ValueError("Cannot convert to ShellTensor64") @@ -562,53 +659,59 @@ def from_shell_tensor(s_tensor): # Convert from polynomial representation to plaintext tensorflow tensor. # Always convert to int64, then handle the fixed point as appropriate. tf_tensor = shell_ops.polynomial_export64( - s_tensor._context, + s_tensor._context._raw_context, s_tensor._raw, dtype=shell_export_type, ) return _from_fixed_point( tf_tensor, - s_tensor.num_fxp_fractional_bits(), + s_tensor.num_fxp_fractional_bits, s_tensor._underlying_dtype, ) -def create_context64( - log_n, main_moduli, aux_moduli, plaintext_modulus, noise_variance, seed="" -): - return shell_ops.context_import64( - log_n=log_n, - main_moduli=main_moduli, - aux_moduli=aux_moduli, - plaintext_modulus=plaintext_modulus, - noise_variance=noise_variance, - seed=seed, - ) - - def create_key64(context): - return shell_ops.key_gen64(context) + if not isinstance(context, ShellContext64): + raise ValueError("Context must be a ShellContext64") + return shell_ops.key_gen64(context._raw_context) def create_rotation_key64(context, key): - return shell_ops.rotation_key_gen64(context, key) + if not isinstance(context, ShellContext64): + raise ValueError("Context must be a ShellContext64") + return shell_ops.rotation_key_gen64(context._raw_context, key) def matmul(x, y, rotation_key=None): + """Matrix multiplication is specialized to whether the operands are + plaintext or ciphertext. + + matmul(ciphertext, plaintext) works as in Tensorflow. + + matmul(plaintext, ciphertext) in tf-shell has slightly different semantics + than plaintext / Tensorflow. tf-shell affects top and bottom halves + independently, as well as the first dimension repeating the sum of either + the halves.""" if isinstance(x, ShellTensor64) and isinstance(y, tf.Tensor): # Convert y to fixed point and make sure it's multiplication level # matches x's. - fxp_tensor = _to_fixed_point(y, x.num_fxp_fractional_bits()) + fxp_tensor = _to_fixed_point(y, x.num_fxp_fractional_bits) + + # Noise grows from one multiplication then a sum over that dimension. + multiplication_noise = x._noise_bit_count + x._context.noise_bits + reduce_sum_noise = multiplication_noise + x.shape[1].bit_length() return ShellTensor64( - value=shell_ops.mat_mul_ct_pt64(x._context, x._raw, fxp_tensor), + value=shell_ops.mat_mul_ct_pt64( + x._context._raw_context, x._raw, fxp_tensor + ), context=x._context, - num_slots=x._num_slots, underlying_dtype=x._underlying_dtype, is_enc=True, fxp_fractional_bits=x._fxp_fractional_bits, - mult_count=x._mult_count + 1, + mul_count=x._mul_count + 1, + noise_bit_count=reduce_sum_noise, ) elif isinstance(x, tf.Tensor) and isinstance(y, ShellTensor64): @@ -622,18 +725,25 @@ def matmul(x, y, rotation_key=None): # Convert x to fixed point and make sure it's multiplication level # matches y's. - fxp_tensor = _to_fixed_point(x, y.num_fxp_fractional_bits()) + fxp_tensor = _to_fixed_point(x, y.num_fxp_fractional_bits) + + # Noise grows from doing one multiplication then a reduce_sum operation + # over the outer (ciphertext) dimension. dimension. The noise from the + # reduce_sum is a rough estimate that works for slots = 2**11. + multiplication_noise = y._noise_bit_count + y._context.noise_bits + rotation_noise = multiplication_noise + 60 + reduce_sum_noise = rotation_noise + y._context.num_slots.bit_length() return ShellTensor64( value=shell_ops.mat_mul_pt_ct64( - y._context, rotation_key, fxp_tensor, y._raw + y._context._raw_context, rotation_key, fxp_tensor, y._raw ), context=y._context, - num_slots=y._num_slots, underlying_dtype=y._underlying_dtype, is_enc=True, fxp_fractional_bits=y._fxp_fractional_bits, - mult_count=y._mult_count + 1, + mul_count=y._mul_count + 1, + noise_bit_count=reduce_sum_noise, ) elif isinstance(x, ShellTensor64) and isinstance(y, ShellTensor64): diff --git a/shell_tensor/test/add_test.py b/shell_tensor/test/add_test.py index deb2383..5ea3ffb 100644 --- a/shell_tensor/test/add_test.py +++ b/shell_tensor/test/add_test.py @@ -20,35 +20,34 @@ class TestShellTensor(tf.test.TestCase): def _test_neg(self, test_context, plaintext_dtype, frac_bits): - context = test_context.shell_context - key = shell_tensor.create_key64(context) - # This test performs zero additions, just negation. - min_val, max_val = test_utils.get_bounds_for_n_adds( - plaintext_dtype, test_context.plaintext_modulus, frac_bits, 0 - ) - - a = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, - ) - tf.cast(a, plaintext_dtype) - sa = shell_tensor.to_shell_tensor(context, a) + a = test_utils.uniform_for_n_adds(plaintext_dtype, test_context, frac_bits, 0) + + if a is None: + # Test parameters do not support zero additions at this + # precision. + print( + "Note: Skipping test neg with dtype %s and frac_bits %d. Not enough precision to support this test." + % (plaintext_dtype, frac_bits) + ) + return + + sa = shell_tensor.to_shell_tensor(test_context.shell_context, a, frac_bits) nsa = -sa self.assertAllClose(-a, shell_tensor.from_shell_tensor(nsa)) - ea = sa.get_encrypted(key) + ea = sa.get_encrypted(test_context.key) nea = -ea - self.assertAllClose(-a, nea.get_decrypted(key)) + self.assertAllClose(-a, nea.get_decrypted(test_context.key)) - self.assertAllClose(a, ea.get_decrypted(key)) + self.assertAllClose(a, ea.get_decrypted(test_context.key)) def test_neg(self): for test_context in test_utils.test_contexts: - for frac_bits in test_utils.test_fxp_fractional_bits: - for test_dtype in test_utils.test_dtypes: + for frac_bits in [1]: + # for frac_bits in test_utils.test_fxp_fractional_bits: + for test_dtype in [tf.float32]: + # for test_dtype in test_utils.test_dtypes: if test_dtype.is_unsigned: # Negating an unsigned value is undefined. continue @@ -59,45 +58,44 @@ def test_neg(self): self._test_neg(test_context, test_dtype, frac_bits) def _test_ct_ct_add(self, test_context, plaintext_dtype, frac_bits): - context = test_context.shell_context - key = shell_tensor.create_key64(context) - # This test performs one addition. - min_val, max_val = test_utils.get_bounds_for_n_adds( + _, max_val = test_utils.get_bounds_for_n_adds( plaintext_dtype, test_context.plaintext_modulus, frac_bits, 1 ) - - a = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, - ) - a = tf.cast(a, plaintext_dtype) - b = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, - ) - b = tf.cast(b, plaintext_dtype) - - ea = shell_tensor.to_shell_tensor(context, a).get_encrypted(key) - eb = shell_tensor.to_shell_tensor(context, b).get_encrypted(key) + a = test_utils.uniform_for_n_adds(plaintext_dtype, test_context, frac_bits, 1) + b = test_utils.uniform_for_n_adds(plaintext_dtype, test_context, frac_bits, 1) + + if a is None: + # Test parameters do not support one addition at this + # precision. + print( + "Note: Skipping test ct_ct_add with dtype %s and frac_bits %d. Not enough precision to support this test." + % (plaintext_dtype, frac_bits) + ) + return + + ea = shell_tensor.to_shell_tensor( + test_context.shell_context, a, frac_bits + ).get_encrypted(test_context.key) + eb = shell_tensor.to_shell_tensor( + test_context.shell_context, b, frac_bits + ).get_encrypted(test_context.key) ec = ea + eb - self.assertAllClose(a, ea.get_decrypted(key)) - self.assertAllClose(b, eb.get_decrypted(key)) - self.assertAllClose(a + b, ec.get_decrypted(key)) + self.assertAllClose(a, ea.get_decrypted(test_context.key)) + self.assertAllClose(b, eb.get_decrypted(test_context.key)) + self.assertAllClose(a + b, ec.get_decrypted(test_context.key)) if plaintext_dtype.is_unsigned: # To test subtraction, ensure that a > b to avoid underflow. - eaa = shell_tensor.to_shell_tensor(context, a + max_val).get_encrypted(key) + eaa = shell_tensor.to_shell_tensor( + test_context.shell_context, a + max_val, frac_bits + ).get_encrypted(test_context.key) ed = eaa - eb - self.assertAllClose(a + max_val - b, ed.get_decrypted(key)) + self.assertAllClose(a + max_val - b, ed.get_decrypted(test_context.key)) else: ed = ea - eb - self.assertAllClose(a - b, ed.get_decrypted(key)) + self.assertAllClose(a - b, ed.get_decrypted(test_context.key)) def test_ct_ct_add(self): for test_context in test_utils.test_contexts: @@ -110,56 +108,54 @@ def test_ct_ct_add(self): self._test_ct_ct_add(test_context, test_dtype, frac_bits) def _test_ct_pt_add(self, test_context, plaintext_dtype, frac_bits): - context = test_context.shell_context - key = shell_tensor.create_key64(context) - # This test performs one addition. - min_val, max_val = test_utils.get_bounds_for_n_adds( + _, max_val = test_utils.get_bounds_for_n_adds( plaintext_dtype, test_context.plaintext_modulus, frac_bits, 1 ) - - a = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, - ) - a = tf.cast(a, plaintext_dtype) - b = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, - ) - b = tf.cast(b, plaintext_dtype) - sa = shell_tensor.to_shell_tensor(context, a) - sb = shell_tensor.to_shell_tensor(context, b) - ea = sa.get_encrypted(key) + a = test_utils.uniform_for_n_adds(plaintext_dtype, test_context, frac_bits, 1) + b = test_utils.uniform_for_n_adds(plaintext_dtype, test_context, frac_bits, 1) + + if a is None: + # Test parameters do not support one addition at this + # precision. + print( + "Note: Skipping test ct_pt_add with dtype %s and frac_bits %d. Not enough precision to support this test." + % (plaintext_dtype, frac_bits) + ) + return + + sa = shell_tensor.to_shell_tensor(test_context.shell_context, a, frac_bits) + sb = shell_tensor.to_shell_tensor(test_context.shell_context, b, frac_bits) + ea = sa.get_encrypted(test_context.key) ec = ea + sb - self.assertAllClose(a + b, ec.get_decrypted(key)) + self.assertAllClose(a + b, ec.get_decrypted(test_context.key)) ed = sb + ea - self.assertAllClose(a + b, ed.get_decrypted(key)) + self.assertAllClose(a + b, ed.get_decrypted(test_context.key)) if plaintext_dtype.is_unsigned: # To test subtraction, ensure that a > b to avoid underflow. - eaa = shell_tensor.to_shell_tensor(context, a + max_val).get_encrypted(key) + eaa = shell_tensor.to_shell_tensor( + test_context.shell_context, a + max_val, frac_bits + ).get_encrypted(test_context.key) ee = eaa - sb - self.assertAllClose(a + max_val - b, ee.get_decrypted(key)) + self.assertAllClose(a + max_val - b, ee.get_decrypted(test_context.key)) - sbb = shell_tensor.to_shell_tensor(context, b + max_val) + sbb = shell_tensor.to_shell_tensor( + test_context.shell_context, b + max_val, frac_bits + ) ef = sbb - ea - self.assertAllClose(b + max_val - a, ef.get_decrypted(key)) + self.assertAllClose(b + max_val - a, ef.get_decrypted(test_context.key)) else: ee = ea - sb - self.assertAllClose(a - b, ee.get_decrypted(key)) + self.assertAllClose(a - b, ee.get_decrypted(test_context.key)) ef = sb - ea - self.assertAllClose(b - a, ef.get_decrypted(key)) + self.assertAllClose(b - a, ef.get_decrypted(test_context.key)) # Ensure initial arguments are not modified. - self.assertAllClose(a, ea.get_decrypted(key)) + self.assertAllClose(a, ea.get_decrypted(test_context.key)) self.assertAllClose(b, shell_tensor.from_shell_tensor(sb)) def test_ct_pt_add(self): @@ -173,56 +169,52 @@ def test_ct_pt_add(self): self._test_ct_pt_add(test_context, test_dtype, frac_bits) def _test_ct_tf_add(self, test_context, plaintext_dtype, frac_bits): - context = test_context.shell_context - key = shell_tensor.create_key64(context) - # This test performs one addition. - min_val, max_val = test_utils.get_bounds_for_n_adds( + _, max_val = test_utils.get_bounds_for_n_adds( plaintext_dtype, test_context.plaintext_modulus, frac_bits, 1 ) + a = test_utils.uniform_for_n_adds(plaintext_dtype, test_context, frac_bits, 1) + b = test_utils.uniform_for_n_adds(plaintext_dtype, test_context, frac_bits, 1) - a = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, - ) - a = tf.cast(a, plaintext_dtype) - b = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, - ) - b = tf.cast(b, plaintext_dtype) - sa = shell_tensor.to_shell_tensor(context, a) - ea = sa.get_encrypted(key) + if a is None: + # Test parameters do not support one addition at this + # precision. + print( + "Note: Skipping test ct_tf_add with dtype %s and frac_bits %d. Not enough precision to support this test." + % (plaintext_dtype, frac_bits) + ) + return + + sa = shell_tensor.to_shell_tensor(test_context.shell_context, a, frac_bits) + ea = sa.get_encrypted(test_context.key) ec = ea + b - self.assertAllClose(a, ea.get_decrypted(key)) - self.assertAllClose(a + b, ec.get_decrypted(key)) + self.assertAllClose(a, ea.get_decrypted(test_context.key)) + self.assertAllClose(a + b, ec.get_decrypted(test_context.key)) ed = b + ea - self.assertAllClose(a + b, ed.get_decrypted(key)) + self.assertAllClose(a + b, ed.get_decrypted(test_context.key)) if plaintext_dtype.is_unsigned: # To test subtraction, ensure that a > b to avoid underflow. - eaa = shell_tensor.to_shell_tensor(context, a + max_val).get_encrypted(key) + eaa = shell_tensor.to_shell_tensor( + test_context.shell_context, a + max_val, frac_bits + ).get_encrypted(test_context.key) ee = eaa - b - self.assertAllClose(a + max_val - b, ee.get_decrypted(key)) + self.assertAllClose(a + max_val - b, ee.get_decrypted(test_context.key)) bb = b + max_val ef = bb - ea - self.assertAllClose(bb - a, ef.get_decrypted(key)) + self.assertAllClose(bb - a, ef.get_decrypted(test_context.key)) else: ee = ea - b - self.assertAllClose(a - b, ee.get_decrypted(key)) + self.assertAllClose(a - b, ee.get_decrypted(test_context.key)) ef = b - ea - self.assertAllClose(b - a, ef.get_decrypted(key)) + self.assertAllClose(b - a, ef.get_decrypted(test_context.key)) # Ensure initial arguemnts are not modified. - self.assertAllClose(a, ea.get_decrypted(key)) + self.assertAllClose(a, ea.get_decrypted(test_context.key)) def test_ct_tf_add(self): for test_context in test_utils.test_contexts: @@ -235,36 +227,33 @@ def test_ct_tf_add(self): self._test_ct_tf_add(test_context, test_dtype, frac_bits) def _test_pt_pt_add(self, test_context, plaintext_dtype, frac_bits): - context = test_context.shell_context - # This test performs one addition. - min_val, max_val = test_utils.get_bounds_for_n_adds( + _, max_val = test_utils.get_bounds_for_n_adds( plaintext_dtype, test_context.plaintext_modulus, frac_bits, 1 ) + a = test_utils.uniform_for_n_adds(plaintext_dtype, test_context, frac_bits, 1) + b = test_utils.uniform_for_n_adds(plaintext_dtype, test_context, frac_bits, 1) - a = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, - ) - a = tf.cast(a, plaintext_dtype) - b = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, - ) - b = tf.cast(b, plaintext_dtype) - sa = shell_tensor.to_shell_tensor(context, a) - sb = shell_tensor.to_shell_tensor(context, b) + if a is None: + # Test parameters do not support one addition at this + # precision. + print( + "Note: Skipping test pt_pt_add with dtype %s and frac_bits %d. Not enough precision to support this test." + % (plaintext_dtype, frac_bits) + ) + return + + sa = shell_tensor.to_shell_tensor(test_context.shell_context, a, frac_bits) + sb = shell_tensor.to_shell_tensor(test_context.shell_context, b, frac_bits) sc = sa + sb self.assertAllClose(a + b, shell_tensor.from_shell_tensor(sc)) if plaintext_dtype.is_unsigned: # To test subtraction, ensure that a > b to avoid underflow. - saa = shell_tensor.to_shell_tensor(context, a + max_val) + saa = shell_tensor.to_shell_tensor( + test_context.shell_context, a + max_val, frac_bits + ) ee = saa - sb self.assertAllClose(a + max_val - b, shell_tensor.from_shell_tensor(ee)) else: @@ -286,28 +275,23 @@ def test_pt_pt_add(self): self._test_pt_pt_add(test_context, test_dtype, frac_bits) def _test_pt_tf_add(self, test_context, plaintext_dtype, frac_bits): - context = test_context.shell_context - # This test performs one addition. - min_val, max_val = test_utils.get_bounds_for_n_adds( + _, max_val = test_utils.get_bounds_for_n_adds( plaintext_dtype, test_context.plaintext_modulus, frac_bits, 1 ) + a = test_utils.uniform_for_n_adds(plaintext_dtype, test_context, frac_bits, 1) + b = test_utils.uniform_for_n_adds(plaintext_dtype, test_context, frac_bits, 1) - a = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, - ) - a = tf.cast(a, plaintext_dtype) - b = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, - ) - b = tf.cast(b, plaintext_dtype) - sa = shell_tensor.to_shell_tensor(context, a) + if a is None: + # Test parameters do not support one addition at this + # precision. + print( + "Note: Skipping test pt_tf_add with dtype %s and frac_bits %d. Not enough precision to support this test." + % (plaintext_dtype, frac_bits) + ) + return + + sa = shell_tensor.to_shell_tensor(test_context.shell_context, a, frac_bits) sc = sa + b self.assertAllClose(a + b, shell_tensor.from_shell_tensor(sc)) @@ -317,7 +301,9 @@ def _test_pt_tf_add(self, test_context, plaintext_dtype, frac_bits): if plaintext_dtype.is_unsigned: # To test subtraction, ensure that a > b to avoid underflow. - saa = shell_tensor.to_shell_tensor(context, a + max_val) + saa = shell_tensor.to_shell_tensor( + test_context.shell_context, a + max_val, frac_bits + ) se = saa - b self.assertAllClose(a + max_val - b, shell_tensor.from_shell_tensor(se)) diff --git a/shell_tensor/test/composite_test.py b/shell_tensor/test/composite_test.py index b150912..637d829 100644 --- a/shell_tensor/test/composite_test.py +++ b/shell_tensor/test/composite_test.py @@ -20,14 +20,11 @@ class TestShellTensor(tf.test.TestCase): def _test_ct_ct_mulmul(self, test_context, plaintext_dtype, frac_bits): - shell_context = test_context.shell_context - key = shell_tensor.create_key64(shell_context) - # This test performs two multiplications. - min_val, max_val = test_utils.get_bounds_for_n_muls( - plaintext_dtype, test_context.plaintext_modulus, frac_bits, 2 - ) - if max_val == 0: + a = test_utils.uniform_for_n_muls(plaintext_dtype, test_context, frac_bits, 2) + b = test_utils.uniform_for_n_muls(plaintext_dtype, test_context, frac_bits, 2) + + if a is None: # Test parameters do not support two multiplications at this # precision. print( @@ -36,40 +33,25 @@ def _test_ct_ct_mulmul(self, test_context, plaintext_dtype, frac_bits): ) return - a = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int64, - maxval=max_val, - minval=min_val, - ) - a = tf.cast(a, plaintext_dtype) - b = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int64, - maxval=max_val, - minval=min_val, - ) - b = tf.cast(b, plaintext_dtype) - ea = shell_tensor.to_shell_tensor( - shell_context, a, fxp_fractional_bits=frac_bits - ).get_encrypted(key) + test_context.shell_context, a, fxp_fractional_bits=frac_bits + ).get_encrypted(test_context.key) eb = shell_tensor.to_shell_tensor( - shell_context, b, fxp_fractional_bits=frac_bits - ).get_encrypted(key) + test_context.shell_context, b, fxp_fractional_bits=frac_bits + ).get_encrypted(test_context.key) ec = ea * eb - self.assertAllClose(a * b, ec.get_decrypted(key)) + self.assertAllClose(a * b, ec.get_decrypted(test_context.key)) # Here, ec has a mul_count of 1 while eb has a mul_count of 0. To # multiply them, eb needs to be left shifted by the number of fractional # bits in the fixed point representation to match ec. ShellTensor should # handle this automatically. ed = ec * eb - self.assertAllClose(a * b * b, ed.get_decrypted(key)) + self.assertAllClose(a * b * b, ed.get_decrypted(test_context.key)) - self.assertAllClose(a, ea.get_decrypted(key)) - self.assertAllClose(b, eb.get_decrypted(key)) + self.assertAllClose(a, ea.get_decrypted(test_context.key)) + self.assertAllClose(b, eb.get_decrypted(test_context.key)) def test_ct_ct_mulmul(self): for test_context in test_utils.test_contexts: @@ -82,14 +64,11 @@ def test_ct_ct_mulmul(self): self._test_ct_ct_mulmul(test_context, test_dtype, frac_bits) def _test_ct_pt_mulmul(self, test_context, plaintext_dtype, frac_bits): - shell_context = test_context.shell_context - key = shell_tensor.create_key64(shell_context) - # This test performs two multiplications. - min_val, max_val = test_utils.get_bounds_for_n_muls( - plaintext_dtype, test_context.plaintext_modulus, frac_bits, 2 - ) - if max_val == 0: + a = test_utils.uniform_for_n_muls(plaintext_dtype, test_context, frac_bits, 2) + b = test_utils.uniform_for_n_muls(plaintext_dtype, test_context, frac_bits, 2) + + if a is None: # Test parameters do not support two multiplications at this # precision. print( @@ -98,40 +77,25 @@ def _test_ct_pt_mulmul(self, test_context, plaintext_dtype, frac_bits): ) return - a = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int64, - maxval=max_val, - minval=min_val, - ) - a = tf.cast(a, plaintext_dtype) - b = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int64, - maxval=max_val, - minval=min_val, - ) - b = tf.cast(b, plaintext_dtype) - ea = shell_tensor.to_shell_tensor( - shell_context, a, fxp_fractional_bits=frac_bits - ).get_encrypted(key) + test_context.shell_context, a, fxp_fractional_bits=frac_bits + ).get_encrypted(test_context.key) eb = shell_tensor.to_shell_tensor( - shell_context, b, fxp_fractional_bits=frac_bits - ).get_encrypted(key) + test_context.shell_context, b, fxp_fractional_bits=frac_bits + ).get_encrypted(test_context.key) ec = ea * eb - self.assertAllClose(a * b, ec.get_decrypted(key)) + self.assertAllClose(a * b, ec.get_decrypted(test_context.key)) # Here, ec has a mul_count of 1 while eb has a mul_count of 0. To # multiply them, eb needs to be left shifted by the number of fractional # bits in the fixed point representation to match ec. ShellTensor should # handle this automatically. ed = ec * b - self.assertAllClose(a * b * b, ed.get_decrypted(key)) + self.assertAllClose(a * b * b, ed.get_decrypted(test_context.key)) - self.assertAllClose(a, ea.get_decrypted(key)) - self.assertAllClose(b, eb.get_decrypted(key)) + self.assertAllClose(a, ea.get_decrypted(test_context.key)) + self.assertAllClose(b, eb.get_decrypted(test_context.key)) def test_ct_pt_mulmul(self): for test_context in test_utils.test_contexts: @@ -143,6 +107,50 @@ def test_ct_pt_mulmul(self): ): self._test_ct_pt_mulmul(test_context, test_dtype, frac_bits) + def _test_get_at_mul_count(self, test_context, plaintext_dtype, frac_bits): + # This test performs two multiplications. + a = test_utils.uniform_for_n_muls(plaintext_dtype, test_context, frac_bits, 2) + b = test_utils.uniform_for_n_muls(plaintext_dtype, test_context, frac_bits, 2) + + if a is None: + # Test parameters do not support two multiplications at this + # precision. + print( + "Note: Skipping test ct_pt_mulmul with dtype %s and frac_bits %d. Not enough precision to support this test." + % (plaintext_dtype, frac_bits) + ) + return + + ea = shell_tensor.to_shell_tensor( + test_context.shell_context, a, fxp_fractional_bits=frac_bits + ).get_encrypted(test_context.key) + eb = shell_tensor.to_shell_tensor( + test_context.shell_context, b, fxp_fractional_bits=frac_bits + ).get_encrypted(test_context.key) + + ec = ea * eb + self.assertAllClose(a * b, ec.get_decrypted(test_context.key)) + + ec_right_shifted = ec.get_at_multiplication_count(0) + + # Here, ec_right_shifted and eb both have a mul_count of 0. Multiplying + # them should be straightforward. + ed = ec_right_shifted * b + self.assertAllClose(a * b * b, ed.get_decrypted(test_context.key)) + + self.assertAllClose(a, ea.get_decrypted(test_context.key)) + self.assertAllClose(b, eb.get_decrypted(test_context.key)) + + def test_get_at_mul_count(self): + for test_context in test_utils.test_contexts: + for frac_bits in test_utils.test_fxp_fractional_bits: + for test_dtype in test_utils.test_dtypes: + with self.subTest( + "test_get_at_mul_count with fractional bits %d and dtype %s" + % (frac_bits, test_dtype) + ): + self._test_get_at_mul_count(test_context, test_dtype, frac_bits) + if __name__ == "__main__": tf.test.main() diff --git a/shell_tensor/test/mul_test.py b/shell_tensor/test/mul_test.py index 8d5f017..8e8a189 100644 --- a/shell_tensor/test/mul_test.py +++ b/shell_tensor/test/mul_test.py @@ -20,49 +20,48 @@ class TestShellTensor(tf.test.TestCase): # Matrix multiplication tests require smaller parameters to avoid overflow. - matmul_fxp_fractional_bits = [0, 1] - matmul_max_val = 3 - matmul_val_offset = -1 matmul_dtypes = [ tf.int32, tf.int64, tf.float32, tf.float64, ] + matmul_contexts = [ + # Num plaintext bits: 27, noise bits: 66, num rns moduli: 2 + test_utils.TestContext( + outer_shape=[], # dummy + log_slots=11, + main_moduli=[281474976768001, 281474976829441], + plaintext_modulus=134246401, + ), + ] def _test_ct_ct_mul(self, test_context, plaintext_dtype, frac_bits): - context = test_context.shell_context - key = shell_tensor.create_key64(context) - # This test performs one multiplication. - min_val, max_val = test_utils.get_bounds_for_n_muls( - plaintext_dtype, test_context.plaintext_modulus, frac_bits, 1 - ) + a = test_utils.uniform_for_n_muls(plaintext_dtype, test_context, frac_bits, 1) + b = test_utils.uniform_for_n_muls(plaintext_dtype, test_context, frac_bits, 1) + + if a is None: + # Test parameters do not support reduce_sum at this precision. + print( + "Note: Skipping test ct_ct_mul with dtype %s and frac_bits %d. Not enough precision to support this test." + % (plaintext_dtype, frac_bits) + ) + return - a = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, + sa = shell_tensor.to_shell_tensor( + test_context.shell_context, a, fxp_fractional_bits=frac_bits ) - a = tf.cast(a, plaintext_dtype) - b = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, + sb = shell_tensor.to_shell_tensor( + test_context.shell_context, b, fxp_fractional_bits=frac_bits ) - b = tf.cast(b, plaintext_dtype) - - sa = shell_tensor.to_shell_tensor(context, a, fxp_fractional_bits=frac_bits) - sb = shell_tensor.to_shell_tensor(context, b, fxp_fractional_bits=frac_bits) - ea = sa.get_encrypted(key) - eb = sb.get_encrypted(key) + ea = sa.get_encrypted(test_context.key) + eb = sb.get_encrypted(test_context.key) ec = ea * eb - self.assertAllClose(a, ea.get_decrypted(key)) - self.assertAllClose(b, eb.get_decrypted(key)) - self.assertAllClose(a * b, ec.get_decrypted(key)) + self.assertAllClose(a, ea.get_decrypted(test_context.key)) + self.assertAllClose(b, eb.get_decrypted(test_context.key)) + self.assertAllClose(a * b, ec.get_decrypted(test_context.key)) def test_ct_ct_mul(self): for test_context in test_utils.test_contexts: @@ -75,38 +74,32 @@ def test_ct_ct_mul(self): self._test_ct_ct_mul(test_context, test_dtype, frac_bits) def _test_ct_pt_mul(self, test_context, plaintext_dtype, frac_bits): - context = test_context.shell_context - key = shell_tensor.create_key64(context) - # This test performs one multiplication. - min_val, max_val = test_utils.get_bounds_for_n_muls( - plaintext_dtype, test_context.plaintext_modulus, frac_bits, 1 - ) + a = test_utils.uniform_for_n_muls(plaintext_dtype, test_context, frac_bits, 1) + b = test_utils.uniform_for_n_muls(plaintext_dtype, test_context, frac_bits, 1) + + if a is None: + # Test parameters do not support reduce_sum at this precision. + print( + "Note: Skipping test ct_pt_mul with dtype %s and frac_bits %d. Not enough precision to support this test." + % (plaintext_dtype, frac_bits) + ) + return - a = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, + sa = shell_tensor.to_shell_tensor( + test_context.shell_context, a, fxp_fractional_bits=frac_bits ) - a = tf.cast(a, plaintext_dtype) - b = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, + sb = shell_tensor.to_shell_tensor( + test_context.shell_context, b, fxp_fractional_bits=frac_bits ) - b = tf.cast(b, plaintext_dtype) - sa = shell_tensor.to_shell_tensor(context, a, fxp_fractional_bits=frac_bits) - sb = shell_tensor.to_shell_tensor(context, b, fxp_fractional_bits=frac_bits) - ea = sa.get_encrypted(key) + ea = sa.get_encrypted(test_context.key) ec = ea * sb - self.assertAllClose(a, ea.get_decrypted(key)) - self.assertAllClose(a * b, ec.get_decrypted(key)) + self.assertAllClose(a, ea.get_decrypted(test_context.key)) + self.assertAllClose(a * b, ec.get_decrypted(test_context.key)) ed = sb * ea - self.assertAllClose(a * b, ed.get_decrypted(key)) + self.assertAllClose(a * b, ed.get_decrypted(test_context.key)) def test_ct_pt_mul(self): for test_context in test_utils.test_contexts: @@ -119,37 +112,29 @@ def test_ct_pt_mul(self): self._test_ct_pt_mul(test_context, test_dtype, frac_bits) def _test_ct_tf_scalar_mul(self, test_context, plaintext_dtype, frac_bits): - context = test_context.shell_context - key = shell_tensor.create_key64(context) - # This test performs one multiplication. - min_val, max_val = test_utils.get_bounds_for_n_muls( - plaintext_dtype, test_context.plaintext_modulus, frac_bits, 1 - ) + a = test_utils.uniform_for_n_muls(plaintext_dtype, test_context, frac_bits, 1) + b = test_utils.uniform_for_n_muls(plaintext_dtype, test_context, frac_bits, 1) + + if a is None: + # Test parameters do not support reduce_sum at this precision. + print( + "Note: Skipping test ct_tf_scalar_mul with dtype %s and frac_bits %d. Not enough precision to support this test." + % (plaintext_dtype, frac_bits) + ) + return - a = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, - ) - a = tf.cast(a, plaintext_dtype) - b = tf.random.uniform( - [], - dtype=tf.int32, - maxval=max_val, - minval=min_val, + sa = shell_tensor.to_shell_tensor( + test_context.shell_context, a, fxp_fractional_bits=frac_bits ) - b = tf.cast(b, plaintext_dtype) - sa = shell_tensor.to_shell_tensor(context, a, fxp_fractional_bits=frac_bits) - ea = sa.get_encrypted(key) + ea = sa.get_encrypted(test_context.key) ec = ea * b - self.assertAllClose(a, ea.get_decrypted(key)) - self.assertAllClose(a * b, ec.get_decrypted(key)) + self.assertAllClose(a, ea.get_decrypted(test_context.key)) + self.assertAllClose(a * b, ec.get_decrypted(test_context.key)) ed = b * ea - self.assertAllClose(a * b, ed.get_decrypted(key)) + self.assertAllClose(a * b, ed.get_decrypted(test_context.key)) def test_ct_tf_scalar_mul(self): for test_context in test_utils.test_contexts: @@ -162,37 +147,29 @@ def test_ct_tf_scalar_mul(self): self._test_ct_tf_scalar_mul(test_context, test_dtype, frac_bits) def _test_ct_tf_mul(self, test_context, plaintext_dtype, frac_bits): - context = test_context.shell_context - key = shell_tensor.create_key64(context) - # This test performs one multiplication. - min_val, max_val = test_utils.get_bounds_for_n_muls( - plaintext_dtype, test_context.plaintext_modulus, frac_bits, 1 - ) + a = test_utils.uniform_for_n_muls(plaintext_dtype, test_context, frac_bits, 1) + b = test_utils.uniform_for_n_muls(plaintext_dtype, test_context, frac_bits, 1) + + if a is None: + # Test parameters do not support reduce_sum at this precision. + print( + "Note: Skipping test ct_pt_mul with dtype %s and frac_bits %d. Not enough precision to support this test." + % (plaintext_dtype, frac_bits) + ) + return - a = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, + sa = shell_tensor.to_shell_tensor( + test_context.shell_context, a, fxp_fractional_bits=frac_bits ) - a = tf.cast(a, plaintext_dtype) - b = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, - ) - b = tf.cast(b, plaintext_dtype) - sa = shell_tensor.to_shell_tensor(context, a, fxp_fractional_bits=frac_bits) - ea = sa.get_encrypted(key) + ea = sa.get_encrypted(test_context.key) ec = ea * b - self.assertAllClose(a, ea.get_decrypted(key)) - self.assertAllClose(a * b, ec.get_decrypted(key)) + self.assertAllClose(a, ea.get_decrypted(test_context.key)) + self.assertAllClose(a * b, ec.get_decrypted(test_context.key)) ed = b * ea - self.assertAllClose(a * b, ed.get_decrypted(key)) + self.assertAllClose(a * b, ed.get_decrypted(test_context.key)) def test_ct_tf_mul(self): for test_context in test_utils.test_contexts: @@ -205,29 +182,24 @@ def test_ct_tf_mul(self): self._test_ct_tf_mul(test_context, test_dtype, frac_bits) def _test_pt_pt_mul(self, test_context, plaintext_dtype, frac_bits): - context = test_context.shell_context - # This test performs one multiplication. - min_val, max_val = test_utils.get_bounds_for_n_muls( - plaintext_dtype, test_context.plaintext_modulus, frac_bits, 1 - ) + a = test_utils.uniform_for_n_muls(plaintext_dtype, test_context, frac_bits, 1) + b = test_utils.uniform_for_n_muls(plaintext_dtype, test_context, frac_bits, 1) + + if a is None: + # Test parameters do not support reduce_sum at this precision. + print( + "Note: Skipping test pt_pt_mul with dtype %s and frac_bits %d. Not enough precision to support this test." + % (plaintext_dtype, frac_bits) + ) + return - a = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, + sa = shell_tensor.to_shell_tensor( + test_context.shell_context, a, fxp_fractional_bits=frac_bits ) - a = tf.cast(a, plaintext_dtype) - b = tf.random.uniform( - [test_context.slots, 2, 3, 4], - dtype=tf.int32, - maxval=max_val, - minval=min_val, + sb = shell_tensor.to_shell_tensor( + test_context.shell_context, b, fxp_fractional_bits=frac_bits ) - b = tf.cast(b, plaintext_dtype) - sa = shell_tensor.to_shell_tensor(context, a, fxp_fractional_bits=frac_bits) - sb = shell_tensor.to_shell_tensor(context, b, fxp_fractional_bits=frac_bits) sc = sa * sb self.assertAllClose(a * b, shell_tensor.from_shell_tensor(sc)) @@ -243,30 +215,37 @@ def test_pt_pt_mul(self): self._test_pt_pt_mul(test_context, test_dtype, frac_bits) def _test_ct_tf_matmul(self, test_context, plaintext_dtype, frac_bits): - context = test_context.shell_context - key = shell_tensor.create_key64(context) - - a = ( - tf.random.uniform( - [test_context.slots, 5], dtype=tf.int32, maxval=self.matmul_max_val - ) - - self.matmul_val_offset + a = test_utils.uniform_for_n_muls( + plaintext_dtype, + test_context, + frac_bits, + 1, + shape=[test_context.slots, 5], + subsequent_adds=5, # For dim(1) ) - a = tf.cast(a, plaintext_dtype) - b = ( - tf.random.uniform([5, 7], dtype=tf.int32, maxval=self.matmul_max_val) - - self.matmul_val_offset + b = test_utils.uniform_for_n_muls( + plaintext_dtype, test_context, frac_bits, 1, shape=[5, 7] ) - b = tf.cast(b, plaintext_dtype) + + if a is None or b is None: + print( + "Note: Skipping test ct_tf_matmul with dtype %s and frac_bits %d. Not enough precision to support this test." + % (plaintext_dtype, frac_bits) + ) + return ea = shell_tensor.to_shell_tensor( - context, a, fxp_fractional_bits=frac_bits - ).get_encrypted(key) + test_context.shell_context, a, fxp_fractional_bits=frac_bits + ).get_encrypted(test_context.key) ec = shell_tensor.matmul(ea, b) - self.assertAllClose(a, ea.get_decrypted(key)) + self.assertAllClose(a, ea.get_decrypted(test_context.key)) c = tf.matmul(a, b) - self.assertAllClose(c, ec.get_decrypted(key)) + self.assertAllClose( + c, + ec.get_decrypted(test_context.key) + # , atol=5 * 2 ** (-frac_bits - 1) + ) def test_ct_tf_matmul(self): for test_context in test_utils.test_contexts: @@ -312,40 +291,48 @@ def plaintext_matmul(self, a, b): return tf.concat([top, bottom], axis=0) def _test_tf_ct_matmul(self, test_context, plaintext_dtype, frac_bits): - context = test_context.shell_context - key = shell_tensor.create_key64(context) - rotation_key = shell_tensor.create_rotation_key64(context, key) - - a = ( - tf.random.uniform( - [3, 5, test_context.slots], - dtype=tf.int32, - maxval=self.matmul_max_val, - ) - - self.matmul_val_offset + a = test_utils.uniform_for_n_muls( + plaintext_dtype, + test_context, + frac_bits, + 1, + shape=[3, 5, test_context.slots], + subsequent_adds=test_context.slots / 2, ) - a = tf.cast(a, plaintext_dtype) - b = ( - tf.random.uniform( - [test_context.slots, 2], dtype=tf.int32, maxval=self.matmul_max_val - ) - - self.matmul_val_offset + b = test_utils.uniform_for_n_muls( + plaintext_dtype, + test_context, + frac_bits, + 1, + shape=[test_context.slots, 2], + subsequent_adds=test_context.slots / 2, ) - b = tf.cast(b, plaintext_dtype) + if a is None or b is None: + print( + "Note: Skipping test tf_ct_matmul with dtype %s and frac_bits %d. Not enough precision to support this test." + % (plaintext_dtype, frac_bits) + ) + return eb = shell_tensor.to_shell_tensor( - context, b, fxp_fractional_bits=frac_bits - ).get_encrypted(key) + test_context.shell_context, b, fxp_fractional_bits=frac_bits + ).get_encrypted(test_context.key) - ec = shell_tensor.matmul(a, eb, rotation_key) - self.assertAllClose(b, eb.get_decrypted(key)) + ec = shell_tensor.matmul(a, eb, test_context.rotation_key) + self.assertAllClose( + b, + eb.get_decrypted(test_context.key), + ) check_c = self.plaintext_matmul(a, b) - self.assertAllClose(check_c.shape, ec.shape) - self.assertAllClose(check_c, ec.get_decrypted(key)) + self.assertAllClose( + check_c, + ec.get_decrypted(test_context.key), + # atol=test_context.slots * 2 ** (-frac_bits - 1), + ) def test_tf_ct_matmul(self): - for test_context in test_utils.test_contexts: + for test_context in self.matmul_contexts: for frac_bits in test_utils.test_fxp_fractional_bits: for test_dtype in self.matmul_dtypes: with self.subTest( diff --git a/shell_tensor/test/rotation_test.py b/shell_tensor/test/rotation_test.py index a368514..7b055dc 100644 --- a/shell_tensor/test/rotation_test.py +++ b/shell_tensor/test/rotation_test.py @@ -30,17 +30,6 @@ class TestShellTensorRotation(tf.test.TestCase): roll_test_outer_shape = [3, 3] test_outer_shape = [2, 5, 4] - def _test_keygen(self, test_context): - context = test_context.shell_context - key = shell_tensor.create_key64(context) - rotation_key = shell_tensor.create_rotation_key64(context, key) - assert rotation_key is not None - - def test_keygen(self): - for test_context in test_utils.test_contexts: - with self.subTest("keygen"): - self._test_keygen(test_context) - # TensorFlow's roll has slightly different sematics than encrypted roll. # Encrypted rotation affects top and bottom halves independently. # This function emulates this in plaintext by splitting the tensor in half, @@ -52,9 +41,7 @@ def plaintext_roll(self, t, shift): rotated_tftensor = tf.concat([top, bottom], axis=0) return rotated_tftensor - def _test_roll(self, test_context, key, rotation_key, plaintext_dtype, roll_num): - context = test_context.shell_context - + def _test_roll(self, test_context, plaintext_dtype, roll_num): # Create a tensor with the shape of slots x (outer_shape) where each # column of the first dimensions counts from 0 to slots-1. tftensor = tf.range(0, test_context.slots, delta=1, dtype=plaintext_dtype) @@ -66,18 +53,15 @@ def _test_roll(self, test_context, key, rotation_key, plaintext_dtype, roll_num) rolled_tftensor = self.plaintext_roll(tftensor, roll_num) - s = shell_tensor.to_shell_tensor(context, tftensor) - enc = s.get_encrypted(key) + s = shell_tensor.to_shell_tensor(test_context.shell_context, tftensor) + enc = s.get_encrypted(test_context.key) - rolled_enc = enc.roll(rotation_key, roll_num) - rolled_result = rolled_enc.get_decrypted(key) + rolled_enc = enc.roll(test_context.rotation_key, roll_num) + rolled_result = rolled_enc.get_decrypted(test_context.key) self.assertAllClose(rolled_tftensor, rolled_result) def test_roll(self): for test_context in test_utils.test_contexts: - context = test_context.shell_context - key = shell_tensor.create_key64(context) - rotation_key = shell_tensor.create_rotation_key64(context, key) rotation_range = test_context.slots // 2 - 1 for test_dtype in self.rotation_dtypes: @@ -85,9 +69,7 @@ def test_roll(self): with self.subTest( "rotate with dtype %s, rotating by %s" % (test_dtype, roll_num) ): - self._test_roll( - test_context, key, rotation_key, test_dtype, roll_num - ) + self._test_roll(test_context, test_dtype, roll_num) # TensorFlow's reduce_sum has slightly different semantics than encrypted # reduce_sum. Encrypted reduce_sum affects top and bottom halves @@ -103,20 +85,13 @@ def plaintext_reduce_sum_axis_0(self, t): return tf.concat([repeated_bottom_answer, repeated_top_answer], 0) - def _test_reduce_sum_axis_0( - self, test_context, key, rotation_key, plaintext_dtype, frac_bits - ): - context = test_context.shell_context - + def _test_reduce_sum_axis_0(self, test_context, plaintext_dtype, frac_bits): # reduce_sum across axis 0 requires adding over all the slots. - min_val, max_val = test_utils.get_bounds_for_n_adds( - plaintext_dtype, - test_context.plaintext_modulus, - frac_bits, - test_context.slots, + tftensor = test_utils.uniform_for_n_adds( + plaintext_dtype, test_context, frac_bits, test_context.slots ) - if max_val is 0: + if tftensor is None: # Test parameters do not support reduce_sum at this precision. print( "Note: Skipping test reduce_sum_axis0 with dtype %s and frac_bits %d. Not enough precision to support this test." @@ -124,30 +99,18 @@ def _test_reduce_sum_axis_0( ) return - test_shape = self.test_outer_shape.copy() - test_shape.insert(0, test_context.slots) - - tftensor = tf.random.uniform( - test_shape, - dtype=tf.int64, - maxval=max_val, - minval=min_val, + s = shell_tensor.to_shell_tensor( + test_context.shell_context, tftensor, frac_bits ) - tftensor = tf.cast(tftensor, plaintext_dtype) - s = shell_tensor.to_shell_tensor(context, tftensor) - enc = s.get_encrypted(key) + enc = s.get_encrypted(test_context.key) - enc_reduce_sum = enc.reduce_sum(axis=0, rotation_key=rotation_key) + enc_reduce_sum = enc.reduce_sum(axis=0, rotation_key=test_context.rotation_key) - tftensor_out = enc_reduce_sum.get_decrypted(key) + tftensor_out = enc_reduce_sum.get_decrypted(test_context.key) self.assertAllClose(tftensor_out, self.plaintext_reduce_sum_axis_0(tftensor)) def test_reduce_sum_axis_0(self): for test_context in test_utils.test_contexts: - context = test_context.shell_context - key = shell_tensor.create_key64(context) - rotation_key = shell_tensor.create_rotation_key64(context, key) - for frac_bits in test_utils.test_fxp_fractional_bits: for test_dtype in self.rotation_dtypes: with self.subTest( @@ -155,24 +118,18 @@ def test_reduce_sum_axis_0(self): % (frac_bits, test_dtype) ): self._test_reduce_sum_axis_0( - test_context, key, rotation_key, test_dtype, frac_bits + test_context, test_dtype, frac_bits ) def _test_reduce_sum_axis_n( self, test_context, plaintext_dtype, frac_bits, outer_axis ): - context = test_context.shell_context - key = shell_tensor.create_key64(context) - # reduce_sum across `axis` requires adding over that dimension. - min_val, max_val = test_utils.get_bounds_for_n_adds( - plaintext_dtype, - test_context.plaintext_modulus, - frac_bits, - self.test_outer_shape[outer_axis], + tftensor = test_utils.uniform_for_n_adds( + plaintext_dtype, test_context, frac_bits, self.test_outer_shape[outer_axis] ) - if max_val == 0: + if tftensor is None: # Test parameters do not support reduce_sum at this precision. print( "Note: Skipping test reduce_sum_axis0 with dtype %s and frac_bits %d. Not enough precision to support this test." @@ -180,22 +137,14 @@ def _test_reduce_sum_axis_n( ) return - test_shape = self.test_outer_shape.copy() - test_shape.insert(0, test_context.slots) - - tftensor = tf.random.uniform( - test_shape, - dtype=tf.int64, - maxval=max_val, - minval=min_val, + s = shell_tensor.to_shell_tensor( + test_context.shell_context, tftensor, frac_bits ) - tftensor = tf.cast(tftensor, plaintext_dtype) - s = shell_tensor.to_shell_tensor(context, tftensor) - enc = s.get_encrypted(key) + enc = s.get_encrypted(test_context.key) enc_reduce_sum = enc.reduce_sum(axis=outer_axis + 1) - tftensor_out = enc_reduce_sum.get_decrypted(key) + tftensor_out = enc_reduce_sum.get_decrypted(test_context.key) self.assertAllClose(tftensor_out, tf.reduce_sum(tftensor, axis=outer_axis + 1)) def test_reduce_sum_axis_n(self): diff --git a/shell_tensor/test/test_utils.py b/shell_tensor/test/test_utils.py index f4945c9..a932800 100644 --- a/shell_tensor/test/test_utils.py +++ b/shell_tensor/test/test_utils.py @@ -17,7 +17,7 @@ import shell_tensor import math -test_fxp_fractional_bits = [0, 1] +test_fxp_fractional_bits = [0, 1, 2, 3, 4] test_dtypes = [ tf.int8, tf.uint8, @@ -33,7 +33,8 @@ class TestContext: - def __init__(self, log_slots, plaintext_modulus): + def __init__(self, outer_shape, log_slots, main_moduli, plaintext_modulus): + self.outer_shape = outer_shape self.log_slots = log_slots self.slots = 2**log_slots @@ -41,22 +42,35 @@ def __init__(self, log_slots, plaintext_modulus): self.shell_context = shell_tensor.create_context64( log_n=log_slots, - main_moduli=[8556589057, 8388812801], - aux_moduli=[34359709697], + main_moduli=main_moduli, + aux_moduli=[], plaintext_modulus=plaintext_modulus, noise_variance=4, seed="", ) + self.key = shell_tensor.create_key64(self.shell_context) -test_contexts = [TestContext(log_slots=11, plaintext_modulus=40961)] + self.rotation_key = shell_tensor.create_rotation_key64( + self.shell_context, self.key + ) + + +test_contexts = [ + TestContext( + outer_shape=[3, 2, 3], + log_slots=11, + main_moduli=[8556589057, 8388812801], + plaintext_modulus=40961, + ), +] -def get_bounds_for_n_muls(dtype, plaintext_modulus, num_fxp_fractional_bits, num_muls): +def get_bounds_for_n_muls(dtype, plaintext_modulus, num_frac_bits, num_muls): """Returns a safe range for plaintext values when doing a given number of multiplications. The range is determined by both the plaintext modulus and the datatype.""" - max_fractional_bits = 2**num_muls * num_fxp_fractional_bits + max_fractional_bits = 2**num_muls * num_frac_bits max_fractional_value = 2**max_fractional_bits # Make sure not to exceed the range of the dtype. @@ -85,10 +99,10 @@ def get_bounds_for_n_muls(dtype, plaintext_modulus, num_fxp_fractional_bits, num return min_val, max_val -def get_bounds_for_n_adds(dtype, plaintext_modulus, num_fxp_fractional_bits, num_adds): +def get_bounds_for_n_adds(dtype, plaintext_modulus, num_frac_bits, num_adds): """Returns a safe range for plaintext values when doing a given number of additions.""" - max_fractional_bits = num_fxp_fractional_bits + max_fractional_bits = num_frac_bits max_fractional_value = 2**max_fractional_bits # Make sure not to exceed the range of the dtype. @@ -115,3 +129,76 @@ def get_bounds_for_n_adds(dtype, plaintext_modulus, num_fxp_fractional_bits, num max_val = min(max_plaintext_dtype, max_plaintext_modulus) return min_val, max_val + + +def uniform_for_n_adds(dtype, test_context, num_fxp_frac_bits, num_adds): + """Returns a random tensor with values in the range of the datatype and + plaintext modulus. The elements support n additions without overflowing + either the datatype and plaintext modulus. Floating point datatypes return + fractional values at the appropriate quantization.""" + min_val, max_val = get_bounds_for_n_adds( + dtype, test_context.plaintext_modulus, num_fxp_frac_bits, num_adds + ) + + if max_val < 2 ** (-num_fxp_frac_bits - 1): + return None + + if dtype.is_floating: + min_val *= 2**num_fxp_frac_bits + max_val *= 2**num_fxp_frac_bits + + shape = test_context.outer_shape.copy() + shape.insert(0, test_context.slots) + + rand = tf.random.uniform( + shape, + dtype=tf.int64, + maxval=max_val, + minval=min_val, + ) + + rand = tf.cast(rand, dtype) + if dtype.is_floating: + rand /= 2**num_fxp_frac_bits + + return rand + + +def uniform_for_n_muls( + dtype, test_context, num_fxp_frac_bits, num_muls, shape=None, subsequent_adds=0 +): + """Returns a random tensor with values in the range of the datatype and + plaintext modulus. The elements support n additions without overflowing + either the datatype and plaintext modulus. Floating point datatypes return + fractional values at the appropriate quantization. + """ + min_val, max_val = get_bounds_for_n_muls( + dtype, test_context.plaintext_modulus, num_fxp_frac_bits, num_muls + ) + + min_val = math.floor(min_val / (subsequent_adds + 1)) + max_val = math.floor(max_val / (subsequent_adds + 1)) + + if max_val < 2 ** (-num_fxp_frac_bits - 1): + return None + + if dtype.is_floating: + min_val *= 2**num_fxp_frac_bits + max_val *= 2**num_fxp_frac_bits + + if shape is None: + shape = test_context.outer_shape.copy() + shape.insert(0, test_context.slots) + + rand = tf.random.uniform( + shape, + dtype=tf.int64, + maxval=max_val, + minval=min_val, + ) + + rand = tf.cast(rand, dtype) + if dtype.is_floating: + rand /= 2**num_fxp_frac_bits + + return rand