Skip to content

Commit

Permalink
Fix tracking scaling factor through addition and subtraction.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Oct 10, 2024
1 parent a588bcf commit 2144ee8
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tf_shell/python/shell_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __add__(self, other):
_level=matched_self._level,
_num_mod_reductions=matched_self._num_mod_reductions,
_underlying_dtype=self._underlying_dtype,
_scaling_factor=self._scaling_factor,
_scaling_factor=matched_self._scaling_factor,
_is_enc=self._is_enc or other._is_enc,
_is_fast_rotated=self._is_fast_rotated or other._is_fast_rotated,
)
Expand Down Expand Up @@ -199,7 +199,7 @@ def __sub__(self, other):
_level=matched_self._level,
_num_mod_reductions=matched_self._num_mod_reductions,
_underlying_dtype=self._underlying_dtype,
_scaling_factor=self._scaling_factor,
_scaling_factor=matched_self._scaling_factor,
_is_enc=self._is_enc or other._is_enc,
_is_fast_rotated=self._is_fast_rotated or other._is_fast_rotated,
)
Expand Down Expand Up @@ -488,9 +488,9 @@ def _match_moduli_and_scaling(x, y):

# Match the scaling factors.
while x._scaling_factor > y._scaling_factor:
y = y.__mul__(x._scaling_factor)
y = y.__mul__(y._scaling_factor)
while x._scaling_factor < y._scaling_factor:
x = x.__mul__(y._scaling_factor)
x = x.__mul__(x._scaling_factor)

return x, y

Expand Down

0 comments on commit 2144ee8

Please sign in to comment.