Skip to content

Commit

Permalink
Add name scope to common operations.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Oct 10, 2024
1 parent 451c2d9 commit e971b9f
Showing 1 changed file with 54 additions and 50 deletions.
104 changes: 54 additions & 50 deletions tf_shell/python/shell_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,24 +473,25 @@ def mod_reduce_tensor64(shell_tensor):


def _match_moduli_and_scaling(x, y):
# Mod switch to the smaller modulus of the two.
while x._num_mod_reductions < y._num_mod_reductions:
x = mod_reduce_tensor64(x)
while x._num_mod_reductions > y._num_mod_reductions:
y = mod_reduce_tensor64(y)

# First make sure the scaling factors are compatible.
frac = x._scaling_factor / y._scaling_factor
if abs(frac - int(frac)) != 0:
raise ValueError(
f"Scaling factors must be compatible. Got {x._scaling_factor} and {y._scaling_factor}"
)
with tf.name_scope("match_moduli_and_scaling"):
# Mod switch to the smaller modulus of the two.
while x._num_mod_reductions < y._num_mod_reductions:
x = mod_reduce_tensor64(x)
while x._num_mod_reductions > y._num_mod_reductions:
y = mod_reduce_tensor64(y)

# First make sure the scaling factors are compatible.
frac = x._scaling_factor / y._scaling_factor
if abs(frac - int(frac)) != 0:
raise ValueError(
f"Scaling factors must be compatible. Got {x._scaling_factor} and {y._scaling_factor}"
)

# Match the scaling factors.
while x._scaling_factor > y._scaling_factor:
y = y.__mul__(y._scaling_factor)
while x._scaling_factor < y._scaling_factor:
x = x.__mul__(x._scaling_factor)
# Match the scaling factors.
while x._scaling_factor > y._scaling_factor:
y = y.__mul__(y._scaling_factor)
while x._scaling_factor < y._scaling_factor:
x = x.__mul__(x._scaling_factor)

return x, y

Expand All @@ -507,30 +508,32 @@ def _get_shell_dtype_from_underlying(type):


def _encode_scaling(tf_tensor, scaling_factor=1):
if tf_tensor.dtype in [tf.float32, tf.float64]:
return tf.cast(tf.round(tf_tensor * scaling_factor), tf.int64)
elif tf_tensor.dtype in [tf.uint8, tf.uint16, tf.uint32, tf.uint64]:
# Pass unsigned datatypes to shell as uint64.
return tf.cast(tf_tensor, tf.uint64)
elif tf_tensor.dtype in [tf.int8, tf.int16, tf.int32, tf.int64]:
# Pass signed datatypes to shell as int64.
return tf.cast(tf_tensor, tf.int64)
else:
raise ValueError(f"Unsupported dtype {tf_tensor.dtype}")
with tf.name_scope("encode_scaling"):
if tf_tensor.dtype in [tf.float32, tf.float64]:
return tf.cast(tf.round(tf_tensor * scaling_factor), tf.int64)
elif tf_tensor.dtype in [tf.uint8, tf.uint16, tf.uint32, tf.uint64]:
# Pass unsigned datatypes to shell as uint64.
return tf.cast(tf_tensor, tf.uint64)
elif tf_tensor.dtype in [tf.int8, tf.int16, tf.int32, tf.int64]:
# Pass signed datatypes to shell as int64.
return tf.cast(tf_tensor, tf.int64)
else:
raise ValueError(f"Unsupported dtype {tf_tensor.dtype}")


def _decode_scaling(scaled_tensor, output_dtype, scaling_factor):
if output_dtype in [tf.float32, tf.float64]:
assert scaled_tensor.dtype == tf.int64
return tf.cast(scaled_tensor, output_dtype) / scaling_factor
elif output_dtype in [tf.uint8, tf.uint16, tf.uint32, tf.uint64]:
assert scaled_tensor.dtype == tf.uint64
return tf.cast(scaled_tensor, output_dtype)
elif output_dtype in [tf.int8, tf.int16, tf.int32, tf.int64]:
assert scaled_tensor.dtype == tf.int64
return tf.cast(scaled_tensor, output_dtype)
else:
raise ValueError(f"Unsupported dtype {output_dtype}")
with tf.name_scope("decode_scaling"):
if output_dtype in [tf.float32, tf.float64]:
assert scaled_tensor.dtype == tf.int64
return tf.cast(scaled_tensor, output_dtype) / scaling_factor
elif output_dtype in [tf.uint8, tf.uint16, tf.uint32, tf.uint64]:
assert scaled_tensor.dtype == tf.uint64
return tf.cast(scaled_tensor, output_dtype)
elif output_dtype in [tf.int8, tf.int16, tf.int32, tf.int64]:
assert scaled_tensor.dtype == tf.int64
return tf.cast(scaled_tensor, output_dtype)
else:
raise ValueError(f"Unsupported dtype {output_dtype}")


def to_shell_plaintext(tensor, context):
Expand All @@ -554,18 +557,19 @@ def to_shell_plaintext(tensor, context):
scaled_tensor = _encode_scaling(tensor, context.scaling_factor)

# Pad the tensor to the correct number of slots.
first_dim = tf.cast(tf.shape(scaled_tensor)[0], dtype=tf.int64)
tf.Assert(
context.num_slots >= first_dim,
[f"First dimension must be <= {context.num_slots}. Got {first_dim}"],
)
padding = [[0, 0] for _ in range(len(scaled_tensor.shape))]
padding[0][1] = tf.cond(
context.num_slots > first_dim,
lambda: context.num_slots - first_dim,
lambda: tf.constant(0, dtype=tf.int64),
)
scaled_tensor = tf.pad(scaled_tensor, padding)
with tf.name_scope("pad_to_slots"):
first_dim = tf.cast(tf.shape(scaled_tensor)[0], dtype=tf.int64)
tf.Assert(
context.num_slots >= first_dim,
[f"First dimension must be <= {context.num_slots}. Got {first_dim}"],
)
padding = [[0, 0] for _ in range(len(scaled_tensor.shape))]
padding[0][1] = tf.cond(
context.num_slots > first_dim,
lambda: context.num_slots - first_dim,
lambda: tf.constant(0, dtype=tf.int64),
)
scaled_tensor = tf.pad(scaled_tensor, padding)

return ShellTensor64(
_raw_tensor=shell_ops.polynomial_import64(
Expand Down

0 comments on commit e971b9f

Please sign in to comment.