Skip to content

Commit

Permalink
Fix issue with Masking layer with Tensor as mask_value (#20791)
Browse files Browse the repository at this point in the history
* Fix issue with Masking layer with Tensor as `mask_value`

* fix formatting
  • Loading branch information
Surya2k1 authored Jan 23, 2025
1 parent 90568da commit 592c118
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
4 changes: 4 additions & 0 deletions keras/src/layers/core/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.layers.layer import Layer
from keras.src.saving.serialization_lib import deserialize_keras_object


@keras_export("keras.layers.Masking")
Expand Down Expand Up @@ -45,6 +46,9 @@ class Masking(Layer):

def __init__(self, mask_value=0.0, **kwargs):
super().__init__(**kwargs)
# `mask_value` can be a serialized tensor, hence verify it
if isinstance(mask_value, dict) and mask_value.get("config", None):
mask_value = deserialize_keras_object(mask_value)
self.mask_value = mask_value
self.supports_masking = True
self.built = True
Expand Down
21 changes: 21 additions & 0 deletions keras/src/layers/core/masking_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from keras.src import layers
from keras.src import models
from keras.src import ops
from keras.src import testing
from keras.src.saving import load_model


class MaskingTest(testing.TestCase):
Expand Down Expand Up @@ -57,3 +59,22 @@ def call(self, inputs, mask=None):
]
)
model(x)

@pytest.mark.requires_trainable_backend
def test_masking_with_tensor(self):
model = models.Sequential(
[
layers.Masking(mask_value=ops.convert_to_tensor([0.0])),
layers.LSTM(1),
]
)
x = np.array(
[
[[0.0, 0.0], [1.0, 2.0], [0.0, 0.0]],
[[2.0, 2.0], [0.0, 0.0], [2.0, 1.0]],
]
)
model(x)
model.save("model.keras")
reload_model = load_model("model.keras")
reload_model(x)

0 comments on commit 592c118

Please sign in to comment.