-
Notifications
You must be signed in to change notification settings - Fork 269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
restore test_apply_paddings_check runtime_checks test #771
base: main
Are you sure you want to change the base?
Conversation
The main idea is that we need to call `jax.effects_barrier()`, because the error may be raised in an XLA computation that is asynchronous with the main Python thread and therefore we need to block. (There may have been a recent change in behavior, where JAX runs more computations asynchronously on the CPU backend.) We could put that call to `jax.effects_barrier()` in the test code (and corresponding user code), or we could bulid it into the `runtime_checks` context manager. Currently this commit does the latter. I also tweaked the `runtime_checks` logic to use a `try/finally` pattern to restore the state when the context is exited, even when it's exited via exception. We may want to do the same to context managers like `numeric_checks`. While the test now passes, there is a gross warning printed about "Exception ignored in atexit callback". That may be a JAX internal bug, or it may be some quirk of CPython 3.10; I haven't investigated further. Let me know if that seems like a problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mattjj! While the warning isn't pleasant, I think we can live with it for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
My intrepid teammates @yashk2810 and @hawkinsp noticed that in the most recent release of JAX we no longer raise |
Thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
The main idea is that we need to call
jax.effects_barrier()
, because the error may be raised in an XLA computation that is asynchronous with the main Python thread and therefore we need to block. (There may have been a recent change in behavior, where JAX runs more computations asynchronously on the CPU backend.) We could put that call tojax.effects_barrier()
in the test code (and corresponding user code), or we could bulid it into theruntime_checks
context manager. Currently this commit does the latter.I also tweaked the
runtime_checks
logic to use atry/finally
pattern to restore the state when the context is exited, even when it's exited via exception. We may want to do the same to context managers likenumeric_checks
.While the test now passes, there is a gross warning printed about "Exception ignored in atexit callback". That may be a JAX internal bug, or it may be some quirk of CPython 3.10; I haven't investigated further. Let me know if that seems like a problem.
What do you think?