Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss functions applied in alphabetical order instead of by dictionary keys in Keras 3.5.0 #20596

Closed
malcolmlett opened this issue Dec 5, 2024 · 5 comments
Assignees
Labels
keras-team-review-pending Pending review by a Keras team member. type:Bug

Comments

@malcolmlett
Copy link

Environment info

  • Google Colab (CPU or GPU)
  • Tensorflow 2.17.0, 2.17.1
  • Python 3.10.12

Problem description
There seems to be a change in Keras 3.5.0 that has introduced a bug for models with multiple outputs.
The problem is not present in Keras 3.4.1.

Passing a dictionary as loss to model.compile() should result in those loss functions being applied to the respective outputs based on output name. But instead they now appear to be applied in alphabetical order of dictionary keys, leading to the wrong loss functions being applied against the model outputs.

For example, in the following snippet, "loss_small" gets applied against "output_big" when it should be applied against "output_small". It appears that the loss dictionary gets 1) re-ordered by alphabetical order of key, and then 2) the dictionary values are read off in the resultant order and applied as an ordered list against the model outputs.

...
output_small = Dense(1, activation="sigmoid", name="output_small")(x)
output_big = Dense(64, activation="softmax", name="output_big")(x)
model = Model(inputs=input_layer, outputs=[output_small, output_big])
model.compile(optimizer='adam',
              loss={
                  'output_small': DebugLoss(name='loss_small'),
                  'output_big': DebugLoss(name='loss_big')
              })

This conclusion is the result of flipping the orders of these components and comparing the results. Which is what the following code does...

Code to reproduce

import sys
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model

print(f"TensorFlow version: {tf.__version__}")
print(f"Keras version: {tf.keras.__version__}")
print(f"Python version: {sys.version}")
print()

print("Problem doesn't occur if model outputs happen to be ordered alphabetically: (big, small)")

# Generate synthetic training data
num_samples = 100
x_train = np.random.normal(size=(num_samples, 10))  # Input data
y_train_output_big = np.eye(64)[np.random.choice(64, size=num_samples)]  # Shape (num_samples, 64)
y_train_output_small = np.random.choice([0, 1], size=(num_samples, 1))  # Shape (num_samples, 1)

dataset = tf.data.Dataset.from_tensor_slices((x_train, (y_train_output_big, y_train_output_small)))
dataset = dataset.batch(num_samples)

# Define model with single input and two named outputs
input_layer = Input(shape=(10,))
x = Dense(64, activation="relu")(input_layer)
output_big = Dense(64, activation="softmax", name="output_big")(x)     # (100,64)
output_small = Dense(1, activation="sigmoid", name="output_small")(x)  # (100,1)
model = Model(inputs=input_layer, outputs=[output_big, output_small])

# Compile with custom loss function for debugging
class DebugLoss(tf.keras.losses.Loss):
    def call(self, y_true, y_pred):
        print(f"{self.name} - y_true: {y_true.shape}, y_pred: {y_pred.shape}")
        return tf.reduce_mean((y_true - y_pred)**2)

model.compile(optimizer='adam',
              loss={
                  'output_big': DebugLoss(name='loss_big'),
                  'output_small': DebugLoss(name='loss_small')
              })

# Train 
tf.config.run_functions_eagerly(True)
history = model.fit(dataset, epochs=1, verbose=0)

print()
print("Problem occurs if model outputs happen to be ordered non-alphabetically: (small, big)")

# Generate synthetic training data
num_samples = 100
x_train = np.random.normal(size=(num_samples, 10))  # Input data
y_train_output_small = np.random.choice([0, 1], size=(num_samples, 1))  # Shape (num_samples, 1)
y_train_output_big = np.eye(64)[np.random.choice(64, size=num_samples)]  # Shape (num_samples, 64)

dataset = tf.data.Dataset.from_tensor_slices((x_train, (y_train_output_small, y_train_output_big)))
dataset = dataset.batch(num_samples)

# Define model with single input and two named outputs
input_layer = Input(shape=(10,))
x = Dense(64, activation="relu")(input_layer)
output_small = Dense(1, activation="sigmoid", name="output_small")(x) # (100,1)
output_big = Dense(64, activation="softmax", name="output_big")(x)    # (100,64)
model = Model(inputs=input_layer, outputs=[output_small, output_big])

# Compile with custom loss function for debugging
class DebugLoss(tf.keras.losses.Loss):
    def call(self, y_true, y_pred):
        print(f"{self.name} - y_true: {y_true.shape}, y_pred: {y_pred.shape}")
        return tf.reduce_mean((y_true - y_pred)**2)

model.compile(optimizer='adam',
              loss={
                  'output_small': DebugLoss(name='loss_small'),
                  'output_big': DebugLoss(name='loss_big')
              })

# Train 
tf.config.run_functions_eagerly(True)
history = model.fit(dataset, epochs=1, verbose=0)

Code outputs on various environments
Current Google Colab env - fails on second ordering:

TensorFlow version: 2.17.1
Keras version: 3.5.0
Python version: 3.10.12 (main, Nov  6 2024, 20:22:13) [GCC 11.4.0]

Problem doesn't occur if model outputs happen to be ordered alphabetically: (big, small)
loss_big - y_true: (100, 64), y_pred: (100, 64)
loss_small - y_true: (100, 1), y_pred: (100, 1)

Problem occurs occur if model outputs happen to be ordered non-alphabetically: (small, big)
loss_big - y_true: (100, 1), y_pred: (100, 1)
loss_small - y_true: (100, 64), y_pred: (100, 64)

Downgraded TF version, no change:

TensorFlow version: 2.17.0
Keras version: 3.5.0
Python version: 3.10.12 (main, Nov  6 2024, 20:22:13) [GCC 11.4.0]

Problem doesn't occur if model outputs happen to be ordered alphabetically: (big, small)
loss_big - y_true: (100, 64), y_pred: (100, 64)
loss_small - y_true: (100, 1), y_pred: (100, 1)

Problem occurs occur if model outputs happen to be ordered non-alphabetically: (small, big)
loss_big - y_true: (100, 1), y_pred: (100, 1)
loss_small - y_true: (100, 64), y_pred: (100, 64)

Downgraded Keras, and now get correct output for both orderings

TensorFlow version: 2.17.0
Keras version: 3.4.1
Python version: 3.10.12 (main, Nov  6 2024, 20:22:13) [GCC 11.4.0]

Problem doesn't occur if model outputs happen to be ordered alphabetically: (big, small)
loss_big - y_true: (100, 64), y_pred: (100, 64)
loss_small - y_true: (100, 1), y_pred: (100, 1)

Problem occurs occur if model outputs happen to be ordered non-alphabetically: (small, big)
loss_small - y_true: (100, 1), y_pred: (100, 1)
loss_big - y_true: (100, 64), y_pred: (100, 64)

Final remarks
This seems related to tensorflow/tensorflow#37887, but looks like someone has since tried to fix that bug and introduced another perhaps?

@malcolmlett
Copy link
Author

A quick update:

  • The problem affects loss, but not metrics passed to model.compile().

I believe a reliable workaround is to use an ordered tuple instead. For example, this works for me:

model.compile(optimizer='adam',
              loss=(
                  DebugLoss(name='loss_small'),
                  DebugLoss(name='loss_big')
              ),
              metrics={
                  'output_small': DebugMetric(name='metric_small'),
                  'output_big': DebugMetric(name='metric_big')
              })

@sonali-kumari1
Copy link
Contributor

Hi @malcolmlett,

Thanks for reporting this issue. I tried to reproduce this issue but the output seems to be correct for both order in the latest version of keras(3.7.0) and tensorflow(2.18.0).
Attaching gist for your reference.

@malcolmlett
Copy link
Author

I expect this will depend on what the Google Colab contributers have to say, but Google Colab's currently defaulting to keras 3.5.0, which has the bug. So I'd be wondering if there's any chance of a backport / fix in keras 3.5.0.....but I leave that up to those who know better.

@VarunS1997
Copy link
Collaborator

Hi, this was a bug with tree that was fixed in the recent versions. Back-fixing would be a bit tough and we aren't sure when Colab will update their image. But, if you use some "Colab magic" you can force the update on your instance:

!pip install -U keras

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
keras-team-review-pending Pending review by a Keras team member. type:Bug
Projects
None yet
Development

No branches or pull requests

4 participants