Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UnexpectedTracerError: JAX had a side effect - DynamicJaxprTracer - set JAX_CHECK_TRACER_LEAKS #20724

Closed
innat opened this issue Jan 4, 2025 · 3 comments

Comments

@innat
Copy link

innat commented Jan 4, 2025

The following code breaks.

import os
os.environ["KERAS_BACKEND"] = "jax"

import keras
from keras import layers

class RNL(layers.Layer):
    def __init__(self, noise_rate, **kwargs):
        super().__init__(**kwargs)
        self.noise_rate = noise_rate
        self.seed_generator = keras.random.SeedGenerator(seed=1337)

    def call(self, inputs):
        apply_noise = keras.random.uniform([], seed=self.seed_generator) < self.noise_rate
        outputs = keras.ops.cond(
            pred=apply_noise,
            true_fn=lambda: inputs + keras.random.uniform(
                shape=keras.ops.shape(inputs),
                minval=0,
                maxval=self.noise_rate,
                seed=self.seed_generator
            ),
            false_fn=lambda: inputs,
        )
        return inputs

        def compute_output_shape(self, input_shape):
            return input_shape
import numpy as np
from keras import layers, models

def create_dummy_model(noise_rate=0.1):
    model = models.Sequential([
        layers.Input(shape=(10,)),
        RNL(noise_rate=noise_rate),
        layers.Dense(32, activation="relu"),
        layers.Dense(1, activation="sigmoid")
    ])
    return model
    
model = create_dummy_model(noise_rate=0.2)
model.compile(
    optimizer="adam", 
    loss="binary_crossentropy", 
    metrics=["accuracy"]
)
x_dummy = np.random.rand(100, 10)
y_dummy = np.random.randint(0, 2, size=(100,))
model.fit(x_dummy, y_dummy, epochs=5, batch_size=10)
---------------------------------------------------------------------------
UnexpectedTracerError                     Traceback (most recent call last)
Cell In[64], line 4
      2 x_dummy = np.random.rand(100, 10)
      3 y_dummy = np.random.randint(0, 2, size=(100,))
----> 4 model.fit(x_dummy, y_dummy, epochs=5, batch_size=10)

File /opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

    [... skipping hidden 15 frame]

File /opt/conda/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py:1720, in DynamicJaxprTracer._assert_live(self)
   1718 def _assert_live(self) -> None:
   1719   if not self._trace.main.jaxpr_stack:  # type: ignore
-> 1720     raise core.escaped_tracer_error(self, None)

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was <lambda> at /tmp/ipykernel_34/1644797871.py:17 traced for cond.
------------------------------
The leaked intermediate value was created on line /tmp/ipykernel_34/1644797871.py:17 (<lambda>). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/tmp/ipykernel_34/1281975580.py:4 (<module>)
/tmp/ipykernel_34/1644797871.py:15 (call)
/tmp/ipykernel_34/1644797871.py:17 (<lambda>)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
@sonali-kumari1
Copy link

Hi @innat -

The error you are getting is because JAX transformations expects the functions to explicitly return their outputs and you are trying to calculate noise inside the lambda function which leads to leaked intermediate value. So, to avoid side effects you can modify your call method like this:

def call(self, inputs):
        noise=keras.random.uniform(
                shape=keras.ops.shape(inputs),
                minval=0,
                maxval=self.noise_rate,
                seed=self.seed_generator
            )
        apply_noise = keras.random.uniform([], seed=self.seed_generator) < self.noise_rate
        outputs = keras.ops.cond(
            pred=apply_noise,
            true_fn=lambda: inputs+ noise,
            false_fn=lambda: inputs,
        )
        return inputs

Attaching gist for your reference. Thanks!

@fchollet
Copy link
Collaborator

fchollet commented Jan 7, 2025

From a Python perspective I think what's going on is that the lambda definition will capture its args values, hence the seed generator won't be called at every call. If you drop the lambda and instead use something like a class method you should be good. Or do like the suggestion above, that's a solid pattern.

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants