Skip to content

Commit

Permalink
Estimate actual number of weights, variables, seed generators if expe…
Browse files Browse the repository at this point in the history
…cted ones not set. (keras-team#20010)

When using layers uses composition it should build each sublayer manually.
  • Loading branch information
shkarupa-alex authored Aug 15, 2024
1 parent d72a0ea commit 97714e3
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
1 change: 0 additions & 1 deletion keras/src/layers/activations/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def __init__(self, axis=-1, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True
self.axis = axis
self.built = True

def call(self, inputs, mask=None):
if mask is not None:
Expand Down
27 changes: 27 additions & 0 deletions keras/src/testing/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,33 @@ def run_layer_test(
lambda _: "float32", input_shape
)

# Estimate actual number of weights, variables, seed generators if
# expected ones not set. When using layers uses composition it should
# build each sublayer manually.
if input_data is not None or input_shape is not None:
if input_data is None:
input_data = create_eager_tensors(
input_shape, input_dtype, input_sparse
)
layer = layer_cls(**init_kwargs)
if isinstance(input_data, dict):
layer(**input_data, **call_kwargs)
else:
layer(input_data, **call_kwargs)

if expected_num_trainable_weights is None:
expected_num_trainable_weights = len(layer.trainable_weights)
if expected_num_non_trainable_weights is None:
expected_num_non_trainable_weights = len(
layer.non_trainable_weights
)
if expected_num_non_trainable_variables is None:
expected_num_non_trainable_variables = len(
layer.non_trainable_variables
)
if expected_num_seed_generators is None:
expected_num_seed_generators = len(get_seed_generators(layer))

# Serialization test.
layer = layer_cls(**init_kwargs)
self.run_class_serialization_test(layer, custom_objects)
Expand Down

0 comments on commit 97714e3

Please sign in to comment.