Skip to content

Commit

Permalink
Make ShellTensor's scaling factor a public property.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Oct 30, 2024
1 parent d68f374 commit 0f0f7d1
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions tf_shell/python/shell_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ def is_encrypted(self):

@property
def level(self):
return self.level
return self._level

@property
def scaling_factor(self):
return self._scaling_factor

def __getitem__(self, slice):
slots = slice[0]
Expand Down Expand Up @@ -377,9 +381,8 @@ def __mul__(self, other):
assert other.shape[0] == 1
other = tf.reshape(other, other.shape[1:])

# Encode the other scalar tensor to the same scaling factor as
# self.
other = _encode_scaling(other, self._scaling_factor)
# Encode the other scalar tensor to the context scaling factor.
other = _encode_scaling(other, self._context.scaling_factor)

if self.is_encrypted:
raw_result = shell_ops.mul_ct_tf_scalar64(
Expand All @@ -400,7 +403,7 @@ def __mul__(self, other):
_level=self._level,
_num_mod_reductions=self._num_mod_reductions,
_underlying_dtype=self._underlying_dtype,
_scaling_factor=self._scaling_factor**2,
_scaling_factor=self._scaling_factor * self._context.scaling_factor,
_is_enc=self._is_enc,
_is_fast_rotated=self._is_fast_rotated,
)
Expand Down

0 comments on commit 0f0f7d1

Please sign in to comment.