Skip to content

Commit

Permalink
Fix dtype mismatch in embedding layer.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Oct 29, 2024
1 parent 56f9cdf commit d68f374
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tf_shell_ml/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def build(self, input_shape):

def call(self, inputs):
if inputs.dtype != tf.int64:
raise ValueError(
f"Embedding layer expects int64 input. Got {inputs.dtype}."
)
# When using model.fit() keras will cast the input to float32.
inputs = tf.cast(inputs, tf.int64)

if inputs.ndim != 2:
raise ValueError(f"Embedding layer expects rank 2 input. Got {inputs}.")

Expand Down

0 comments on commit d68f374

Please sign in to comment.