You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
---------------------------------------------------------------------------
UnexpectedTracerError Traceback (most recent call last)
Cell In[64], line 4
2 x_dummy = np.random.rand(100, 10)
3 y_dummy = np.random.randint(0, 2, size=(100,))
----> 4 model.fit(x_dummy, y_dummy, epochs=5, batch_size=10)
File /opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
119 filtered_tb = _process_traceback_frames(e.__traceback__)
120 # To get the full stack trace, call:
121 # `keras.config.disable_traceback_filtering()`
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
[... skipping hidden 15 frame]
File /opt/conda/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py:1720, in DynamicJaxprTracer._assert_live(self)
1718 def _assert_live(self) -> None:
1719 if not self._trace.main.jaxpr_stack: # type: ignore
-> 1720 raise core.escaped_tracer_error(self, None)
UnexpectedTracerError: Encountered an unexpected tracer. A functiontransformed by JAX had a side effect, allowing fora reference to an intermediate value with type uint32[2] wrappedin a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The functionbeing traced when the value leaked was <lambda> at /tmp/ipykernel_34/1644797871.py:17 traced for cond.
------------------------------
The leaked intermediate value was created on line /tmp/ipykernel_34/1644797871.py:17 (<lambda>).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/tmp/ipykernel_34/1281975580.py:4 (<module>)
/tmp/ipykernel_34/1644797871.py:15 (call)
/tmp/ipykernel_34/1644797871.py:17 (<lambda>)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
The text was updated successfully, but these errors were encountered:
The error you are getting is because JAX transformations expects the functions to explicitly return their outputs and you are trying to calculate noise inside the lambda function which leads to leaked intermediate value. So, to avoid side effects you can modify your call method like this:
From a Python perspective I think what's going on is that the lambda definition will capture its args values, hence the seed generator won't be called at every call. If you drop the lambda and instead use something like a class method you should be good. Or do like the suggestion above, that's a solid pattern.
The following code breaks.
The text was updated successfully, but these errors were encountered: