diff --git a/tf_shell/python/shell_tensor.py b/tf_shell/python/shell_tensor.py index 2aa3be3..a9eb2e9 100644 --- a/tf_shell/python/shell_tensor.py +++ b/tf_shell/python/shell_tensor.py @@ -63,7 +63,11 @@ def is_encrypted(self): @property def level(self): - return self.level + return self._level + + @property + def scaling_factor(self): + return self._scaling_factor def __getitem__(self, slice): slots = slice[0] @@ -377,9 +381,8 @@ def __mul__(self, other): assert other.shape[0] == 1 other = tf.reshape(other, other.shape[1:]) - # Encode the other scalar tensor to the same scaling factor as - # self. - other = _encode_scaling(other, self._scaling_factor) + # Encode the other scalar tensor to the context scaling factor. + other = _encode_scaling(other, self._context.scaling_factor) if self.is_encrypted: raw_result = shell_ops.mul_ct_tf_scalar64( @@ -400,7 +403,7 @@ def __mul__(self, other): _level=self._level, _num_mod_reductions=self._num_mod_reductions, _underlying_dtype=self._underlying_dtype, - _scaling_factor=self._scaling_factor**2, + _scaling_factor=self._scaling_factor * self._context.scaling_factor, _is_enc=self._is_enc, _is_fast_rotated=self._is_fast_rotated, )