-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make preprocessing and postprocessing consistent accross transforms #93
Conversation
I did not add n-dimensional separable transforms on purpose, because I was thinking people will ask for these in all other cases, too, where these are trickier to deliver. |
src/ptwt/_util.py
Outdated
raise ValueError(f"{ndim}D transforms work with {ndim} axes.") | ||
else: | ||
undo_swap_fn = partial(_undo_swap_axes, axes=axes) | ||
coeffs = _map_result(coeffs, undo_swap_fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would really advise against all of these _map_result
calls - have one function that does the processing that can be reused, then just do list comprehensions for all successive function calls.
coeffs = _map_result(coeffs, lambda x: x.squeeze(0))
will always be less readable and understandable than
coeffs = _map_result(coeffs)
coeffs = [coeff.squeeze(0) for coeff in coeffs]
when _map_result has lots of hidden functionality
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The snippet
coeffs = _map_result(coeffs, lambda x: x.squeeze(0))
applies the function x.squeeze(0)
to all tensors in coeffs
. In the 1d case (where coeffs
is of type list[Tensor]
) this is equivalent to the list comprehension, as you said. However, coeffs
might also be
(Tensor, dict[str, Tensor], ...)
(Tensor, (Tensor, Tensor, Tensor), ...)
So using _map_result
allows to write the function once for all possible coefficient types. Would it perhaps help to rename _map_result
or add documentation for it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the name should have something to do with tree and map. The new _apply_to_tensor_elems
name really hides how general the concept is. I would take a page from https://jax.readthedocs.io/en/latest/_autosummary/jax.tree.map.html#jax.tree.map and also use their type hinting. The concept does not exist in torch, but I think it makes sense here, since we save on a lot of boilerplate-code. Perhaps we should include the link and explain what's going on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is an interesting intro discussing the pytree processing philosophy: https://jax.readthedocs.io/en/latest/working-with-pytrees.html .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also think @cthoyt has a point since the tree-map concept is not very popular.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree when comparing
coeffs = _map_result(coeffs, partial(torch.squeeze, 0))
to
coeffs = [coeff.squeeze(0) for coeff in coeffs]
The list wins, but what if it's a nested structure?
@v0lta I made the n-dim transform private. Does that work? |
Yes, that works. However, we won't be able to support n-dimensional transforms across the board because PyTorch does not provide the interfaces we would need to do that. Padding, for example, works only up until 3D ( https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html ). We have the same problem with isotropic convolution. So, I think we should communicate that nd-transforms are out of the scope of this project. |
In general, I am a big fan of this full request! Thanks at @felixblanke I am going to clean up the docs for |
Our |
I think we so far only refer to the Packet data structure as a tree. Maybe we can add a link to the JAX discussion as a reference? |
I am not sure if users need to know. I think this is more for us here internally. Unlike Jax's tree map, ours is coefficient-specific, hence the proposed |
Actually, never mind. The user argument does not matter since it's a private function. If you think it helps to understand the idea, please add a link. I think it might help potential future contributors. |
…Toolbox into fix/keep-ndims-Nd
Okay, I think we are ready to merge. @felixblanke @cthoyt is everyone on board? |
okay let's merge! |
This addresses #92.
For all discrete transforms, the preprocessing and postprocessing of coefficients and tensors is very similar (i.e. folding and swapping of axes, adding batch dims, etc.). This PR moves this functionality into shared functions that use
_map_result
.Also the check for consistent devices and dtypes between coefficient tensors is moved into a function in
_utils
.Last, as it was possible to add it with a few lines of code, I added the$n$ -dimensional fully separable transforms (
fswavedecn
,fswaverecn
). If this is not wanted, I can revert their addition.Further, I did some minor refactorings along the way.