You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
https://colab.research.google.com/drive/1-GwLYGDIROVwk-PU0XrZdvy0hp7ddjB6?usp=sharing
Hi, can anyone give this a quick check? I was basically trying to solve issue #537 and came up with a solution along this lines for CumOp.
Also the latter part wherein I do "get_cumop_class" is kindof hardcodded for now. I was not able to figure out how to determine the mod passed in "pytensor.function". Can you help me out with that?
If this is correct, I'll start extending this to the other ops and open a PR.
@Dhruvanshu-Joshi looks like a great start. One thing is that axis/mode shouldn't be properties of the Op. In general the decorator will wrap a single Op without any extra parameters.
Regarding the function mode question, why do you need to know it? There shouldn't be a different Op for the different modes. The way backends are implementd is with dispatch. An Op must always have a perform method that runs in the default (python) backend. Than for special backends like JAX/NUMBA/PYTORCH we dispatch a function on the same Op, don't create a new Op. Described in here: https://pytensor.readthedocs.io/en/latest/extending/creating_a_numba_jax_op.html
You'll notice the blogpost also doesn't create distinct Ops for the different modes
Description
This blogpost walks through the logic for 3 different examples: https://www.pymc-labs.com/blog-posts/jax-functions-in-pymc-3-quick-examples/ and shows the logic is always the same:
Things that cannot be obtained automatically (or maybe they can?) and should be opt-in as in
@as_op
:4. Input and outputs types
5. infer_shape
The text was updated successfully, but these errors were encountered: