Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify dispatch of JAX random variables by handling rng split automatically #1204

Open
ricardoV94 opened this issue Feb 12, 2025 · 0 comments

Comments

@ricardoV94
Copy link
Member

Description

JAX inner-most dispatch for RandomVariables: jax_sample_fn, look like

@jax_sample_fn.register(ptr.CauchyRV)
@jax_sample_fn.register(ptr.GumbelRV)
@jax_sample_fn.register(ptr.LaplaceRV)
@jax_sample_fn.register(ptr.LogisticRV)
@jax_sample_fn.register(ptr.NormalRV)
def jax_sample_fn_loc_scale(op, node):
"""JAX implementation of random variables in the loc-scale families.
JAX only implements the standard version of random variables in the
loc-scale family. We thus need to translate and rescale the results
manually.
"""
name = op.name
jax_op = getattr(jax.random, name)
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
loc, scale = parameters
if size is None:
size = jax.numpy.broadcast_arrays(loc, scale)[0].shape
sample = loc + jax_op(sampling_key, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
return sample_fn

The whole rng logic could be handled on the outermost dispatch jax_funcify_RandomVariable instead:

if None in static_size:
assert_size_argument_jax_compatible(node)
def sample_fn(rng, size, *parameters):
return jax_sample_fn(op, node=node)(rng, size, out_dtype, *parameters)
else:
def sample_fn(rng, size, *parameters):
return jax_sample_fn(op, node=node)(
rng, static_size, out_dtype, *parameters
)
return sample_fn

If an implementation needs a split other than 2, they can split the provided rng again anyway.

@ricardoV94 ricardoV94 changed the title Simplify dispatch of JAX random variables by handling rng split automatically. Simplify dispatch of JAX random variables by handling rng split automatically Feb 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant