You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
ricardoV94
changed the title
Simplify dispatch of JAX random variables by handling rng split automatically.
Simplify dispatch of JAX random variables by handling rng split automatically
Feb 12, 2025
Description
JAX inner-most dispatch for RandomVariables:
jax_sample_fn
, look likepytensor/pytensor/link/jax/dispatch/random.py
Lines 146 to 172 in 964cccb
The whole rng logic could be handled on the outermost dispatch
jax_funcify_RandomVariable
instead:pytensor/pytensor/link/jax/dispatch/random.py
Lines 104 to 117 in 964cccb
If an implementation needs a split other than 2, they can split the provided rng again anyway.
The text was updated successfully, but these errors were encountered: