Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement gradient for vector repetitions #1192

Merged
merged 2 commits into from
Feb 11, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Feb 5, 2025

Nobody asked for it, and nobody probably ever needed it (it exists since 2012), but it was a fun challenge.

Also cleaned up a bit the implementation by not allowing negative axis on the Op itself like we do in other cases (the helper repeat handles the conversion to positive axis)

doctest fails due to an unrelated print that should be solved by #1193


📚 Documentation preview 📚: https://pytensor--1192.org.readthedocs.build/en/1192/

@ricardoV94 ricardoV94 force-pushed the repeat_grad branch 4 times, most recently from fe54502 to 2ac8e9c Compare February 10, 2025 22:10
Copy link

codecov bot commented Feb 10, 2025

Codecov Report

Attention: Patch coverage is 92.30769% with 3 lines in your changes missing coverage. Please review.

Project coverage is 82.26%. Comparing base (4fa9bb8) to head (2543387).
Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/extra_ops.py 92.10% 1 Missing and 2 partials ⚠️

❌ Your patch status has failed because the patch coverage (92.30%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1192      +/-   ##
==========================================
- Coverage   82.27%   82.26%   -0.02%     
==========================================
  Files         186      186              
  Lines       48066    48069       +3     
  Branches     8633     8633              
==========================================
- Hits        39546    39543       -3     
- Misses       6360     6366       +6     
  Partials     2160     2160              
Files with missing lines Coverage Δ
pytensor/link/vm.py 92.00% <100.00%> (ø)
pytensor/tensor/extra_ops.py 88.06% <92.10%> (+0.05%) ⬆️

... and 1 file with indirect coverage changes

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs some small cleanup, otherwise cool!

pytensor/tensor/extra_ops.py Outdated Show resolved Hide resolved
pytensor/tensor/extra_ops.py Outdated Show resolved Hide resolved
pytensor/tensor/extra_ops.py Outdated Show resolved Hide resolved
It returns an array which has the same shape as `x`, except along the given
`axis`. The `axis` parameter is used to specify the axis along which values
are repeated. By default, a flattened version of `x` is used.
See `numpy.repeat` for more information.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if I'm working in pure pytensor? I have to import numpy just to read our docs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you go to the docs :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add links like JAX does, I think we already have cross-ref so it should work the same way: https://github.com/jax-ml/jax/blob/e7acb20ea2ba037a0e82134a6b1bcf35430d28ca/jax/_src/numpy/lax_numpy.py#L6980

Then the docs have a clickable link: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.repeat.html

Input data, tensor variable.
repeats
int, scalar or tensor variable
a: tensor_like
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
a: tensor_like
a: TensorLike

I think the pycharm linter checks these for valid types. There's no type called tensor_like, and it's also not fully human readable; kind of a weird middle ground

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we had a glossary like in PyMC: https://www.pymc.io/projects/docs/en/stable/glossary.html#term-tensor_like

In which case we should use tensor_like in the docs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytensor/tensor/extra_ops.py Show resolved Hide resolved
pytensor/tensor/extra_ops.py Show resolved Hide resolved
@ricardoV94 ricardoV94 merged commit ffdde1c into pymc-devs:main Feb 11, 2025
63 of 64 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants