-
Notifications
You must be signed in to change notification settings - Fork 119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement gradient for vector repetitions #1192
Conversation
fe54502
to
2ac8e9c
Compare
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (92.30%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1192 +/- ##
==========================================
- Coverage 82.27% 82.26% -0.02%
==========================================
Files 186 186
Lines 48066 48069 +3
Branches 8633 8633
==========================================
- Hits 39546 39543 -3
- Misses 6360 6366 +6
Partials 2160 2160
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs some small cleanup, otherwise cool!
pytensor/tensor/extra_ops.py
Outdated
It returns an array which has the same shape as `x`, except along the given | ||
`axis`. The `axis` parameter is used to specify the axis along which values | ||
are repeated. By default, a flattened version of `x` is used. | ||
See `numpy.repeat` for more information. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if I'm working in pure pytensor? I have to import numpy just to read our docs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you go to the docs :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could add links like JAX does, I think we already have cross-ref so it should work the same way: https://github.com/jax-ml/jax/blob/e7acb20ea2ba037a0e82134a6b1bcf35430d28ca/jax/_src/numpy/lax_numpy.py#L6980
Then the docs have a clickable link: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.repeat.html
Input data, tensor variable. | ||
repeats | ||
int, scalar or tensor variable | ||
a: tensor_like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a: tensor_like | |
a: TensorLike |
I think the pycharm linter checks these for valid types. There's no type called tensor_like
, and it's also not fully human readable; kind of a weird middle ground
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we had a glossary like in PyMC: https://www.pymc.io/projects/docs/en/stable/glossary.html#term-tensor_like
In which case we should use tensor_like
in the docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CC @OriolAbril
Also cleans up implementation and documentation
b3280d9
to
2543387
Compare
Nobody asked for it, and nobody probably ever needed it (it exists since 2012), but it was a fun challenge.
Also cleaned up a bit the implementation by not allowing negative axis on the Op itself like we do in other cases (the helper repeat handles the conversion to positive axis)
doctest fails due to an unrelated print that should be solved by #1193
📚 Documentation preview 📚: https://pytensor--1192.org.readthedocs.build/en/1192/