Skip to content

Commit

Permalink
Fix the compatibility issues of Orthogonal and GRU (keras-team#19844
Browse files Browse the repository at this point in the history
)

* Add legacy `Orthogonal` class name

* Add legacy `implementation` arg to `GRU`
  • Loading branch information
james77777778 authored Jun 12, 2024
1 parent 46df341 commit 4551644
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras/src/initializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"uniform": RandomUniform,
"normal": RandomNormal,
"orthogonal": OrthogonalInitializer,
"Orthogonal": OrthogonalInitializer, # Legacy
"one": Ones,
"zero": Zeros,
}
Expand Down
4 changes: 4 additions & 0 deletions keras/src/initializers/random_initializers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ def test_orthogonal_initializer(self):

self.run_class_serialization_test(initializer)

# Test legacy class_name
initializer = initializers.get("Orthogonal")
self.assertIsInstance(initializer, initializers.OrthogonalInitializer)

def test_get_method(self):
obj = initializers.get("glorot_normal")
self.assertTrue(obj, initializers.GlorotNormal)
Expand Down
1 change: 1 addition & 0 deletions keras/src/layers/rnn/gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ def __init__(
trainable=kwargs.get("trainable", True),
name="gru_cell",
seed=seed,
implementation=kwargs.pop("implementation", 2),
)
super().__init__(
cell,
Expand Down
23 changes: 23 additions & 0 deletions keras/src/layers/rnn/gru_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,26 @@ def test_masking(self):
np.array([[0.11669192, 0.11669192], [0.28380975, 0.28380975]]),
output,
)

def test_legacy_implementation_argument(self):
sequence = np.arange(72).reshape((3, 6, 4)).astype("float32")
layer = layers.GRU(
3,
kernel_initializer=initializers.Constant(0.01),
recurrent_initializer=initializers.Constant(0.02),
bias_initializer=initializers.Constant(0.03),
)
config = layer.get_config()
config["implementation"] = 0 # Add legacy argument
layer = layers.GRU.from_config(config)
output = layer(sequence)
self.assertAllClose(
np.array(
[
[0.5217289, 0.5217289, 0.5217289],
[0.6371659, 0.6371659, 0.6371659],
[0.39384964, 0.39384964, 0.3938496],
]
),
output,
)

0 comments on commit 4551644

Please sign in to comment.