From d68f374420b0929299b9adde8cadd0ed2f867752 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Tue, 29 Oct 2024 22:17:36 +0000 Subject: [PATCH] Fix dtype mismatch in embedding layer. --- tf_shell_ml/embedding.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tf_shell_ml/embedding.py b/tf_shell_ml/embedding.py index 31cb1db..93fc89d 100644 --- a/tf_shell_ml/embedding.py +++ b/tf_shell_ml/embedding.py @@ -63,9 +63,9 @@ def build(self, input_shape): def call(self, inputs): if inputs.dtype != tf.int64: - raise ValueError( - f"Embedding layer expects int64 input. Got {inputs.dtype}." - ) + # When using model.fit() keras will cast the input to float32. + inputs = tf.cast(inputs, tf.int64) + if inputs.ndim != 2: raise ValueError(f"Embedding layer expects rank 2 input. Got {inputs}.")