Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use expand_dims / squeeze in JAX implementation of Dimshuffle #847

Open
ricardoV94 opened this issue Jun 24, 2024 · 2 comments · May be fixed by #987
Open

Use expand_dims / squeeze in JAX implementation of Dimshuffle #847

ricardoV94 opened this issue Jun 24, 2024 · 2 comments · May be fixed by #987

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 24, 2024

@jax_funcify.register(DimShuffle)
def jax_funcify_DimShuffle(op, **kwargs):
def dimshuffle(x):
res = jnp.transpose(x, op.transposition)
shape = list(res.shape[: len(op.shuffle)])
for augm in op.augment:
shape.insert(augm, 1)
res = jnp.reshape(res, shape)
if not op.inplace:
res = jnp.copy(res)
return res
return dimshuffle

The JAX docs of lax.reshape (which np.reshape uses) suggest this may be better for further optimizations: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.reshape.html#jax.lax.reshape

Relevant part:

For inserting/removing dimensions of size 1, prefer using lax.squeeze / lax.expand_dims. These preserve information about axis identity that may be useful for advanced transformation rules.

@HarshvirSandhu
Copy link
Contributor

I can work on this

@ricardoV94
Copy link
Member Author

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants