Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement helper @as_jax_op to wrap JAX functions in PyTensor #537

Open
ricardoV94 opened this issue Dec 7, 2023 · 4 comments · May be fixed by #1120
Open

Implement helper @as_jax_op to wrap JAX functions in PyTensor #537

ricardoV94 opened this issue Dec 7, 2023 · 4 comments · May be fixed by #1120

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 7, 2023

Description

This blogpost walks through the logic for 3 different examples: https://www.pymc-labs.com/blog-posts/jax-functions-in-pymc-3-quick-examples/ and shows the logic is always the same:

  1. Wrap jitted forward pass in Op
  2. Wrap jitted jvp (or vjp I can never remember) as a GradOp to provide gradient implementation
  3. Dispatch unjitted versions of the two Ops for integration with `function(... , mode="JAX")

Things that cannot be obtained automatically (or maybe they can?) and should be opt-in as in @as_op:
4. Input and outputs types
5. infer_shape

@ricardoV94 ricardoV94 added enhancement New feature or request jax backend compatibility feature request and removed enhancement New feature or request labels Dec 7, 2023
@ricardoV94 ricardoV94 added the torch PyTorch backend label Jul 10, 2024
@ricardoV94
Copy link
Member Author

With the torch backend an as_torch_op should also be feasible

@Dhruvanshu-Joshi
Copy link
Member

https://colab.research.google.com/drive/1-GwLYGDIROVwk-PU0XrZdvy0hp7ddjB6?usp=sharing
Hi, can anyone give this a quick check? I was basically trying to solve issue #537 and came up with a solution along this lines for CumOp.
Also the latter part wherein I do "get_cumop_class" is kindof hardcodded for now. I was not able to figure out how to determine the mod passed in "pytensor.function". Can you help me out with that?
If this is correct, I'll start extending this to the other ops and open a PR.

@ricardoV94
Copy link
Member Author

@Dhruvanshu-Joshi looks like a great start. One thing is that axis/mode shouldn't be properties of the Op. In general the decorator will wrap a single Op without any extra parameters.

Regarding the function mode question, why do you need to know it? There shouldn't be a different Op for the different modes. The way backends are implementd is with dispatch. An Op must always have a perform method that runs in the default (python) backend. Than for special backends like JAX/NUMBA/PYTORCH we dispatch a function on the same Op, don't create a new Op. Described in here: https://pytensor.readthedocs.io/en/latest/extending/creating_a_numba_jax_op.html

You'll notice the blogpost also doesn't create distinct Ops for the different modes

@jdehning jdehning linked a pull request Dec 12, 2024 that will close this issue
10 tasks
@jdehning
Copy link

I started a draft PR, and also have my first question. Where do I add this decorator (see PR)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants